shithub: opus

Download patch

ref: 4f63743f8f814d608df4b249ac796f8edd15ade0
parent: 1b13f6313e8413056f6d9f1f15fa994d0dff7a57
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Thu Aug 31 12:32:15 EDT 2023

explicit signal gain, explicit pitch predictor

--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -9,7 +9,7 @@
 
 fid_dict = {}
 def dump_signal(x, filename):
-    return
+    #return
     if filename in fid_dict:
         fid = fid_dict[filename]
     else:
@@ -162,7 +162,7 @@
 
         self.apply(init_weights)
 
-    def forward(self, cond, prev, exc_mem, phase, period, states):
+    def forward(self, cond, prev, exc_mem, phase, period, states, gain=None):
         device = exc_mem.device
         #print(cond.shape, prev.shape)
         
@@ -176,7 +176,7 @@
         dump_signal(prev, 'pitch_exc.f32')
         dump_signal(exc_mem, 'exc_mem.f32')
         if self.has_gain:
-            gain = torch.norm(prev, dim=1, p=2, keepdim=True)
+            #gain = torch.norm(prev, dim=1, p=2, keepdim=True)
             prev = prev/(1e-5+gain)
             prev = torch.cat([prev, torch.log(1e-5+gain)], 1)
 
@@ -193,10 +193,10 @@
         if self.passthrough_size != 0:
             passthrough = sig_out[:,self.subframe_size:]
             sig_out = sig_out[:,:self.subframe_size]
-        if self.has_gain:
-            out_gain = torch.exp(self.gain_dense_out(gru3_out))
-            sig_out = sig_out * out_gain
         dump_signal(sig_out, 'exc_out.f32')
+        if self.has_gain:
+            pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
+            sig_out = (sig_out + pitch_gain*prev[:,:-1]) * gain
         exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
         dump_signal(sig_out, 'sig_out.f32')
         return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough)
@@ -246,7 +246,9 @@
                 phase = torch.cat([preal, pimag], 1)
                 #print("now: ", preal.shape, prev.shape, sig_in.shape)
                 pitch = period[:, 3+n]
-                out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states)
+                gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
+                #gain = gain[:,:,None]
+                out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states, gain=gain)
 
                 if n < nb_pre_frames:
                     out = pre[:, pos:pos+self.subframe_size]
--- a/dnn/torch/fargan/train_fargan.py
+++ b/dnn/torch/fargan/train_fargan.py
@@ -127,9 +127,9 @@
                 sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
                 sig = torch.cat([pre, sig], -1)
 
-                cont_loss = fargan.sig_l1(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
+                cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
                 specc_loss = spect_loss(sig, target.detach())
-                loss = .2*cont_loss + specc_loss
+                loss = .00*cont_loss + specc_loss
 
                 loss.backward()
                 optimizer.step()
--