shithub: opus

Download patch

ref: 405aa7cf6962d164125e3da1ed0724241da1281a
parent: 981d06eefda5ecfaf44bce05d75cccd01a3a1e24
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sun Sep 11 00:13:24 EDT 2022

WIP: training with different alignment

--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -94,9 +94,9 @@
     return log2_e*tf.math.log(eps+x)
 
 def feat_dist_loss(y_true,y_pred):
-    ceps = y_pred[:,:,:18] - y_true[:,:,:18]
-    pitch = 2*(y_pred[:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2)
-    corr = y_pred[:,:,19:] - y_true[:,:,19:]
+    ceps = y_pred[:,:,:,:18] - y_true[:,:,:18]
+    pitch = 2*(y_pred[:,:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2)
+    corr = y_pred[:,:,:,19:] - y_true[:,:,19:]
     pitch_weight = K.square(K.maximum(0., y_true[:,:,19:]+.5))
     return K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr))
 
@@ -300,7 +300,19 @@
     split = Model([bits_input, quant_embed_input, gru_state_input], output, name="split")
     return split
 
+def tensor_concat(x):
+    #n = x[1]//2
+    #x = x[0]
+    n=2
+    y = []
+    for i in range(n-1):
+        offset = n-1-i
+        tmp = K.concatenate([x[i][:, offset:, :], x[-1][:, -offset:, :]], axis=-2) 
+        y.append(tf.expand_dims(tmp, axis=0))
+    y.append(tf.expand_dims(x[-1], axis=0))
+    return Concatenate(axis=0)(y)
 
+
 def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
 
     feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
@@ -315,8 +327,8 @@
     split_decoder = new_split_decoder(decoder)
 
     dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))
-    soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,::2,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec))
-    hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,::2,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec))
+    soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec))
+    hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec))
 
     noisequant = UniformNoise()
     hardquant = Lambda(hard_quantize)
@@ -326,19 +338,24 @@
     mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
     gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
     ndze = noisequant(dze)
-    for i in [1]:
+    combined_output = []
+    unquantized_output = []
+    for i in range(bunch//2):
         dze_select = mod_select([dze, i])
         ndze_select = mod_select([ndze, i])
         state_select = mod_select([gru_state_dec, i])
-        combined_output = split_decoder([hardquant(dze_select), tf.stop_gradient(quant_embed_dec), state_select])
-        unquantized_output = split_decoder([ndze_select, quant_embed_dec, state_select])
-        unquantized_output_dec = split_decoder([tf.stop_gradient(ndze_select), tf.stop_gradient(quant_embed_dec), state_select])
+        combined_output.append(split_decoder([hardquant(dze_select), tf.stop_gradient(quant_embed_dec), state_select]))
+        unquantized_output.append(split_decoder([ndze_select, quant_embed_dec, state_select]))
 
-    e2 = Concatenate(name="hard_bits")([dze_select, hard_distr_embed, lambda_bunched])
-    e = Concatenate(name="soft_bits")([dze_select, soft_distr_embed, lambda_bunched])
+    concat = Lambda(tensor_concat, name="output")
+    combined_output = concat(combined_output)
+    unquantized_output = concat(unquantized_output)
+    
+    e2 = Concatenate(name="hard_bits")([dze, hard_distr_embed, lambda_val])
+    e = Concatenate(name="soft_bits")([dze, soft_distr_embed, lambda_val])
 
 
-    model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, unquantized_output_dec, e, e2], name="end2end")
+    model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, e, e2], name="end2end")
     model.nb_used_features = nb_used_features
 
     return model, encoder, decoder
--- a/dnn/training_tf2/train_rdovae.py
+++ b/dnn/training_tf2/train_rdovae.py
@@ -100,7 +100,7 @@
 
 with strategy.scope():
     model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
-    model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[0.5, 0.5, 0., 1., .1], metrics={'split':'mse', 'hard_bits':rdovae.sq_rate_metric})
+    model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[0.5, 0.5, 1., .1], metrics={'hard_bits':rdovae.sq_rate_metric})
     model.summary()
 
 lpc_order = 16
@@ -147,4 +147,4 @@
     tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
     callbacks.append(tensorboard_callback)
 
-model.fit([features, quant_id, lambda_val], [features, features, features, features, features], batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
+model.fit([features, quant_id, lambda_val], [features, features, features, features], batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
--