shithub: opus

Download patch

ref: 2f8b36d691a3802714a54abd7409234e41ec3e21
parent: 72c5ea4129dc6473fb1d82ef3ec3d5714fab8cbd
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Mon Sep 11 19:28:52 EDT 2023

Add conditioning interpolation, fwconv layer

--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -101,7 +101,32 @@
         
         return out
 
+class FWConv(nn.Module):
+    def __init__(self, in_size, out_size, kernel_size=3):
+        super(FWConv, self).__init__()
 
+        torch.manual_seed(5)
+
+        self.in_size = in_size
+        self.kernel_size = kernel_size
+        self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
+        self.glu = GLU(out_size)
+
+        self.init_weights()
+
+    def init_weights(self):
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
+            or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
+                nn.init.orthogonal_(m.weight.data)
+
+    def forward(self, x, state):
+        xcat = torch.cat((state, x), -1)
+        #print(x.shape, state.shape, xcat.shape, self.in_size, self.kernel_size)
+        out = self.glu(torch.tanh(self.conv(xcat)))
+        return out, xcat[:,self.in_size:]
+
 class FARGANCond(nn.Module):
     def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64):
         super(FARGANCond, self).__init__()
@@ -113,7 +138,7 @@
         self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False)
         self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
         self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
-        self.fdense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
+        self.fdense2 = nn.Linear(self.cond_size, 80*4, bias=False)
 
         self.apply(init_weights)
 
@@ -138,9 +163,10 @@
         self.has_gain = has_gain
         self.passthrough_size = passthrough_size
         
-        print("has_gain:", self.has_gain)
-        print("passthrough_size:", self.passthrough_size)
-        self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
+        #print("has_gain:", self.has_gain)
+        #print("passthrough_size:", self.passthrough_size)
+        #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
+        self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size)
         self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
         self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
         self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
@@ -176,30 +202,26 @@
         dump_signal(prev, 'pitch_exc.f32')
         dump_signal(exc_mem, 'exc_mem.f32')
 
-        passthrough = states[3]
-        tmp = torch.cat((cond, pred[:,2:-2], prev, passthrough, phase), 1)
+        tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1)
 
-        tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
-        dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(tmp)))
+        #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
+        fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
+        dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out)))
         gru1_state = self.gru1(dense2_out, states[0])
         gru1_out = self.gru1_glu(gru1_state)
-        #gru1_out = torch.cat([gru1_out, fpitch], 1)
         gru2_state = self.gru2(gru1_out, states[1])
         gru2_out = self.gru2_glu(gru2_state)
-        #gru2_out = torch.cat([gru2_out, fpitch], 1)
         gru3_state = self.gru3(gru2_out, states[2])
         gru3_out = self.gru3_glu(gru3_state)
         gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1)
         sig_out = torch.tanh(self.sig_dense_out(gru3_out))
-        if self.passthrough_size != 0:
-            passthrough = sig_out[:,self.subframe_size:]
-            sig_out = sig_out[:,:self.subframe_size]
         dump_signal(sig_out, 'exc_out.f32')
         taps = self.ptaps_dense(gru3_out)
         taps = .2*taps + torch.exp(taps)
         taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
         dump_signal(taps, 'taps.f32')
-        fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
+        #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
+        fpitch = pred[:,2:-2]
 
         if self.has_gain:
             pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
@@ -207,7 +229,7 @@
             sig_out = (sig_out + pitch_gain*fpitch) * gain
         exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
         dump_signal(sig_out, 'sig_out.f32')
-        return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough)
+        return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state)
 
 class FARGAN(nn.Module):
     def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
@@ -239,7 +261,7 @@
             torch.zeros(batch_size, self.cond_size, device=device),
             torch.zeros(batch_size, self.cond_size, device=device),
             torch.zeros(batch_size, self.cond_size, device=device),
-            torch.zeros(batch_size, self.passthrough_size, device=device)
+            torch.zeros(batch_size, (4*self.subframe_size+80)*2, device=device)
         )
 
         sig = torch.zeros((batch_size, 0), device=device)
@@ -259,7 +281,7 @@
                 pitch = period[:, 3+n]
                 gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
                 #gain = gain[:,:,None]
-                out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states, gain=gain)
+                out, exc_mem, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, phase, pitch, states, gain=gain)
 
                 if n < nb_pre_frames:
                     out = pre[:, pos:pos+self.subframe_size]
--- a/dnn/torch/fargan/train_fargan.py
+++ b/dnn/torch/fargan/train_fargan.py
@@ -121,6 +121,8 @@
                 if (np.random.rand() > 0.1):
                     target = target[:, :sequence_length*160]
                     lpc = lpc[:,:sequence_length,:]
+                    features = features[:,:sequence_length+4,:]
+                    periods = periods[:,:sequence_length+4]
                 else:
                     target=target[::2, :]
                     lpc=lpc[::2,:]
--