shithub: opus

Download patch

ref: 61459c24e0cedf8f36ed618db78d25fda7d8153b
parent: 79d1a916d0715f7bcd188819abb53631842eeb2d
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Mon Oct 3 20:27:36 EDT 2022

Change decoder architecture to be like the encoder

--- a/dnn/training_tf2/decode_rdovae.py
+++ b/dnn/training_tf2/decode_rdovae.py
@@ -82,8 +82,8 @@
 bits = np.reshape(bits, (nb_sequences, sequence_size//2, 20*4))
 print(bits.shape)
 
-lambda_val = 0.0007 * np.ones((nb_sequences, sequence_size//2, 1))
-quant_id = np.round(10*np.log(lambda_val/.0002)).astype('int16')
+lambda_val = 0.001 * np.ones((nb_sequences, sequence_size//2, 1))
+quant_id = np.round(3.8*np.log(lambda_val/.0002)).astype('int16')
 quant_id = quant_id[:,:,0]
 quant_embed = qembedding(quant_id)
 quant_scale = tf.math.softplus(quant_embed[:,:,:nbits])
@@ -98,7 +98,7 @@
 
 state = np.reshape(state, (nb_sequences, sequence_size//2, 24))
 state = state[:,-1,:]
-state = pvq_quantize(state, 30)
+state = pvq_quantize(state, 82)
 #state = state/(1e-15+tf.norm(state, axis=-1,keepdims=True))
 
 print("shapes are:")
--- a/dnn/training_tf2/encode_rdovae.py
+++ b/dnn/training_tf2/encode_rdovae.py
@@ -105,7 +105,7 @@
 bits.astype('float32').tofile(args.output + "-syms.f32")
 
 lambda_val = 0.001 * np.ones((nb_sequences, sequence_size//2, 1))
-quant_id = np.round(10*np.log(lambda_val/.0002)).astype('int16')
+quant_id = np.round(3.8*np.log(lambda_val/.0002)).astype('int16')
 quant_id = quant_id[:,:,0]
 quant_embed = qembedding(quant_id)
 quant_scale = tf.math.softplus(quant_embed[:,:,:nbits])
@@ -115,7 +115,7 @@
 bits = np.round(apply_dead_zone([bits, dead_zone]).numpy())
 bits = bits/quant_scale
 
-gru_state_dec = pvq_quantize(gru_state_dec, 30)
+gru_state_dec = pvq_quantize(gru_state_dec, 82)
 #gru_state_dec = gru_state_dec/(1e-15+tf.norm(gru_state_dec, axis=-1,keepdims=True))
 gru_state_dec = gru_state_dec[:,-1,:]
 dec_out = decoder([bits[:,1::2,:], gru_state_dec])
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -238,15 +238,15 @@
     gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state")
 
     
+    gru = CuDNNGRU if training else GRU
     dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1')
-    dec_dense2 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense2')
+    dec_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense2')
     dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3')
-    gru = CuDNNGRU if training else GRU
     dec_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
-    dec_dense5 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5')
+    dec_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense5')
     dec_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
-    dec_dense7 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
-    dec_dense8 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
+    dec_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
+    dec_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
 
     dec_final = Dense(bunch*nb_used_features, activation='linear', name='dec_final')
 
@@ -260,10 +260,10 @@
     gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input)
 
     dec1 = dec_dense1(time_reverse(bits_input))
-    dec2 = dec_dense2(dec1)
+    dec2 = dec_dense2(dec1, initial_state=gru_state1)
     dec3 = dec_dense3(dec2)
-    dec4 = dec_dense4(dec3, initial_state=gru_state1)
-    dec5 = dec_dense5(dec4, initial_state=gru_state2)
+    dec4 = dec_dense4(dec3, initial_state=gru_state2)
+    dec5 = dec_dense5(dec4)
     dec6 = dec_dense6(dec5, initial_state=gru_state3)
     dec7 = dec_dense7(dec6)
     dec8 = dec_dense8(dec7)
@@ -340,7 +340,7 @@
     ndze_unquant = div([ndze,quant_scale])
 
     mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
-    gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
+    gru_state_dec = Lambda(lambda x: pvq_quantize(x, 82))(gru_state_dec)
     combined_output = []
     unquantized_output = []
     cat = Concatenate(name="out_cat")
--- a/dnn/training_tf2/train_rdovae.py
+++ b/dnn/training_tf2/train_rdovae.py
@@ -124,8 +124,8 @@
 #lambda_val = np.repeat(np.random.uniform(.0007, .002, (features.shape[0], 1, 1)), features.shape[1]//2, axis=1)
 #quant_id = np.round(10*np.log(lambda_val/.0007)).astype('int16')
 #quant_id = quant_id[:,:,0]
-quant_id = np.repeat(np.random.randint(39, size=(features.shape[0], 1, 1), dtype='int16'), features.shape[1]//2, axis=1)
-lambda_val = .0002*np.exp(quant_id/10.)
+quant_id = np.repeat(np.random.randint(16, size=(features.shape[0], 1, 1), dtype='int16'), features.shape[1]//2, axis=1)
+lambda_val = .0002*np.exp(quant_id/3.8)
 quant_id = quant_id[:,:,0]
 
 # dump models to disk as we go
--