shithub: opus

Download patch

ref: 3e223e6015d2740fcfd4c50d4fdc2d20349fa565
parent: f8f12e7f3c4bb8d48f78f19adf9c38768389dba0
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Fri Jul 9 19:16:11 EDT 2021

Fixes Python inference for the binary probability tree

--- a/dnn/training_tf2/lpcnet.py
+++ b/dnn/training_tf2/lpcnet.py
@@ -44,15 +44,23 @@
 embed_size = 128
 pcm_levels = 2**pcm_bits
 
-def interleave(p):
+def interleave(p, samples):
     p2=tf.expand_dims(p, 3)
     nb_repeats = pcm_levels//(2*p.shape[2])
-    p3 = tf.reshape(tf.repeat(tf.concat([1-p2, p2], 3), nb_repeats), (-1, 2400, pcm_levels))
+    p3 = tf.reshape(tf.repeat(tf.concat([1-p2, p2], 3), nb_repeats), (-1, samples, pcm_levels))
     return p3
 
-def tree_to_pdf(p):
-    return interleave(p[:,:,1:2]) * interleave(p[:,:,2:4]) * interleave(p[:,:,4:8]) * interleave(p[:,:,8:16]) * interleave(p[:,:,16:32]) * interleave(p[:,:,32:64]) * interleave(p[:,:,64:128]) * interleave(p[:,:,128:256])
+def tree_to_pdf(p, samples):
+    return interleave(p[:,:,1:2], samples) * interleave(p[:,:,2:4], samples) * interleave(p[:,:,4:8], samples) * interleave(p[:,:,8:16], samples) \
+         * interleave(p[:,:,16:32], samples) * interleave(p[:,:,32:64], samples) * interleave(p[:,:,64:128], samples) * interleave(p[:,:,128:256], samples)
 
+def tree_to_pdf_train(p):
+    #FIXME: try not to hardcode the 2400 samples (15 frames * 160 samples/frame)
+    return tree_to_pdf(p, 2400)
+
+def tree_to_pdf_infer(p):
+    return tree_to_pdf(p, 1)
+
 def quant_regularizer(x):
     Q = 128
     Q_1 = 1./Q
@@ -197,8 +205,8 @@
     md = MDense(pcm_levels, activation='sigmoid', name='dual_fc')
     gru_out1, _ = rnn(rnn_in)
     gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
-    ulaw_prob = Lambda(tree_to_pdf)(md(gru_out2))
-    
+    ulaw_prob = Lambda(tree_to_pdf_train)(md(gru_out2))
+
     if adaptation:
         rnn.trainable=False
         rnn2.trainable=False
@@ -216,7 +224,7 @@
     dec_rnn_in = Concatenate()([cpcm, dec_feat])
     dec_gru_out1, state1 = rnn(dec_rnn_in, initial_state=dec_state1)
     dec_gru_out2, state2 = rnn2(Concatenate()([dec_gru_out1, dec_feat]), initial_state=dec_state2)
-    dec_ulaw_prob = Lambda(tree_to_pdf)(md(dec_gru_out2))
+    dec_ulaw_prob = Lambda(tree_to_pdf_infer)(md(dec_gru_out2))
 
     decoder = Model([pcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2])
     return model, encoder, decoder
--