shithub: opus

Download patch

ref: ef12c29f14a2821760be1a04363bc46d96dd6ede
parent: 405aa7cf6962d164125e3da1ed0724241da1281a
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Sep 14 13:04:36 EDT 2022

Update encoder/decoder

--- a/dnn/training_tf2/decode_rdovae.py
+++ b/dnn/training_tf2/decode_rdovae.py
@@ -72,17 +72,20 @@
 
 
 bits = np.memmap(bits_file + "-bits.s16", dtype='int16', mode='r')
-nb_sequences = len(bits)//(20*sequence_size)//batch_size*batch_size
-bits = bits[:nb_sequences*sequence_size*20]
+nb_sequences = len(bits)//(40*sequence_size)//batch_size*batch_size
+bits = bits[:nb_sequences*sequence_size*40]
 
-bits = np.reshape(bits, (nb_sequences, sequence_size//4, 20*4))
+bits = np.reshape(bits, (nb_sequences, sequence_size//2, 20*4))
+bits = bits[:,1::2,:]
 print(bits.shape)
 
 quant = np.memmap(bits_file + "-quant.f32", dtype='float32', mode='r')
 state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r')
 
-quant = np.reshape(quant, (nb_sequences, sequence_size//4, 6*20*4))
-state = np.reshape(state, (nb_sequences, sequence_size//2, 16))
+quant = np.reshape(quant, (nb_sequences, sequence_size//2, 6*20*4))
+quant = quant[:,1::2,:]
+
+state = np.reshape(state, (nb_sequences, sequence_size//2, 24))
 state = state[:,-1,:]
 
 print("shapes are:")
--- a/dnn/training_tf2/encode_rdovae.py
+++ b/dnn/training_tf2/encode_rdovae.py
@@ -109,6 +109,6 @@
 quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32")
 
 gru_state_dec = gru_state_dec[:,-1,:]
-dec_out = decoder([bits, quant_embed_dec, gru_state_dec])
+dec_out = decoder([bits[:,1::2,:], quant_embed_dec[:,1::2,:], gru_state_dec])
 
 dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32")
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -195,7 +195,7 @@
 
 nb_state_dim = 24
 
-def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
+def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
     feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
 
     quant_id = Input(shape=(None,), batch_size=batch_size)
@@ -205,12 +205,13 @@
 
     quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed))
 
+    gru = CuDNNGRU if training else GRU
     enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1')
-    enc_dense2 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
+    enc_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
     enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
-    enc_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
+    enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
     enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
-    enc_dense6 = CuDNNGRU(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
+    enc_dense6 = gru(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
     enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
     enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
 
@@ -238,7 +239,7 @@
     encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder')
     return encoder
 
-def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
+def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
     bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits")
     quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size, name="dec_embed")
     gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state")
@@ -247,9 +248,10 @@
     dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1')
     dec_dense2 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense2')
     dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3')
-    dec_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
-    dec_dense5 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5')
-    dec_dense6 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
+    gru = CuDNNGRU if training else GRU
+    dec_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
+    dec_dense5 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5')
+    dec_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
     dec_dense7 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
     dec_dense8 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
 
@@ -313,7 +315,7 @@
     return Concatenate(axis=0)(y)
 
 
-def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
+def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
 
     feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
     quant_id = Input(shape=(None,), batch_size=batch_size)
@@ -320,10 +322,10 @@
     lambda_val = Input(shape=(None, 1), batch_size=batch_size)
     lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val)
 
-    encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2)
+    encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
     ze, quant_embed_dec, gru_state_dec = encoder([feat, quant_id, lambda_val])
 
-    decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2)
+    decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
     split_decoder = new_split_decoder(decoder)
 
     dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))
--