shithub: opus

Download patch

ref: 1db1946f77bed48cdaf6fb1c00611b27275e96ce
parent: 186fa61680c7109ab240004373428a129417d52d
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Mon Jan 31 21:57:50 EST 2022

Support for biased loss

--- a/dnn/training_tf2/train_plc.py
+++ b/dnn/training_tf2/train_plc.py
@@ -46,6 +46,8 @@
 parser.add_argument('--seq-length', metavar='<sequence length>', default=1000, type=int, help='sequence length to use (default 1000)')
 parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
 parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
+parser.add_argument('--band-loss', metavar='<weight>', default=1.0, type=float, help='weight of band loss (default 1.0)')
+parser.add_argument('--loss-bias', metavar='<bias>', default=0.0, type=float, help='loss bias towards low energy (default 0.0)')
 parser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files')
 
 
@@ -94,13 +96,13 @@
 if retrain:
     input_model = args.retrain
 
-def plc_loss(alpha=1.0):
+def plc_loss(alpha=1.0, bias=0.):
     def loss(y_true,y_pred):
         mask = y_true[:,:,-1:]
         y_true = y_true[:,:,:-1]
-        e = (y_true - y_pred)*mask
+        e = (y_pred - y_true)*mask
         e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
-        l1_loss = K.mean(K.abs(e)) + alpha*K.mean(K.abs(e_bands))
+        l1_loss = K.mean(K.abs(e)) + bias*K.mean(K.maximum(e[:,:,:1], 0.)) + alpha*K.mean(K.abs(e_bands) + bias*K.maximum(e_bands, 0.))
         return l1_loss
     return loss
 
@@ -108,7 +110,7 @@
     def L1_loss(y_true,y_pred):
         mask = y_true[:,:,-1:]
         y_true = y_true[:,:,:-1]
-        e = (y_true - y_pred)*mask
+        e = (y_pred - y_true)*mask
         l1_loss = K.mean(K.abs(e))
         return l1_loss
     return L1_loss
@@ -117,7 +119,7 @@
     def L1_band_loss(y_true,y_pred):
         mask = y_true[:,:,-1:]
         y_true = y_true[:,:,:-1]
-        e = (y_true - y_pred)*mask
+        e = (y_pred - y_true)*mask
         e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
         l1_loss = K.mean(K.abs(e_bands))
         return l1_loss
@@ -128,7 +130,7 @@
 
 with strategy.scope():
     model = lpcnet.new_lpcnet_plc_model(rnn_units=args.gru_size, batch_size=batch_size, training=True, quantize=quantize, cond_size=args.cond_size)
-    model.compile(optimizer=opt, loss=plc_loss(alpha=1.), metrics=[plc_l1_loss(), plc_band_loss()])
+    model.compile(optimizer=opt, loss=plc_loss(alpha=args.band_loss, bias=args.loss_bias), metrics=[plc_l1_loss(), plc_band_loss()])
     model.summary()
 
 lpc_order = 16
--