shithub: opus

Download patch

ref: 7f7b2a1c662580e214e5fba20eef40816563bfbd
parent: 19a5d6ec03d10920380d5385b7b898e287079a68
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Nov 15 07:58:52 EST 2023

Smaller version of fargan

800k parameters, 600 MFLOPS, with a receptive field of 3 feature vectors

--- a/autogen.sh
+++ b/autogen.sh
@@ -9,7 +9,7 @@
 srcdir=`dirname $0`
 test -n "$srcdir" && cd "$srcdir"
 
-dnn/download_model.sh 58923f6
+dnn/download_model.sh 19a5d6e
 
 echo "Updating build configuration files, please wait...."
 
--- a/dnn/fargan.c
+++ b/dnn/fargan.c
@@ -45,17 +45,17 @@
   FARGAN *model;
   float dense_in[NB_FEATURES+COND_NET_PEMBED_OUT_SIZE];
   float conv1_in[COND_NET_FCONV1_IN_SIZE];
-  float conv2_in[COND_NET_FCONV2_IN_SIZE];
+  float fdense2_in[COND_NET_FCONV1_OUT_SIZE];
   model = &st->model;
   celt_assert(FARGAN_FEATURES+COND_NET_PEMBED_OUT_SIZE == model->cond_net_fdense1.nb_inputs);
   celt_assert(COND_NET_FCONV1_IN_SIZE == model->cond_net_fdense1.nb_outputs);
-  celt_assert(COND_NET_FCONV2_IN_SIZE == model->cond_net_fconv1.nb_outputs);
+  celt_assert(COND_NET_FCONV1_OUT_SIZE == model->cond_net_fconv1.nb_outputs);
   OPUS_COPY(&dense_in[NB_FEATURES], &model->cond_net_pembed.float_weights[IMAX(0,IMIN(period-32, 224))*COND_NET_PEMBED_OUT_SIZE], COND_NET_PEMBED_OUT_SIZE);
   OPUS_COPY(dense_in, features, NB_FEATURES);
 
   compute_generic_dense(&model->cond_net_fdense1, conv1_in, dense_in, ACTIVATION_TANH, st->arch);
-  compute_generic_conv1d(&model->cond_net_fconv1, conv2_in, st->cond_conv1_state, conv1_in, COND_NET_FCONV1_IN_SIZE, ACTIVATION_TANH, st->arch);
-  compute_generic_conv1d(&model->cond_net_fconv2, cond, st->cond_conv2_state, conv2_in, COND_NET_FCONV2_IN_SIZE, ACTIVATION_TANH, st->arch);
+  compute_generic_conv1d(&model->cond_net_fconv1, fdense2_in, st->cond_conv1_state, conv1_in, COND_NET_FCONV1_IN_SIZE, ACTIVATION_TANH, st->arch);
+  compute_generic_dense(&model->cond_net_fdense2, cond, fdense2_in, ACTIVATION_TANH, st->arch);
 }
 
 static void fargan_deemphasis(float *pcm, float *deemph_mem) {
@@ -142,7 +142,7 @@
 void fargan_cont(FARGANState *st, const float *pcm0, const float *features0)
 {
   int i;
-  float cond[COND_NET_FCONV2_OUT_SIZE];
+  float cond[COND_NET_FDENSE2_OUT_SIZE];
   float x0[FARGAN_CONT_SAMPLES];
   float dummy[FARGAN_SUBFRAME_SIZE];
   int period=0;
@@ -197,7 +197,7 @@
 static void fargan_synthesize_impl(FARGANState *st, float *pcm, const float *features)
 {
   int subframe;
-  float cond[COND_NET_FCONV2_OUT_SIZE];
+  float cond[COND_NET_FDENSE2_OUT_SIZE];
   int period;
   celt_assert(st->cont_initialized);
 
--- a/dnn/fargan.h
+++ b/dnn/fargan.h
@@ -35,7 +35,7 @@
 #define FARGAN_NB_SUBFRAMES 4
 #define FARGAN_SUBFRAME_SIZE 40
 #define FARGAN_FRAME_SIZE (FARGAN_NB_SUBFRAMES*FARGAN_SUBFRAME_SIZE)
-#define FARGAN_COND_SIZE (COND_NET_FCONV2_OUT_SIZE/FARGAN_NB_SUBFRAMES)
+#define FARGAN_COND_SIZE (COND_NET_FDENSE2_OUT_SIZE/FARGAN_NB_SUBFRAMES)
 #define FARGAN_DEEMPHASIS 0.85f
 
 #define SIG_NET_INPUT_SIZE (FARGAN_COND_SIZE+2*FARGAN_SUBFRAME_SIZE+4)
@@ -49,7 +49,6 @@
   float deemph_mem;
   float pitch_buf[PITCH_MAX_PERIOD];
   float cond_conv1_state[COND_NET_FCONV1_STATE_SIZE];
-  float cond_conv2_state[COND_NET_FCONV2_STATE_SIZE];
   float fwc0_mem[SIG_NET_FWC0_STATE_SIZE];
   float gru1_state[SIG_NET_GRU1_STATE_SIZE];
   float gru2_state[SIG_NET_GRU2_STATE_SIZE];
--- a/dnn/torch/fargan/adv_train_fargan.py
+++ b/dnn/torch/fargan/adv_train_fargan.py
@@ -160,6 +160,10 @@
                 if epoch == 1 and i == 400:
                     for param in model.parameters():
                         param.requires_grad = True
+                    for param in model.cond_net.parameters():
+                        param.requires_grad = False
+                    for param in model.sig_net.cond_gain_dense.parameters():
+                        param.requires_grad = False
 
                 optimizer.zero_grad()
                 features = features.to(device)
@@ -226,7 +230,7 @@
 
                 feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen)
 
-                reg_weight = args.reg_weight + 15./(1 + (batch_count/7600.))
+                reg_weight = args.reg_weight# + 15./(1 + (batch_count/7600.))
                 gen_loss = reg_weight * reg_loss +  feat_loss + loss_gen
 
                 model.zero_grad()
--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -125,7 +125,7 @@
         return out
 
 class FWConv(nn.Module):
-    def __init__(self, in_size, out_size, kernel_size=3):
+    def __init__(self, in_size, out_size, kernel_size=2):
         super(FWConv, self).__init__()
 
         torch.manual_seed(5)
@@ -163,7 +163,7 @@
         self.pembed = nn.Embedding(224, pembed_dims)
         self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False)
         self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
-        self.fconv2 = nn.Conv1d(128, 80*4, kernel_size=3, padding='valid', bias=False)
+        self.fdense2 = nn.Linear(128, 80*4, bias=False)
 
         self.apply(init_weights)
         nb_params = sum(p.numel() for p in self.parameters())
@@ -170,13 +170,15 @@
         print(f"cond model: {nb_params} weights")
 
     def forward(self, features, period):
+        features = features[:,2:,:]
+        period = period[:,2:]
         p = self.pembed(period-32)
         features = torch.cat((features, p), -1)
         tmp = torch.tanh(self.fdense1(features))
         tmp = tmp.permute(0, 2, 1)
         tmp = torch.tanh(self.fconv1(tmp))
-        tmp = torch.tanh(self.fconv2(tmp))
         tmp = tmp.permute(0, 2, 1)
+        tmp = torch.tanh(self.fdense2(tmp))
         #tmp = torch.tanh(self.fdense2(tmp))
         return tmp
 
@@ -190,21 +192,20 @@
         self.cond_gain_dense = nn.Linear(80, 1)
 
         #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
-        self.fwc0 = FWConv(2*self.subframe_size+80+4, self.cond_size)
-        self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
-        self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, 128, bias=False)
+        self.fwc0 = FWConv(2*self.subframe_size+80+4, 192)
+        self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False)
+        self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False)
         self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)
 
-        self.dense1_glu = GLU(self.cond_size)
-        self.gru1_glu = GLU(self.cond_size)
+        self.gru1_glu = GLU(160)
         self.gru2_glu = GLU(128)
         self.gru3_glu = GLU(128)
-        self.skip_glu = GLU(self.cond_size)
+        self.skip_glu = GLU(128)
         #self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
 
-        self.skip_dense = nn.Linear(2*128+2*self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
-        self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False)
-        self.gain_dense_out = nn.Linear(self.cond_size, 4)
+        self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False)
+        self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False)
+        self.gain_dense_out = nn.Linear(192, 4)
 
 
         self.apply(init_weights)
@@ -291,10 +292,10 @@
         nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
 
         states = (
-            torch.zeros(batch_size, self.cond_size, device=device),
+            torch.zeros(batch_size, 160, device=device),
             torch.zeros(batch_size, 128, device=device),
             torch.zeros(batch_size, 128, device=device),
-            torch.zeros(batch_size, (2*self.subframe_size+80+4)*2, device=device)
+            torch.zeros(batch_size, (2*self.subframe_size+80+4)*1, device=device)
         )
 
         sig = torch.zeros((batch_size, 0), device=device)
--- a/dnn/torch/fargan/train_fargan.py
+++ b/dnn/torch/fargan/train_fargan.py
@@ -52,7 +52,7 @@
 sequence_length = args.sequence_length
 lr_decay = args.lr_decay
 
-adam_betas = [0.8, 0.99]
+adam_betas = [0.8, 0.95]
 adam_eps = 1e-8
 features_file = args.features
 signal_file = args.signal
--