shithub: opus

Download patch

ref: 544b3e576c8edd1785914c988882b62d60652f26
parent: 98b8be09d56d03d220fff3536842c0703bae865c
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sun Nov 5 22:10:59 EST 2023

DRED: quantize r and p0 parameters with 8 bits

Only code non-degenerate symbols, which makes the encoder faster

--- a/autogen.sh
+++ b/autogen.sh
@@ -9,7 +9,7 @@
 srcdir=`dirname $0`
 test -n "$srcdir" && cd "$srcdir"
 
-dnn/download_model.sh c99054d
+dnn/download_model.sh 98b8be0
 
 echo "Updating build configuration files, please wait...."
 
--- a/dnn/dred_rdovae.c
+++ b/dnn/dred_rdovae.c
@@ -79,9 +79,9 @@
 }
 
 
-const opus_uint16 * DRED_rdovae_get_p0_pointer(void)
+const opus_uint8 * DRED_rdovae_get_p0_pointer(void)
 {
-    return &dred_p0_q15[0];
+    return &dred_p0_q8[0];
 }
 
 const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void)
@@ -89,9 +89,9 @@
     return &dred_dead_zone_q10[0];
 }
 
-const opus_uint16 * DRED_rdovae_get_r_pointer(void)
+const opus_uint8 * DRED_rdovae_get_r_pointer(void)
 {
-    return &dred_r_q15[0];
+    return &dred_r_q8[0];
 }
 
 const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void)
--- a/dnn/dred_rdovae.h
+++ b/dnn/dred_rdovae.h
@@ -58,9 +58,9 @@
 
 void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float *qframe, const float * z);
 
-const opus_uint16 * DRED_rdovae_get_p0_pointer(void);
+const opus_uint8 * DRED_rdovae_get_p0_pointer(void);
 const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void);
-const opus_uint16 * DRED_rdovae_get_r_pointer(void);
+const opus_uint8 * DRED_rdovae_get_r_pointer(void);
 const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void);
 
 #endif
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -63,20 +63,20 @@
 
     quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
     dead_zone_q10   = np.round(dead_zone * 2**10).astype(np.uint16)
-    r_q15           = np.round(r * 2**15).astype(np.uint16)
-    p0_q15          = np.round(p0 * 2**15).astype(np.uint16)
+    r_q15           = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
+    p0_q15          = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
 
     print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
     print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
-    print_vector(writer.source, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False)
-    print_vector(writer.source, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False)
+    print_vector(writer.source, r_q15, 'dred_r_q8', dtype='opus_uint8', static=False)
+    print_vector(writer.source, p0_q15, 'dred_p0_q8', dtype='opus_uint8', static=False)
 
     writer.header.write(
 f"""
 extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
 extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
-extern const opus_uint16 dred_r_q15[{levels * N}];
-extern const opus_uint16 dred_p0_q15[{levels * N}];
+extern const opus_uint8 dred_r_q8[{levels * N}];
+extern const opus_uint8 dred_p0_q8[{levels * N}];
 
 """
     )
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -222,6 +222,9 @@
 
     return clip_weights
 
+def n(x):
+    return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
+
 # RDOVAE module and submodules
 
 class MyConv(nn.Module):
@@ -295,17 +298,17 @@
         device = x.device
 
         # run encoding layer stack
-        x = torch.tanh(self.dense_1(x))
-        x = torch.cat([x, self.gru1(x)[0]], -1)
-        x = torch.cat([x, self.conv1(x)], -1)
-        x = torch.cat([x, self.gru2(x)[0]], -1)
-        x = torch.cat([x, self.conv2(x)], -1)
-        x = torch.cat([x, self.gru3(x)[0]], -1)
-        x = torch.cat([x, self.conv3(x)], -1)
-        x = torch.cat([x, self.gru4(x)[0]], -1)
-        x = torch.cat([x, self.conv4(x)], -1)
-        x = torch.cat([x, self.gru5(x)[0]], -1)
-        x = torch.cat([x, self.conv5(x)], -1)
+        x = n(torch.tanh(self.dense_1(x)))
+        x = torch.cat([x, n(self.gru1(x)[0])], -1)
+        x = torch.cat([x, n(self.conv1(x))], -1)
+        x = torch.cat([x, n(self.gru2(x)[0])], -1)
+        x = torch.cat([x, n(self.conv2(x))], -1)
+        x = torch.cat([x, n(self.gru3(x)[0])], -1)
+        x = torch.cat([x, n(self.conv3(x))], -1)
+        x = torch.cat([x, n(self.gru4(x)[0])], -1)
+        x = torch.cat([x, n(self.conv4(x))], -1)
+        x = torch.cat([x, n(self.gru5(x)[0])], -1)
+        x = torch.cat([x, n(self.conv5(x))], -1)
         z = self.z_dense(x)
 
         # init state for decoder
@@ -372,18 +375,18 @@
         h5_state = gru_state[:,:,384:].contiguous()
 
         # run decoding layer stack
-        x = torch.tanh(self.dense_1(z))
+        x = n(torch.tanh(self.dense_1(z)))
 
-        x = torch.cat([x, self.gru1(x, h1_state)[0]], -1)
-        x = torch.cat([x, self.conv1(x)], -1)
-        x = torch.cat([x, self.gru2(x, h2_state)[0]], -1)
-        x = torch.cat([x, self.conv2(x)], -1)
-        x = torch.cat([x, self.gru3(x, h3_state)[0]], -1)
-        x = torch.cat([x, self.conv3(x)], -1)
-        x = torch.cat([x, self.gru4(x, h4_state)[0]], -1)
-        x = torch.cat([x, self.conv4(x)], -1)
-        x = torch.cat([x, self.gru5(x, h5_state)[0]], -1)
-        x = torch.cat([x, self.conv5(x)], -1)
+        x = torch.cat([x, n(self.gru1(x, h1_state)[0])], -1)
+        x = torch.cat([x, n(self.conv1(x))], -1)
+        x = torch.cat([x, n(self.gru2(x, h2_state)[0])], -1)
+        x = torch.cat([x, n(self.conv2(x))], -1)
+        x = torch.cat([x, n(self.gru3(x, h3_state)[0])], -1)
+        x = torch.cat([x, n(self.conv3(x))], -1)
+        x = torch.cat([x, n(self.gru4(x, h4_state)[0])], -1)
+        x = torch.cat([x, n(self.conv4(x))], -1)
+        x = torch.cat([x, n(self.gru5(x, h5_state)[0])], -1)
+        x = torch.cat([x, n(self.conv5(x))], -1)
 
         # output layer and reshaping
         x10 = self.output(x)
@@ -451,7 +454,7 @@
                  cond_size2,
                  state_dim=24,
                  split_mode='split',
-                 clip_weights=True,
+                 clip_weights=False,
                  pvq_num_pulses=82,
                  state_dropout_rate=0):
 
@@ -487,7 +490,7 @@
         if not type(self.weight_clip_fn) == type(None):
             self.apply(self.weight_clip_fn)
 
-    def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 24):
+    def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
 
         enc_stride = self.enc_stride
         dec_stride = self.dec_stride
--- a/silk/dred_config.h
+++ b/silk/dred_config.h
@@ -32,7 +32,7 @@
 #define DRED_EXTENSION_ID 126
 
 /* Remove these two completely once DRED gets an extension number assigned. */
-#define DRED_EXPERIMENTAL_VERSION 6
+#define DRED_EXPERIMENTAL_VERSION 7
 #define DRED_EXPERIMENTAL_BYTES 2
 
 
--- a/silk/dred_decoder.c
+++ b/silk/dred_decoder.c
@@ -43,11 +43,12 @@
   return (x ^ m) - m;
 }
 
-static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint16 *r, const opus_uint16 *p0, int dim) {
+static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
     int i;
     for (i=0;i<dim;i++) {
         int q;
-        q = ec_laplace_decode_p0(dec, p0[i], r[i]);
+        if (r[i] == 0 || p0[i] == 255) q = 0;
+        else q = ec_laplace_decode_p0(dec, p0[i]<<7, r[i]<<7);
         x[i] = q*256.f/(scale[i] == 0 ? 1 : scale[i]);
     }
 }
@@ -54,9 +55,9 @@
 
 int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int min_feature_frames)
 {
-  const opus_uint16 *p0              = DRED_rdovae_get_p0_pointer();
+  const opus_uint8 *p0              = DRED_rdovae_get_p0_pointer();
   const opus_uint16 *quant_scales    = DRED_rdovae_get_quant_scales_pointer();
-  const opus_uint16 *r               = DRED_rdovae_get_r_pointer();
+  const opus_uint8 *r               = DRED_rdovae_get_r_pointer();
   ec_dec ec;
   int q_level;
   int i;
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -217,7 +217,7 @@
     }
 }
 
-static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint16 *r, const opus_uint16 *p0, int dim) {
+static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
     int i;
     int q[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
     float xq[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
@@ -238,16 +238,16 @@
     }
     for (i=0;i<dim;i++) {
         /* Make the impossible actually impossible. */
-        if (r[i] == 0 || p0[i] >= 32767) q[i] = 0;
-        ec_laplace_encode_p0(enc, q[i], p0[i], r[i]);
+        if (r[i] == 0 || p0[i] == 255) q[i] = 0;
+        else ec_laplace_encode_p0(enc, q[i], p0[i]<<7, r[i]<<7);
     }
 }
 
 int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunks, int max_bytes) {
     const opus_uint16 *dead_zone       = DRED_rdovae_get_dead_zone_pointer();
-    const opus_uint16 *p0              = DRED_rdovae_get_p0_pointer();
+    const opus_uint8 *p0              = DRED_rdovae_get_p0_pointer();
     const opus_uint16 *quant_scales    = DRED_rdovae_get_quant_scales_pointer();
-    const opus_uint16 *r               = DRED_rdovae_get_r_pointer();
+    const opus_uint8 *r               = DRED_rdovae_get_r_pointer();
     ec_enc ec_encoder;
 
     int q_level;
--