shithub: opus

Download patch

ref: b88644b9c7547f09c7e313c2ae5ad3085fba2d4e
parent: 2ec31cc5ccb92e1b472cf3da1d26490412805a79
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Fri Sep 15 13:27:44 EDT 2023

Quantizing initial state with rdovae too

More efficient than PVQ

--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -372,7 +372,7 @@
 
 
 class StatisticalModel(nn.Module):
-    def __init__(self, quant_levels, latent_dim):
+    def __init__(self, quant_levels, latent_dim, state_dim):
         """ Statistical model for latent space
 
             Computes scaling, deadzone, r, and theta
@@ -383,8 +383,10 @@
 
         # copy parameters
         self.latent_dim     = latent_dim
+        self.state_dim      = state_dim
+        self.total_dim      = latent_dim + state_dim
         self.quant_levels   = quant_levels
-        self.embedding_dim  = 6 * latent_dim
+        self.embedding_dim  = 6 * self.total_dim
 
         # quantization embedding
         self.quant_embedding    = nn.Embedding(quant_levels, self.embedding_dim)
@@ -400,12 +402,12 @@
         x = self.quant_embedding(quant_ids)
 
         # CAVE: theta_soft is not used anymore. Kick it out?
-        quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim])
-        dead_zone   = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim])
-        theta_soft  = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim])
-        r_soft      = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim])
-        theta_hard  = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim])
-        r_hard      = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim])
+        quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim])
+        dead_zone   = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim])
+        theta_soft  = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim])
+        r_soft      = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim])
+        theta_hard  = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim])
+        r_hard      = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim])
 
 
         return {
@@ -445,7 +447,7 @@
         self.state_dropout_rate = state_dropout_rate
 
         # submodules encoder and decoder share the statistical model
-        self.statistical_model = StatisticalModel(quant_levels, latent_dim)
+        self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim)
         self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
         self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
 
@@ -522,14 +524,19 @@
         z, states = self.core_encoder(features)
 
         # scaling, dead-zone and quantization
-        z = z * statistical_model['quant_scale']
-        z = soft_dead_zone(z, statistical_model['dead_zone'])
+        z = z * statistical_model['quant_scale'][:,:,:self.latent_dim]
+        z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim])
 
         # quantization
-        z_q = hard_quantize(z) / statistical_model['quant_scale']
-        z_n = noise_quantize(z) / statistical_model['quant_scale']
-        states_q = soft_pvq(states, self.pvq_num_pulses)
+        z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
+        z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
+        #states_q = soft_pvq(states, self.pvq_num_pulses)
+        states = states * statistical_model['quant_scale'][:,:,self.latent_dim:]
+        states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:])
 
+        states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
+        states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
+
         if self.state_dropout_rate > 0:
             drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
             mask = torch.ones_like(states_q)
@@ -551,6 +558,7 @@
 
             # decoder with soft quantized input
             z_dec_reverse       = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :],  [1])
+            dec_initial_state   = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
             features_reverse    = self.core_decoder(z_dec_reverse, dec_initial_state)
             outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
 
@@ -558,6 +566,7 @@
             'outputs_hard_quant' : outputs_hq,
             'outputs_soft_quant' : outputs_sq,
             'z'                 : z,
+            'states'            : states,
             'statistical_model' : statistical_model
         }
 
@@ -586,11 +595,11 @@
 
         stats = self.statistical_model(q_ids)
 
-        zq = z * stats['quant_scale']
-        zq = soft_dead_zone(zq, stats['dead_zone'])
+        zq = z * stats['quant_scale'][:self.latent_dim]
+        zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim])
         zq = torch.round(zq)
 
-        sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False)
+        sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False)
 
         return zq, sizes
 
@@ -599,7 +608,7 @@
 
         stats = self.statistical_model(q_ids)
 
-        z = zq / stats['quant_scale']
+        z = zq / stats['quant_scale'][:,:,:self.latent_dim]
 
         return z
 
--- a/dnn/torch/rdovae/train_rdovae.py
+++ b/dnn/torch/rdovae/train_rdovae.py
@@ -172,6 +172,7 @@
         running_soft_rate_loss  = 0
         running_total_loss      = 0
         running_rate_metric     = 0
+        running_states_rate_metric     = 0
         previous_total_loss     = 0
         running_first_frame_loss = 0
 
@@ -194,17 +195,21 @@
 
                 # collect outputs
                 z                   = model_output['z']
+                states              = model_output['states']
                 outputs_hard_quant  = model_output['outputs_hard_quant']
                 outputs_soft_quant  = model_output['outputs_soft_quant']
                 statistical_model   = model_output['statistical_model']
 
                 # rate loss
-                hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False)
-                soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False)
-                soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate)
-                hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate)
+                hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
+                soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
+                states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
+                states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
+                soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate))
+                hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate))
                 rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
                 hard_rate_metric = torch.mean(hard_rate)
+                states_rate_metric = torch.mean(states_hard_rate)
 
                 ## distortion losses
 
@@ -242,6 +247,7 @@
                 running_soft_dist_loss  += float(distortion_loss_soft_quant.detach().cpu())
                 running_rate_loss       += float(rate_loss.detach().cpu())
                 running_rate_metric     += float(hard_rate_metric.detach().cpu())
+                running_states_rate_metric     += float(states_rate_metric.detach().cpu())
                 running_total_loss      += float(total_loss.detach().cpu())
                 running_first_frame_loss += float(first_frame_loss.detach().cpu())
                 running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
@@ -256,6 +262,7 @@
                         dist_sq=running_soft_dist_loss / (i + 1),
                         rate_loss=running_rate_loss / (i + 1),
                         rate=running_rate_metric / (i + 1),
+                        states_rate=running_states_rate_metric / (i + 1),
                         ffloss=running_first_frame_loss / (i + 1),
                         rateloss_hard=running_hard_rate_loss / (i + 1),
                         rateloss_soft=running_soft_rate_loss / (i + 1)
--- a/silk/dred_coding.c
+++ b/silk/dred_coding.c
@@ -33,8 +33,6 @@
 #include <stdio.h>
 
 #include "celt/entenc.h"
-#include "celt/vq.h"
-#include "celt/cwrs.h"
 #include "celt/laplace.h"
 #include "os_support.h"
 #include "dred_config.h"
@@ -41,8 +39,7 @@
 #include "dred_coding.h"
 
 #define LATENT_DIM 80
-#define PVQ_DIM 24
-#define PVQ_K 82
+#define STATE_DIM 80
 
 int compute_quantizer(int q0, int dQ, int i) {
   int quant;
@@ -53,37 +50,6 @@
   return (int) floor(0.5f + DRED_ENC_Q0 + 1.f * (DRED_ENC_Q1 - DRED_ENC_Q0) * i / (DRED_NUM_REDUNDANCY_FRAMES - 2));
 }
 
-static void encode_pvq(const int *iy, int N, int K, ec_enc *enc) {
-    int fits;
-    celt_assert(N==24 || N==12 || N==6);
-    fits = (N==24 && K<=9) || (N==12 && K<=16) || (N==6);
-    /*printf("encode(%d,%d), fits=%d\n", N, K, fits);*/
-    if (fits) {
-      if (K > 0)
-        encode_pulses(iy, N, K, enc);
-    }
-    else {
-        int N2 = N/2;
-        int K0=0;
-        int i;
-        for (i=0;i<N2;i++) K0 += abs(iy[i]);
-        /* FIXME: Don't use uniform probability for K0. */
-        ec_enc_uint(enc, K0, K+1);
-        /*printf("K0 = %d\n", K0);*/
-        encode_pvq(iy, N2, K0, enc);
-        encode_pvq(&iy[N2], N2, K-K0, enc);
-    }
-}
-
-void dred_encode_state(ec_enc *enc, const float *x) {
-    int iy[PVQ_DIM];
-    float x0[PVQ_DIM];
-    /* Copy state because the PVQ search will trash it. */
-    OPUS_COPY(x0, x, PVQ_DIM);
-    op_pvq_search_c(x0, iy, PVQ_K, PVQ_DIM, 0);
-    encode_pvq(iy, PVQ_DIM, PVQ_K, enc);
-}
-
 void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint16 *r, const opus_uint16 *p0) {
     int i;
     float eps = .1f;
@@ -98,47 +64,6 @@
         /* Make the impossible actually impossible. */
         if (r[i] == 0 || p0[i] >= 32768) q = 0;
         ec_laplace_encode_p0(enc, q, p0[i], r[i]);
-    }
-}
-
-
-
-static void decode_pvq(int *iy, int N, int K, ec_dec *dec) {
-    int fits;
-    celt_assert(N==24 || N==12 || N==6);
-    fits = (N==24 && K<=9) || (N==12 && K<=16) || (N==6);
-    /*printf("encode(%d,%d), fits=%d\n", N, K, fits);*/
-    if (fits) {
-      if (K > 0)
-        decode_pulses(iy, N, K, dec);
-      else
-        OPUS_CLEAR(iy, N);
-    }
-    else {
-        int N2 = N/2;
-        int K0;
-        /* FIXME: Don't use uniform probability for K0. */
-        K0 = ec_dec_uint(dec, K+1);
-        /*printf("K0 = %d\n", K0);*/
-        decode_pvq(iy, N2, K0, dec);
-        decode_pvq(&iy[N2], N2, K-K0, dec);
-    }
-}
-
-void dred_decode_state(ec_enc *dec, float *x) {
-    int k;
-    int iy[PVQ_DIM];
-    float norm = 0;
-    decode_pvq(iy, PVQ_DIM, PVQ_K, dec);
-    /*printf("tell: %d\n", ec_tell(dec)-tell1);*/
-    for (k = 0; k < PVQ_DIM; k++)
-    {
-        norm += (float) iy[k] * iy[k];
-    }
-    norm = 1.f / sqrt(norm);
-    for (k = 0; k < PVQ_DIM; k++)
-    {
-        x[k] = iy[k] * norm;
     }
 }
 
--- a/silk/dred_config.h
+++ b/silk/dred_config.h
@@ -32,7 +32,7 @@
 #define DRED_EXTENSION_ID 126
 
 /* Remove these two completely once DRED gets an extension number assigned. */
-#define DRED_EXPERIMENTAL_VERSION 1
+#define DRED_EXPERIMENTAL_VERSION 2
 #define DRED_EXPERIMENTAL_BYTES 2
 
 
@@ -41,7 +41,7 @@
 /* these are inpart duplicates to the values defined in dred_rdovae_constants.h */
 #define DRED_NUM_FEATURES 20
 #define DRED_LATENT_DIM 80
-#define DRED_STATE_DIM 24
+#define DRED_STATE_DIM 80
 #define DRED_SILK_ENCODER_DELAY (79+12-80)
 #define DRED_FRAME_SIZE 160
 #define DRED_DFRAME_SIZE (2 * (DRED_FRAME_SIZE))
--- a/silk/dred_decoder.c
+++ b/silk/dred_decoder.c
@@ -54,6 +54,7 @@
   int offset;
   int q0;
   int dQ;
+  int state_qoffset;
 
 
   /* since features are decoded in quadruples, it makes no sense to go with an uneven number of redundancy frames */
@@ -66,7 +67,14 @@
   dQ = ec_dec_uint(&ec, 8);
   /*printf("%d %d %d\n", dred_offset, q0, dQ);*/
 
-  dred_decode_state(&ec, dec->state);
+  //dred_decode_state(&ec, dec->state);
+  state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_STATE_DIM;
+  dred_decode_latents(
+      &ec,
+      dec->state,
+      quant_scales + state_qoffset,
+      r + state_qoffset,
+      p0 + state_qoffset);
 
   /* decode newest to oldest and store oldest to newest */
   for (i = 0; i < IMIN(DRED_NUM_REDUNDANCY_FRAMES, (min_feature_frames+1)/2); i += 2)
@@ -75,7 +83,7 @@
       if (8*num_bytes - ec_tell(&ec) <= 7)
          break;
       q_level = compute_quantizer(q0, dQ, i/2);
-      offset = q_level * DRED_LATENT_DIM;
+      offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
       dred_decode_latents(
           &ec,
           &dec->latents[(i/2)*DRED_LATENT_DIM],
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -197,7 +197,7 @@
             /* 15 ms (6*2.5 ms) is the ideal offset for DRED because it corresponds to our vocoder look-ahead. */
             if (enc->dred_offset < 6) {
                 enc->dred_offset += 8;
-                OPUS_COPY(enc->initial_state, enc->state_buffer, 24);
+                OPUS_COPY(enc->initial_state, enc->state_buffer, DRED_STATE_DIM);
             } else {
                 enc->latent_offset++;
             }
@@ -221,6 +221,7 @@
     int ec_buffer_fill;
     int q0;
     int dQ;
+    int state_qoffset;
 
     /* entropy coding of state and latents */
     ec_enc_init(&ec_encoder, buf, max_bytes);
@@ -229,8 +230,14 @@
     ec_enc_uint(&ec_encoder, enc->dred_offset, 32);
     ec_enc_uint(&ec_encoder, q0, 16);
     ec_enc_uint(&ec_encoder, dQ, 8);
-    dred_encode_state(&ec_encoder, enc->initial_state);
-
+    state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_STATE_DIM;
+    dred_encode_latents(
+        &ec_encoder,
+        enc->initial_state,
+        quant_scales + state_qoffset,
+        dead_zone + state_qoffset,
+        r + state_qoffset,
+        p0 + state_qoffset);
     for (i = 0; i < IMIN(2*max_chunks, enc->latents_buffer_fill-enc->latent_offset-1); i += 2)
     {
         ec_enc ec_bak;
@@ -237,7 +244,7 @@
         ec_bak = ec_encoder;
 
         q_level = compute_quantizer(q0, dQ, i/2);
-        offset = q_level * DRED_LATENT_DIM;
+        offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
 
         dred_encode_latents(
             &ec_encoder,
--- a/silk/dred_encoder.h
+++ b/silk/dred_encoder.h
@@ -50,8 +50,8 @@
     int latent_offset;
     float latents_buffer[DRED_MAX_FRAMES * DRED_LATENT_DIM];
     int latents_buffer_fill;
-    float state_buffer[24];
-    float initial_state[24];
+    float state_buffer[DRED_STATE_DIM];
+    float initial_state[DRED_STATE_DIM];
     float resample_mem[RESAMPLING_ORDER + 1];
     LPCNetEncState lpcnet_enc_state;
     RDOVAEEncState rdovae_enc;
--