ref: 72c5ea4129dc6473fb1d82ef3ec3d5714fab8cbd
parent: 108b75c4b189ea2eedd8e9cb0a2b56a9b4424466
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Sep 6 13:15:24 EDT 2023
Only use one frame of pre-loading
--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -235,18 +235,21 @@
exc_mem = torch.zeros(batch_size, 256, device=device)
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
- if states is None:
- states = (
- torch.zeros(batch_size, self.cond_size, device=device),
- torch.zeros(batch_size, self.cond_size, device=device),
- torch.zeros(batch_size, self.cond_size, device=device),
- torch.zeros(batch_size, self.passthrough_size, device=device)
- )
+ states = (
+ torch.zeros(batch_size, self.cond_size, device=device),
+ torch.zeros(batch_size, self.cond_size, device=device),
+ torch.zeros(batch_size, self.cond_size, device=device),
+ torch.zeros(batch_size, self.passthrough_size, device=device)
+ )
sig = torch.zeros((batch_size, 0), device=device)
cond = self.cond_net(features, period)
passthrough = torch.zeros(batch_size, self.passthrough_size, device=device)
- for n in range(nb_frames+nb_pre_frames):
+ if pre is not None:
+ prev[:,:] = pre[:, self.frame_size-self.subframe_size : self.frame_size]
+ exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size]
+ start = 1 if nb_pre_frames>0 else 0
+ for n in range(start, nb_frames+nb_pre_frames):
for k in range(self.nb_subframes):
pos = n*self.frame_size + k*self.subframe_size
preal = phase_real[:, pos:pos+self.subframe_size]
--
⑨