ref: 26073861985df7d81688bf6d56bf1360e490ee5a
parent: a8170986ecc6259b54d1ae30404b833a48acbc26
author: jbuethe <jbuethe@amazon.de>
date: Fri Nov 4 12:52:04 EDT 2022
fixed scaling/quantization order
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -334,9 +334,10 @@
dzone = Lambda(apply_dead_zone)
dze = dzone([ze,dead_zone])
ndze = noisequant(dze)
+ dze_quant = hardquant(dze)
div = Lambda(lambda x: x[0]/x[1])
- dze_unquant = div([dze,quant_scale])
+ dze_quant = div([dze_quant,quant_scale])
ndze_unquant = div([ndze,quant_scale])
mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
@@ -345,11 +346,11 @@
unquantized_output = []
cat = Concatenate(name="out_cat")
for i in range(bunch//2):
- dze_select = mod_select([dze_unquant, i])
+ dze_select = mod_select([dze_quant, i])
ndze_select = mod_select([ndze_unquant, i])
state_select = mod_select([gru_state_dec, i])
- tmp = split_decoder([hardquant(dze_select), state_select])
+ tmp = split_decoder([dze_select, state_select])
tmp = cat([tmp, lambda_up])
combined_output.append(tmp)
--
⑨