ref: a4f7c157cf879ad768f97c3fa31520ede267bea7
parent: fdd51eb7605bb392b619eee4d3213db1f787a900
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Fri Sep 9 13:03:13 EDT 2022
Stop decimating in the encoder
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -202,9 +202,8 @@
lambda_val = Input(shape=(None, 1), batch_size=batch_size)
qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros')
quant_embed = qembedding(quant_id)
- quant_embed_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(quant_embed)
- quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed_bunched))
+ quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed))
enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1')
enc_dense2 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
@@ -230,19 +229,19 @@
d7 = enc_dense7(d6)
d8 = enc_dense8(d7)
enc_out = bits_dense(Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8]))
- enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out)
+ #enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out)
bits = Multiply()([enc_out, quant_scale])
global_dense1 = Dense(128, activation='tanh', name='gdense1')
global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2')
global_bits = global_dense2(global_dense1(d6))
- encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed_bunched, global_bits], name='encoder')
+ encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder')
return encoder
def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
- bits_input = Input(shape=(None, nb_bits), batch_size=batch_size)
- quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size)
- gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size)
+ bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits")
+ quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size, name="dec_embed")
+ gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state")
dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1')
@@ -282,23 +281,28 @@
def new_split_decoder(decoder):
nb_bits = decoder.nb_bits
bunch = decoder.bunch
- bits_input = Input(shape=(None, nb_bits))
- quant_embed_input = Input(shape=(None, 6*nb_bits))
- gru_state_input = Input(shape=(None,nb_state_dim))
+ bits_input = Input(shape=(None, nb_bits), name="split_bits")
+ uqbits_input = Input(shape=(None, nb_bits), name="split_uqbits")
+ quant_embed_input = Input(shape=(None, 6*nb_bits), name="split_embed")
+ gru_state_input = Input(shape=(None,nb_state_dim), name="split_state")
- range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:])
+ range_select = Lambda(lambda x: x[0][:,x[1]+bunch//2-1:x[2]:bunch//2,:])
elem_select = Lambda(lambda x: x[0][:,x[1],:])
points = [0, 64, 128, 192, 256]
outputs = []
+ uqbits = []
for i in range(len(points)-1):
- begin = points[i]//bunch
- end = points[i+1]//bunch
- state = elem_select([gru_state_input, 2*end-1])
+ begin = points[i]//2
+ end = points[i+1]//2
+ state = elem_select([gru_state_input, end-1])
bits = range_select([bits_input, begin, end])
+ uq = range_select([uqbits_input, begin, end])
+ uqbits.append(uq)
embed = range_select([quant_embed_input, begin, end])
outputs.append(decoder([bits, embed, state]))
output = Concatenate(axis=1)(outputs)
- split = Model([bits_input, quant_embed_input, gru_state_input], output, name="split")
+ uqbits = Concatenate(axis=1)(uqbits)
+ split = Model([bits_input, uqbits_input, quant_embed_input, gru_state_input], [output, uqbits], name="split")
return split
@@ -316,8 +320,8 @@
split_decoder = new_split_decoder(decoder)
dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))
- soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec))
- hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec))
+ soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,::2,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec))
+ hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,::2,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec))
noisequant = UniformNoise()
hardquant = Lambda(hard_quantize)
@@ -324,13 +328,13 @@
dzone = Lambda(apply_dead_zone)
dze = dzone([ze,dead_zone])
gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
- combined_output = split_decoder([hardquant(dze), tf.stop_gradient(quant_embed_dec), gru_state_dec])
+ combined_output, uqbits = split_decoder([hardquant(dze), dze, tf.stop_gradient(quant_embed_dec), gru_state_dec])
ndze = noisequant(dze)
- unquantized_output = split_decoder([ndze, quant_embed_dec, gru_state_dec])
- unquantized_output_dec = split_decoder([tf.stop_gradient(ndze), tf.stop_gradient(quant_embed_dec), gru_state_dec])
+ unquantized_output, uqbits = split_decoder([ndze, dze, quant_embed_dec, gru_state_dec])
+ unquantized_output_dec, uqbits = split_decoder([tf.stop_gradient(ndze), dze, tf.stop_gradient(quant_embed_dec), gru_state_dec])
- e2 = Concatenate(name="hard_bits")([dze, hard_distr_embed, lambda_bunched])
- e = Concatenate(name="soft_bits")([dze, soft_distr_embed, lambda_bunched])
+ e2 = Concatenate(name="hard_bits")([uqbits, hard_distr_embed, lambda_bunched])
+ e = Concatenate(name="soft_bits")([uqbits, soft_distr_embed, lambda_bunched])
model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, unquantized_output_dec, e, e2], name="end2end")
--
⑨