shithub: opus

Download patch

ref: 26073861985df7d81688bf6d56bf1360e490ee5a
parent: a8170986ecc6259b54d1ae30404b833a48acbc26
author: jbuethe <jbuethe@amazon.de>
date: Fri Nov 4 12:52:04 EDT 2022

fixed scaling/quantization order

--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -334,9 +334,10 @@
     dzone = Lambda(apply_dead_zone)
     dze = dzone([ze,dead_zone])
     ndze = noisequant(dze)
+    dze_quant = hardquant(dze)
     
     div = Lambda(lambda x: x[0]/x[1])
-    dze_unquant = div([dze,quant_scale])
+    dze_quant = div([dze_quant,quant_scale])
     ndze_unquant = div([ndze,quant_scale])
 
     mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
@@ -345,11 +346,11 @@
     unquantized_output = []
     cat = Concatenate(name="out_cat")
     for i in range(bunch//2):
-        dze_select = mod_select([dze_unquant, i])
+        dze_select = mod_select([dze_quant, i])
         ndze_select = mod_select([ndze_unquant, i])
         state_select = mod_select([gru_state_dec, i])
 
-        tmp = split_decoder([hardquant(dze_select), state_select])
+        tmp = split_decoder([dze_select, state_select])
         tmp = cat([tmp, lambda_up])
         combined_output.append(tmp)
 
--