shithub: opus

Download patch

ref: 2a8bcf4c0f816e76b0093bbedb15389531a35c23
parent: 2e06c07893521af0c404e2f6150b8637b3e2df36
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sat Feb 12 21:51:22 EST 2022

3-part pitch loss function

--- a/dnn/training_tf2/train_plc.py
+++ b/dnn/training_tf2/train_plc.py
@@ -103,7 +103,7 @@
         y_true = y_true[:,:,:-1]
         e = (y_pred - y_true)*mask
         e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
-        l1_loss = K.mean(K.abs(e)) + bias*K.mean(K.maximum(e[:,:,:1], 0.)) + alpha*K.mean(K.abs(e_bands) + bias*K.maximum(e_bands, 0.)) + 5*K.mean(K.minimum(K.abs(e[:,:,18:19]),1.))
+        l1_loss = K.mean(K.abs(e)) + alpha*K.mean(K.abs(e_bands)) + K.mean(K.minimum(K.abs(e[:,:,18:19]),1.)) + 8*K.mean(K.minimum(K.abs(e[:,:,18:19]),.4))
         return l1_loss
     return loss
 
@@ -116,6 +116,15 @@
         return l1_loss
     return L1_loss
 
+def plc_ceps_loss():
+    def ceps_loss(y_true,y_pred):
+        mask = y_true[:,:,-1:]
+        y_true = y_true[:,:,:-1]
+        e = (y_pred - y_true)*mask
+        l1_loss = K.mean(K.abs(e[:,:,:-2]))
+        return l1_loss
+    return ceps_loss
+
 def plc_band_loss():
     def L1_band_loss(y_true,y_pred):
         mask = y_true[:,:,-1:]
@@ -126,12 +135,21 @@
         return l1_loss
     return L1_band_loss
 
+def plc_pitch_loss():
+    def pitch_loss(y_true,y_pred):
+        mask = y_true[:,:,-1:]
+        y_true = y_true[:,:,:-1]
+        e = (y_pred - y_true)*mask
+        l1_loss = K.mean(K.minimum(K.abs(e[:,:,18:19]),.4))
+        return l1_loss
+    return pitch_loss
+
 opt = Adam(lr, decay=decay, beta_2=0.99)
 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 
 with strategy.scope():
     model = lpcnet.new_lpcnet_plc_model(rnn_units=args.gru_size, batch_size=batch_size, training=True, quantize=quantize, cond_size=args.cond_size)
-    model.compile(optimizer=opt, loss=plc_loss(alpha=args.band_loss, bias=args.loss_bias), metrics=[plc_l1_loss(), plc_band_loss()])
+    model.compile(optimizer=opt, loss=plc_loss(alpha=args.band_loss, bias=args.loss_bias), metrics=[plc_l1_loss(), plc_ceps_loss(), plc_band_loss(), plc_pitch_loss()])
     model.summary()
 
 lpc_order = 16
--