ref: 108b75c4b189ea2eedd8e9cb0a2b56a9b4424466
parent: d54b9fb49af339c8ee72a8f54ee7e5beadbd724f
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Sep 6 10:30:21 EDT 2023
Randomly double the training sequence length Helps with stability with little cost in training speed
--- a/dnn/torch/fargan/dataset.py
+++ b/dnn/torch/fargan/dataset.py
@@ -20,16 +20,19 @@
self.data = np.memmap(signal_file, dtype='int16', mode='r')
#self.data = self.data[1::2]
- self.nb_sequences = len(self.data)//(pcm_chunk_size)-1
+ self.nb_sequences = len(self.data)//(pcm_chunk_size)-4
self.data = self.data[(4-self.lookahead)*self.frame_size:]
self.data = self.data[:self.nb_sequences*pcm_chunk_size]
- self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size))
+ #self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size))
+ sizeof = self.data.strides[-1]
+ self.data = np.lib.stride_tricks.as_strided(self.data, shape=(self.nb_sequences, pcm_chunk_size*2),
+ strides=(pcm_chunk_size*sizeof, sizeof))
self.features = np.reshape(np.memmap(feature_file, dtype='float32', mode='r'), (-1, nb_features))
sizeof = self.features.strides[-1]
- self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length+4, nb_features),
+ self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length*2+4, nb_features),
strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
--- a/dnn/torch/fargan/train_fargan.py
+++ b/dnn/torch/fargan/train_fargan.py
@@ -118,6 +118,14 @@
features = features.to(device)
lpc = lpc.to(device)
periods = periods.to(device)
+ if (np.random.rand() > 0.1):
+ target = target[:, :sequence_length*160]
+ lpc = lpc[:,:sequence_length,:]
+ else:
+ target=target[::2, :]
+ lpc=lpc[::2,:]
+ features=features[::2,:]
+ periods=periods[::2,:]
target = target.to(device)
target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
--
⑨