shithub: opus

Download patch

ref: e1181bcad026b0c1f64dacb631a28f833dc21f42
parent: c8cbfa7e9bc8751dd5f396f2616cb681dda10f1c
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sun Jan 30 12:29:33 EST 2022

oops, fix band loss

--- a/dnn/training_tf2/train_plc.py
+++ b/dnn/training_tf2/train_plc.py
@@ -99,8 +99,8 @@
         mask = y_true[:,:,-1:]
         y_true = y_true[:,:,:-1]
         e = (y_true - y_pred)*mask
-        e_bands = tf.signal.idct(e, norm='ortho')
-        l1_loss = K.mean(K.abs(e) + alpha*K.abs(e_bands))
+        e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
+        l1_loss = K.mean(K.abs(e)) + alpha*K.mean(K.abs(e_bands))
         return l1_loss
     return loss
 
@@ -118,7 +118,7 @@
         mask = y_true[:,:,-1:]
         y_true = y_true[:,:,:-1]
         e = (y_true - y_pred)*mask
-        e_bands = tf.signal.idct(e, norm='ortho')
+        e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
         l1_loss = K.mean(K.abs(e_bands))
         return l1_loss
     return L1_band_loss
--