shithub: opus

Download patch

ref: 9859d68bb036ce2274082a6c253c483de9d2289d
parent: 0e5a38fac65fb06dddc75755b3d8b3ca01d6b524
author: jbuethe <jbuethe@amazon.de>
date: Mon Nov 7 11:14:11 EST 2022

changed distortion loss weighting back to 0.5, 0.5

--- a/dnn/training_tf2/encode_rdovae.py
+++ b/dnn/training_tf2/encode_rdovae.py
@@ -104,7 +104,7 @@
 nbits=80
 bits.astype('float32').tofile(args.output + "-syms.f32")
 
-lambda_val = 0.001 * np.ones((nb_sequences, sequence_size//2, 1))
+lambda_val = 0.0002 * np.ones((nb_sequences, sequence_size//2, 1))
 quant_id = np.round(3.8*np.log(lambda_val/.0002)).astype('int16')
 quant_id = quant_id[:,:,0]
 quant_embed = qembedding(quant_id)
--- a/dnn/training_tf2/train_rdovae.py
+++ b/dnn/training_tf2/train_rdovae.py
@@ -99,8 +99,8 @@
 opt = Adam(lr, decay=decay, beta_2=0.99)
 
 with strategy.scope():
-    model, encoder, decoder, _ = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
-    model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[.1, .9, 1., .1], metrics={'hard_bits':rdovae.sq_rate_metric})
+    model, encoder, decoder, _ = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size, nb_quant=16)
+    model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[.5, .5, 1., .1], metrics={'hard_bits':rdovae.sq_rate_metric})
     model.summary()
 
 lpc_order = 16
--