shithub: opus

Download patch

ref: 8cdc8081d8410d7ea8400c5a1fac8f98e58b4d90
parent: 37c9bd8d283c51de587be6185da02c1ea030d6ba
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Fri Oct 8 23:20:22 EDT 2021

Fix non-128 batch sizes

Avoid hardcoding the batch size in the model

--- a/dnn/training_tf2/lpcnet.py
+++ b/dnn/training_tf2/lpcnet.py
@@ -230,10 +230,10 @@
 
 constraint = WeightClip(0.992)
 
-def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
-    pcm = Input(shape=(None, 3), batch_size=128)
-    feat = Input(shape=(None, nb_used_features), batch_size=128)
-    pitch = Input(shape=(None, 1), batch_size=128)
+def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_size=128, training=False, adaptation=False, quantize=False, flag_e2e = False):
+    pcm = Input(shape=(None, 3), batch_size=batch_size)
+    feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
+    pitch = Input(shape=(None, 1), batch_size=batch_size)
     dec_feat = Input(shape=(None, 128))
     dec_state1 = Input(shape=(rnn_units1,))
     dec_state2 = Input(shape=(rnn_units2,))
--- a/dnn/training_tf2/train_lpcnet.py
+++ b/dnn/training_tf2/train_lpcnet.py
@@ -121,7 +121,7 @@
 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 
 with strategy.scope():
-    model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize, flag_e2e = flag_e2e)
+    model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, batch_size=batch_size, training=True, quantize=quantize, flag_e2e = flag_e2e)
     if not flag_e2e:
         model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
     else:
--