shithub: opus

Download patch

ref: f5c251c5d5faf08d15571b0ba7f34c3474a55fb8
parent: dd114baf4d6fad4501935b23bc14f79b0fb1cdd6
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Mon Sep 19 15:03:09 EDT 2022

Properly align LPC with lookahead in data loader

--- a/dnn/training_tf2/dataloader.py
+++ b/dnn/training_tf2/dataloader.py
@@ -13,7 +13,7 @@
     return rc
 
 class LPCNetLoader(Sequence):
-    def __init__(self, data, features, periods, batch_size, e2e=False):
+    def __init__(self, data, features, periods, batch_size, e2e=False, lookahead=2):
         self.batch_size = batch_size
         self.nb_batches = np.minimum(np.minimum(data.shape[0], features.shape[0]), periods.shape[0])//self.batch_size
         self.data = data[:self.nb_batches*self.batch_size, :]
@@ -20,6 +20,7 @@
         self.features = features[:self.nb_batches*self.batch_size, :]
         self.periods = periods[:self.nb_batches*self.batch_size, :]
         self.e2e = e2e
+        self.lookahead = lookahead
         self.on_epoch_end()
 
     def on_epoch_end(self):
@@ -34,7 +35,10 @@
         periods = self.periods[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
         outputs = [out_data]
         inputs = [in_data, features, periods]
-        lpc = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], 2:-2, -16:]
+        if self.lookahead > 0:
+            lpc = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], 4-self.lookahead:-self.lookahead, -16:]
+        else:
+            lpc = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], 4:, -16:]
         if self.e2e:
             outputs.append(lpc2rc(lpc))
         else:
--- a/dnn/training_tf2/train_lpcnet.py
+++ b/dnn/training_tf2/train_lpcnet.py
@@ -203,7 +203,7 @@
 
 model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
 
-loader = LPCNetLoader(data, features, periods, batch_size, e2e=flag_e2e)
+loader = LPCNetLoader(data, features, periods, batch_size, e2e=flag_e2e, lookahead=args.lookahead)
 
 callbacks = [checkpoint, sparsify, grub_sparsify]
 if args.logdir is not None:
--