ref: ef12c29f14a2821760be1a04363bc46d96dd6ede
parent: 405aa7cf6962d164125e3da1ed0724241da1281a
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Sep 14 13:04:36 EDT 2022
Update encoder/decoder
--- a/dnn/training_tf2/decode_rdovae.py
+++ b/dnn/training_tf2/decode_rdovae.py
@@ -72,17 +72,20 @@
bits = np.memmap(bits_file + "-bits.s16", dtype='int16', mode='r')
-nb_sequences = len(bits)//(20*sequence_size)//batch_size*batch_size
-bits = bits[:nb_sequences*sequence_size*20]
+nb_sequences = len(bits)//(40*sequence_size)//batch_size*batch_size
+bits = bits[:nb_sequences*sequence_size*40]
-bits = np.reshape(bits, (nb_sequences, sequence_size//4, 20*4))
+bits = np.reshape(bits, (nb_sequences, sequence_size//2, 20*4))
+bits = bits[:,1::2,:]
print(bits.shape)
quant = np.memmap(bits_file + "-quant.f32", dtype='float32', mode='r')
state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r')
-quant = np.reshape(quant, (nb_sequences, sequence_size//4, 6*20*4))
-state = np.reshape(state, (nb_sequences, sequence_size//2, 16))
+quant = np.reshape(quant, (nb_sequences, sequence_size//2, 6*20*4))
+quant = quant[:,1::2,:]
+
+state = np.reshape(state, (nb_sequences, sequence_size//2, 24))
state = state[:,-1,:]
print("shapes are:")
--- a/dnn/training_tf2/encode_rdovae.py
+++ b/dnn/training_tf2/encode_rdovae.py
@@ -109,6 +109,6 @@
quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32")
gru_state_dec = gru_state_dec[:,-1,:]
-dec_out = decoder([bits, quant_embed_dec, gru_state_dec])
+dec_out = decoder([bits[:,1::2,:], quant_embed_dec[:,1::2,:], gru_state_dec])
dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32")
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -195,7 +195,7 @@
nb_state_dim = 24
-def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
+def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
quant_id = Input(shape=(None,), batch_size=batch_size)
@@ -205,12 +205,13 @@
quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed))
+ gru = CuDNNGRU if training else GRU
enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1')
- enc_dense2 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
+ enc_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
- enc_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
+ enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
- enc_dense6 = CuDNNGRU(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
+ enc_dense6 = gru(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
@@ -238,7 +239,7 @@
encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder')
return encoder
-def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
+def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits")
quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size, name="dec_embed")
gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state")
@@ -247,9 +248,10 @@
dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1')
dec_dense2 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense2')
dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3')
- dec_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
- dec_dense5 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5')
- dec_dense6 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
+ gru = CuDNNGRU if training else GRU
+ dec_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
+ dec_dense5 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5')
+ dec_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
dec_dense7 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
dec_dense8 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
@@ -313,7 +315,7 @@
return Concatenate(axis=0)(y)
-def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
+def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
quant_id = Input(shape=(None,), batch_size=batch_size)
@@ -320,10 +322,10 @@
lambda_val = Input(shape=(None, 1), batch_size=batch_size)
lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val)
- encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2)
+ encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
ze, quant_embed_dec, gru_state_dec = encoder([feat, quant_id, lambda_val])
- decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2)
+ decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
split_decoder = new_split_decoder(decoder)
dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))
--
⑨