shithub: opus

Download patch

ref: 7fdca7f01dc1c5974b38c0ad77e0174a2c010577
parent: 0e523aa3f40c5e84aa60d70eea1a5fc4a9ff46c8
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Oct 20 19:16:44 EDT 2021

Minor cleanup

--- a/dnn/training_tf2/lpcnet.py
+++ b/dnn/training_tf2/lpcnet.py
@@ -258,20 +258,18 @@
 
     cfeat = fdense2(fdense1(cfeat))
 
-    Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1))
     error_calc = Lambda(lambda x: tf_l2u(x[0] - tf.roll(x[1],1,axis = 1)))
     if flag_e2e:
         lpcoeffs = diff_rc2lpc(name = "rc2lpc")(cfeat)
     else:
         lpcoeffs = Input(shape=(None, lpc_order), batch_size=batch_size)
-    tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs])
-    past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds])
+    tensor_preds = diff_pred(name = "lpc2preds")([pcm,lpcoeffs])
+    past_errors = error_calc([pcm,tensor_preds])
     embed = diff_Embed(name='embed_sig',initializer = PCMInit())
-    cpcm = Concatenate()([tf_l2u(Input_extractor([pcm,0])),tf_l2u(tensor_preds),past_errors])
+    cpcm = Concatenate()([tf_l2u(pcm),tf_l2u(tensor_preds),past_errors])
     cpcm = GaussianNoise(.3)(cpcm)
     cpcm = Reshape((-1, embed_size*3))(embed(cpcm))
-    cpcm_decoder = Concatenate()([Input_extractor([dpcm,0]),Input_extractor([dpcm,1]),Input_extractor([dpcm,2])])
-    cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder))
+    cpcm_decoder = Reshape((-1, embed_size*3))(embed(dpcm))
 
     
     rep = Lambda(lambda x: K.repeat_elements(x, frame_size, 1))
--