ref: e1181bcad026b0c1f64dacb631a28f833dc21f42
parent: c8cbfa7e9bc8751dd5f396f2616cb681dda10f1c
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sun Jan 30 12:29:33 EST 2022
oops, fix band loss
--- a/dnn/training_tf2/train_plc.py
+++ b/dnn/training_tf2/train_plc.py
@@ -99,8 +99,8 @@
mask = y_true[:,:,-1:]
y_true = y_true[:,:,:-1]
e = (y_true - y_pred)*mask
- e_bands = tf.signal.idct(e, norm='ortho')
- l1_loss = K.mean(K.abs(e) + alpha*K.abs(e_bands))
+ e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
+ l1_loss = K.mean(K.abs(e)) + alpha*K.mean(K.abs(e_bands))
return l1_loss
return loss
@@ -118,7 +118,7 @@
mask = y_true[:,:,-1:]
y_true = y_true[:,:,:-1]
e = (y_true - y_pred)*mask
- e_bands = tf.signal.idct(e, norm='ortho')
+ e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
l1_loss = K.mean(K.abs(e_bands))
return l1_loss
return L1_band_loss
--
⑨