shithub: opus

Download patch

ref: 8569121f6c37e4eeac739eff399f5906c07c0f43
parent: b6ac1c78bb1c6376a818c7a68ab034a68624b19a
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Sep 7 23:12:02 EDT 2022

RDO-VAE work in progress

--- /dev/null
+++ b/dnn/training_tf2/decode_rdovae.py
@@ -1,0 +1,95 @@
+#!/usr/bin/python3
+'''Copyright (c) 2021-2022 Amazon
+   Copyright (c) 2018-2019 Mozilla
+
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
+   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+'''
+
+# Train an LPCNet model
+
+import argparse
+#from plc_loader import PLCLoader
+
+parser = argparse.ArgumentParser(description='Train a PLC model')
+
+parser.add_argument('bits', metavar='<bits file>', help='binary features file (int16)')
+parser.add_argument('output', metavar='<output>', help='output features')
+parser.add_argument('--model', metavar='<model>', default='rdovae', help='PLC model python definition (without .py)')
+group1 = parser.add_mutually_exclusive_group()
+group1.add_argument('--weights', metavar='<input weights>', help='model weights')
+parser.add_argument('--cond-size', metavar='<units>', default=1024, type=int, help='number of units in conditioning network (default 1024)')
+parser.add_argument('--batch-size', metavar='<batch size>', default=1, type=int, help='batch size to use (default 128)')
+parser.add_argument('--seq-length', metavar='<sequence length>', default=1000, type=int, help='sequence length to use (default 1000)')
+
+
+args = parser.parse_args()
+
+import importlib
+rdovae = importlib.import_module(args.model)
+
+import sys
+import numpy as np
+from tensorflow.keras.optimizers import Adam
+from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
+import tensorflow.keras.backend as K
+import h5py
+
+import tensorflow as tf
+
+# Try reducing batch_size if you run out of memory on your GPU
+batch_size = args.batch_size
+
+model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
+model.load_weights(args.weights)
+
+lpc_order = 16
+
+bits_file = args.bits
+sequence_size = args.seq_length
+
+# u for unquantised, load 16 bit PCM samples and convert to mu-law
+
+
+bits = np.memmap(bits_file + "-bits.s16", dtype='int16', mode='r')
+nb_sequences = len(bits)//(20*sequence_size)//batch_size*batch_size
+bits = bits[:nb_sequences*sequence_size*20]
+
+bits = np.reshape(bits, (nb_sequences, sequence_size//4, 20*4))
+print(bits.shape)
+
+quant = np.memmap(bits_file + "-quant.f32", dtype='float32', mode='r')
+state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r')
+
+quant = np.reshape(quant, (nb_sequences, sequence_size//4, 6*20*4))
+state = np.reshape(state, (nb_sequences, sequence_size//2, 16))
+state = state[:,-1,:]
+
+print("shapes are:")
+print(bits.shape)
+print(quant.shape)
+print(state.shape)
+
+features = decoder.predict([bits, quant, state], batch_size=batch_size)
+
+features.astype('float32').tofile(args.output)
--- /dev/null
+++ b/dnn/training_tf2/encode_rdovae.py
@@ -1,0 +1,114 @@
+#!/usr/bin/python3
+'''Copyright (c) 2021-2022 Amazon
+   Copyright (c) 2018-2019 Mozilla
+
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
+   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+'''
+
+# Train an LPCNet model
+
+import argparse
+#from plc_loader import PLCLoader
+
+parser = argparse.ArgumentParser(description='Train a PLC model')
+
+parser.add_argument('features', metavar='<features file>', help='binary features file (float32)')
+parser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
+parser.add_argument('--model', metavar='<model>', default='rdovae', help='PLC model python definition (without .py)')
+group1 = parser.add_mutually_exclusive_group()
+group1.add_argument('--weights', metavar='<input weights>', help='model weights')
+parser.add_argument('--cond-size', metavar='<units>', default=1024, type=int, help='number of units in conditioning network (default 1024)')
+parser.add_argument('--batch-size', metavar='<batch size>', default=1, type=int, help='batch size to use (default 128)')
+parser.add_argument('--seq-length', metavar='<sequence length>', default=1000, type=int, help='sequence length to use (default 1000)')
+
+
+args = parser.parse_args()
+
+import importlib
+rdovae = importlib.import_module(args.model)
+
+import sys
+import numpy as np
+from tensorflow.keras.optimizers import Adam
+from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
+import tensorflow.keras.backend as K
+import h5py
+
+import tensorflow as tf
+
+# Try reducing batch_size if you run out of memory on your GPU
+batch_size = args.batch_size
+
+model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
+model.load_weights(args.weights)
+
+lpc_order = 16
+
+feature_file = args.features
+nb_features = model.nb_used_features + lpc_order
+nb_used_features = model.nb_used_features
+sequence_size = args.seq_length
+
+# u for unquantised, load 16 bit PCM samples and convert to mu-law
+
+
+features = np.memmap(feature_file, dtype='float32', mode='r')
+nb_sequences = len(features)//(nb_features*sequence_size)//batch_size*batch_size
+features = features[:nb_sequences*sequence_size*nb_features]
+
+features = np.reshape(features, (nb_sequences, sequence_size, nb_features))
+print(features.shape)
+features = features[:, :, :nb_used_features]
+#features = np.random.randn(73600, 1000, 17)
+
+lambda_val = 0.001 * np.ones((nb_sequences, sequence_size//2, 1))
+quant_id = np.round(10*np.log(lambda_val/.0007)).astype('int16')
+quant_id = quant_id[:,:,0]
+
+
+bits, quant_embed_dec, gru_state_dec = encoder.predict([features, quant_id, lambda_val], batch_size=batch_size)
+(gru_state_dec).astype('float32').tofile(args.output + "-state.f32")
+
+
+#quant_out, _, _, model_bits, _ = model.predict([features, quant_id, lambda_val], batch_size=batch_size)
+
+#dist = rdovae.feat_dist_loss(features, quant_out)
+#rate = rdovae.sq1_rate_loss(features, model_bits)
+#rate2 = rdovae.sq_rate_metric(features, model_bits)
+#print(dist, rate, rate2)
+
+print("shapes are:")
+print(bits.shape)
+print(quant_embed_dec.shape)
+print(gru_state_dec.shape)
+
+features.astype('float32').tofile(args.output + "-input.f32")
+#quant_out.astype('float32').tofile(args.output + "-enc_dec.f32")
+np.round(bits).astype('int16').tofile(args.output + "-bits.s16")
+quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32")
+
+gru_state_dec = gru_state_dec[:,-1,:]
+dec_out = decoder([bits, quant_embed_dec, gru_state_dec])
+
+dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32")
--- /dev/null
+++ b/dnn/training_tf2/rdovae.py
@@ -1,0 +1,340 @@
+#!/usr/bin/python3
+'''Copyright (c) 2022 Amazon
+
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
+   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+'''
+
+import math
+import tensorflow as tf
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise, AveragePooling1D, RepeatVector
+from tensorflow.compat.v1.keras.layers import CuDNNGRU
+from tensorflow.keras import backend as K
+from tensorflow.keras.constraints import Constraint
+from tensorflow.keras.initializers import Initializer
+from tensorflow.keras.callbacks import Callback
+from tensorflow.keras.regularizers import l1
+import numpy as np
+import h5py
+from uniform_noise import UniformNoise
+
+class WeightClip(Constraint):
+    '''Clips the weights incident to each hidden unit to be inside a range
+    '''
+    def __init__(self, c=2):
+        self.c = c
+
+    def __call__(self, p):
+        # Ensure that abs of adjacent weights don't sum to more than 127. Otherwise there's a risk of
+        # saturation when implementing dot products with SSSE3 or AVX2.
+        return self.c*p/tf.maximum(self.c, tf.repeat(tf.abs(p[:, 1::2])+tf.abs(p[:, 0::2]), 2, axis=1))
+        #return K.clip(p, -self.c, self.c)
+
+    def get_config(self):
+        return {'name': self.__class__.__name__,
+            'c': self.c}
+
+constraint = WeightClip(0.496)
+
+def soft_quantize(x):
+    #x = 4*x
+    #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x)
+    #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x)
+    #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x)    
+    return x
+
+def noise_quantize(x):
+    return soft_quantize(x + (K.random_uniform((128, 16, 80))-.5) )
+
+def hard_quantize(x):
+    x = soft_quantize(x)
+    quantized = tf.round(x)
+    return x + tf.stop_gradient(quantized - x)
+
+def apply_dead_zone(x):
+    d = x[1]*.05
+    x = x[0]
+    y = x - d*tf.math.tanh(x/(.1+d))
+    return y
+
+def rate_loss(y_true,y_pred):
+    log2_e = 1.4427
+    n = y_pred.shape[-1]
+    C = n - log2_e*np.math.log(np.math.gamma(n))
+    k = K.sum(K.abs(y_pred), axis=-1)
+    p = 1.5
+    #rate = C + (n-1)*log2_e*tf.math.log((k**p + (n/5)**p)**(1/p))
+    rate = C + (n-1)*log2_e*tf.math.log(k + .112*n**2/(n/1.8+k) )
+    return K.mean(rate)
+
+eps=1e-6
+def safelog2(x):
+    log2_e = 1.4427
+    return log2_e*tf.math.log(eps+x)
+
+def feat_dist_loss(y_true,y_pred):
+    ceps = y_pred[:,:,:18] - y_true[:,:,:18]
+    pitch = 2*(y_pred[:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2)
+    corr = y_pred[:,:,19:] - y_true[:,:,19:]
+    pitch_weight = K.square(K.maximum(0., y_true[:,:,19:]+.5))
+    return K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr))
+
+def sq1_rate_loss(y_true,y_pred):
+    lambda_val = y_pred[:,:,-1]
+    y_pred = y_pred[:,:,:-1]
+    log2_e = 1.4427
+    n = y_pred.shape[-1]//3
+    r = (y_pred[:,:,2*n:])
+    p0 = (y_pred[:,:,n:2*n])
+    p0 = 1-r**(.5+.5*p0)
+    y_pred = y_pred[:,:,:n]
+    y_pred = soft_quantize(y_pred)
+
+    y0 = K.maximum(0., 1. - K.abs(y_pred))**2
+    rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1))
+    rate = -safelog2(-.5*tf.math.log(r)*r**K.abs(y_pred))
+    rate = -safelog2((1-r)/(1+r)*r**K.abs(y_pred))
+    #rate = -safelog2(- tf.math.sinh(.5*tf.math.log(r))* r**K.abs(y_pred) - tf.math.cosh(K.maximum(0., .5 - K.abs(y_pred))*tf.math.log(r)) + 1)
+    rate = lambda_val*K.sum(rate, axis=-1)
+    return K.mean(rate)
+
+def sq2_rate_loss(y_true,y_pred):
+    lambda_val = y_pred[:,:,-1]
+    y_pred = y_pred[:,:,:-1]
+    log2_e = 1.4427
+    n = y_pred.shape[-1]//3
+    r = y_pred[:,:,2*n:]
+    p0 = y_pred[:,:,n:2*n]
+    p0 = 1-r**(.5+.5*p0)
+    #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r))
+    #p0 = 1-r**theta
+    y_pred = tf.round(y_pred[:,:,:n])
+    y0 = K.maximum(0., 1. - K.abs(y_pred))**2
+    rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1))
+    rate = lambda_val*K.sum(rate, axis=-1)
+    return K.mean(rate)
+
+def sq_rate_metric(y_true,y_pred):
+    lambda_val = y_pred[:,:,-1]
+    y_pred = y_pred[:,:,:-1]
+    log2_e = 1.4427
+    n = y_pred.shape[-1]//3
+    r = y_pred[:,:,2*n:]
+    p0 = y_pred[:,:,n:2*n]
+    p0 = 1-r**(.5+.5*p0)
+    #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r))
+    #p0 = 1-r**theta
+    y_pred = tf.round(y_pred[:,:,:n])
+    y0 = K.maximum(0., 1. - K.abs(y_pred))**2
+    rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1))
+    rate = K.sum(rate, axis=-1)
+    return K.mean(rate)
+
+def pvq_quant_search(x, k):
+    x = x/tf.reduce_sum(tf.abs(x), axis=-1, keepdims=True)
+    kx = k*x
+    y = tf.round(kx)
+    newk = k
+
+    for j in range(10):
+        #print("y = ", y)
+        #print("iteration ", j)
+        abs_y = tf.abs(y)
+        abs_kx = tf.abs(kx)
+        kk=tf.reduce_sum(abs_y, axis=-1)
+        #print("sums = ", kk)
+        plus = 1.0001*tf.reduce_min((abs_y+.5)/(abs_kx+1e-15), axis=-1)
+        minus = .9999*tf.reduce_max((abs_y-.5)/(abs_kx+1e-15), axis=-1)
+        #print("plus = ", plus)
+        #print("minus = ", minus)
+        factor = tf.where(kk>k, minus, plus)
+        factor = tf.where(kk==k, tf.ones_like(factor), factor)
+        #print("scale = ", factor)
+        factor = tf.expand_dims(factor, axis=-1)
+        #newk = newk * (k/kk)**.2
+        newk = newk*factor
+        kx = newk*x
+        #print("newk = ", newk)
+        #print("unquantized = ", newk*x)
+        y = tf.round(kx)
+
+    #print(y)
+    
+    return y
+
+def pvq_quantize(x, k):
+    x = x/(1e-15+tf.norm(x, axis=-1,keepdims=True))
+    quantized = pvq_quant_search(x, k)
+    quantized = quantized/(1e-15+tf.norm(quantized, axis=-1,keepdims=True))
+    return x + tf.stop_gradient(quantized - x)
+
+
+def var_repeat(x):
+    return RepeatVector(K.shape(x[1])[1])(x[0])
+
+nb_state_dim = 24
+
+def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
+    feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
+
+    quant_id = Input(shape=(None,), batch_size=batch_size)
+    lambda_val = Input(shape=(None, 1), batch_size=batch_size)
+    qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros')
+    quant_embed = qembedding(quant_id)
+    quant_embed_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(quant_embed)
+
+    quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed_bunched))
+
+    enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1')
+    enc_dense2 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
+    enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
+    enc_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
+    enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
+    enc_dense6 = CuDNNGRU(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
+    enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
+    enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
+
+    #bits_dense = Dense(nb_bits, activation='linear', name='bits_dense')
+    bits_dense = Conv1D(nb_bits, 4, padding='causal', activation='linear', name='bits_dense')
+
+    zero_out = Lambda(lambda x: 0*x)
+    inputs = Concatenate()([Reshape((-1, 2*nb_used_features))(feat), tf.stop_gradient(quant_embed), lambda_val])
+    #inputs = Concatenate()([feat, tf.stop_gradient(quant_embed), lambda_val])
+    d1 = enc_dense1(inputs)
+    d2 = enc_dense2(d1)
+    d3 = enc_dense3(d2)
+    d4 = enc_dense4(d3)
+    d5 = enc_dense5(d4)
+    d6, gru_state = enc_dense6(d5)
+    d7 = enc_dense7(d6)
+    d8 = enc_dense8(d7)
+    enc_out = bits_dense(Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8]))
+    enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out)
+    bits = Multiply()([enc_out, quant_scale])
+    global_dense1 = Dense(128, activation='tanh', name='gdense1')
+    global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2')
+    global_bits = global_dense2(global_dense1(d6))
+
+    encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed_bunched, global_bits], name='encoder')
+    return encoder
+
+def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
+    bits_input = Input(shape=(None, nb_bits), batch_size=batch_size)
+    quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size)
+    gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size)
+
+    
+    dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1')
+    dec_dense2 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense2')
+    dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3')
+    dec_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
+    dec_dense5 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5')
+    dec_dense6 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
+    dec_dense7 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
+    dec_dense8 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
+
+    dec_final = Dense(bunch*nb_used_features, activation='linear', name='dec_final')
+
+    div = Lambda(lambda x: x[0]/x[1])
+    time_reverse = Lambda(lambda x: K.reverse(x, 1))
+    #time_reverse = Lambda(lambda x: x)
+    quant_scale_dec = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed_dec')(quant_embed_input))
+    #gru_state_rep = RepeatVector(64//bunch)(gru_state_input)
+
+    gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
+
+    dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input), gru_state_rep])
+    dec1 = dec_dense1(time_reverse(dec_inputs))
+    dec2 = dec_dense2(dec1)
+    dec3 = dec_dense3(dec2)
+    dec4 = dec_dense4(dec3)
+    dec5 = dec_dense5(dec4)
+    dec6 = dec_dense6(dec5)
+    dec7 = dec_dense7(dec6)
+    dec8 = dec_dense8(dec7)
+    output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8])))
+    decoder = Model([bits_input, quant_embed_input, gru_state_input], time_reverse(output), name='decoder')
+    decoder.nb_bits = nb_bits
+    decoder.bunch = bunch
+    return decoder
+
+def new_split_decoder(decoder):
+    nb_bits = decoder.nb_bits
+    bunch = decoder.bunch
+    bits_input = Input(shape=(None, nb_bits))
+    quant_embed_input = Input(shape=(None, 6*nb_bits))
+    gru_state_input = Input(shape=(None,nb_state_dim))
+
+    range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:])
+    elem_select = Lambda(lambda x: x[0][:,x[1],:])
+    points = [0, 64, 128, 192, 256]
+    outputs = []
+    for i in range(len(points)-1):
+        begin = points[i]//bunch
+        end = points[i+1]//bunch
+        state = elem_select([gru_state_input, 2*end-1])
+        bits = range_select([bits_input, begin, end])
+        embed = range_select([quant_embed_input, begin, end])
+        outputs.append(decoder([bits, embed, state]))
+    output = Concatenate(axis=1)(outputs)
+    split = Model([bits_input, quant_embed_input, gru_state_input], output, name="split")
+    return split
+
+
+def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
+
+    feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
+    quant_id = Input(shape=(None,), batch_size=batch_size)
+    lambda_val = Input(shape=(None, 1), batch_size=batch_size)
+    lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val)
+
+    encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2)
+    ze, quant_embed_dec, gru_state_dec = encoder([feat, quant_id, lambda_val])
+
+    decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2)
+    split_decoder = new_split_decoder(decoder)
+
+    dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))
+    soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec))
+    hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec))
+
+    noisequant = UniformNoise()
+    hardquant = Lambda(hard_quantize)
+    dzone = Lambda(apply_dead_zone)
+    dze = dzone([ze,dead_zone])
+    gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
+    combined_output = split_decoder([hardquant(dze), tf.stop_gradient(quant_embed_dec), gru_state_dec])
+    ndze = noisequant(dze)
+    unquantized_output = split_decoder([ndze, quant_embed_dec, gru_state_dec])
+    unquantized_output_dec = split_decoder([tf.stop_gradient(ndze), tf.stop_gradient(quant_embed_dec), gru_state_dec])
+
+    e2 = Concatenate(name="hard_bits")([dze, hard_distr_embed, lambda_bunched])
+    e = Concatenate(name="soft_bits")([dze, soft_distr_embed, lambda_bunched])
+
+
+    model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, unquantized_output_dec, e, e2], name="end2end")
+    model.nb_used_features = nb_used_features
+
+    return model, encoder, decoder
+
--- /dev/null
+++ b/dnn/training_tf2/train_rdovae.py
@@ -1,0 +1,150 @@
+#!/usr/bin/python3
+'''Copyright (c) 2021-2022 Amazon
+   Copyright (c) 2018-2019 Mozilla
+
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
+   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+'''
+
+# Train an LPCNet model
+import tensorflow as tf
+strategy = tf.distribute.MultiWorkerMirroredStrategy()
+
+
+import argparse
+#from plc_loader import PLCLoader
+
+parser = argparse.ArgumentParser(description='Train a quantization model')
+
+parser.add_argument('features', metavar='<features file>', help='binary features file (float32)')
+parser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
+parser.add_argument('--model', metavar='<model>', default='rdovae', help='PLC model python definition (without .py)')
+group1 = parser.add_mutually_exclusive_group()
+group1.add_argument('--quantize', metavar='<input weights>', help='quantize model')
+group1.add_argument('--retrain', metavar='<input weights>', help='continue training model')
+parser.add_argument('--cond-size', metavar='<units>', default=1024, type=int, help='number of units in conditioning network (default 1024)')
+parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
+parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
+parser.add_argument('--seq-length', metavar='<sequence length>', default=1000, type=int, help='sequence length to use (default 1000)')
+parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
+parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
+parser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files')
+
+
+args = parser.parse_args()
+
+import importlib
+rdovae = importlib.import_module(args.model)
+
+import sys
+import numpy as np
+from tensorflow.keras.optimizers import Adam
+from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
+import tensorflow.keras.backend as K
+import h5py
+
+#gpus = tf.config.experimental.list_physical_devices('GPU')
+#if gpus:
+#  try:
+#    tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)])
+#  except RuntimeError as e:
+#    print(e)
+
+nb_epochs = args.epochs
+
+# Try reducing batch_size if you run out of memory on your GPU
+batch_size = args.batch_size
+
+quantize = args.quantize is not None
+retrain = args.retrain is not None
+
+if quantize:
+    lr = 0.00003
+    decay = 0
+    input_model = args.quantize
+else:
+    lr = 0.001
+    decay = 2.5e-5
+
+if args.lr is not None:
+    lr = args.lr
+
+if args.decay is not None:
+    decay = args.decay
+
+if retrain:
+    input_model = args.retrain
+
+
+opt = Adam(lr, decay=decay, beta_2=0.99)
+
+with strategy.scope():
+    model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
+    model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[0.5, 0.5, 0., 1., .1], metrics={'split':'mse', 'hard_bits':rdovae.sq_rate_metric})
+    model.summary()
+
+lpc_order = 16
+
+feature_file = args.features
+nb_features = model.nb_used_features + lpc_order
+nb_used_features = model.nb_used_features
+sequence_size = args.seq_length
+
+# u for unquantised, load 16 bit PCM samples and convert to mu-law
+
+
+features = np.memmap(feature_file, dtype='float32', mode='r')
+nb_sequences = len(features)//(nb_features*sequence_size)//batch_size*batch_size
+features = features[:nb_sequences*sequence_size*nb_features]
+
+features = np.reshape(features, (nb_sequences, sequence_size, nb_features))
+print(features.shape)
+features = features[:, :, :nb_used_features]
+
+#lambda_val = np.random.uniform(.0007, .002, (features.shape[0], features.shape[1], 1))
+lambda_val = np.repeat(np.random.uniform(.0007, .002, (features.shape[0], 1, 1)), features.shape[1]//2, axis=1)
+#lambda_val = 0*lambda_val + .001
+quant_id = np.round(10*np.log(lambda_val/.0007)).astype('int16')
+quant_id = quant_id[:,:,0]
+
+# dump models to disk as we go
+checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.cond_size, '{epoch:02d}'))
+
+if args.retrain is not None:
+    model.load_weights(args.retrain)
+
+if quantize or retrain:
+    #Adapting from an existing model
+    model.load_weights(input_model)
+
+model.save_weights('{}_{}_initial.h5'.format(args.output, args.cond_size))
+
+callbacks = [checkpoint]
+#callbacks = []
+
+if args.logdir is not None:
+    logdir = '{}/{}_{}_logs'.format(args.logdir, args.output, args.cond_size)
+    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
+    callbacks.append(tensorboard_callback)
+
+model.fit([features, quant_id, lambda_val], [features, features, features, features, features], batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
--- /dev/null
+++ b/dnn/training_tf2/uniform_noise.py
@@ -1,0 +1,78 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains the UniformNoise layer."""
+
+
+import tensorflow.compat.v2 as tf
+
+from tensorflow.keras import backend
+
+from tensorflow.keras.layers import Layer
+
+class UniformNoise(Layer):
+    """Apply additive zero-centered uniform noise.
+
+    This is useful to mitigate overfitting
+    (you could see it as a form of random data augmentation).
+    Gaussian Noise (GS) is a natural choice as corruption process
+    for real valued inputs.
+
+    As it is a regularization layer, it is only active at training time.
+
+    Args:
+      stddev: Float, standard deviation of the noise distribution.
+      seed: Integer, optional random seed to enable deterministic behavior.
+
+    Call arguments:
+      inputs: Input tensor (of any rank).
+      training: Python boolean indicating whether the layer should behave in
+        training mode (adding noise) or in inference mode (doing nothing).
+
+    Input shape:
+      Arbitrary. Use the keyword argument `input_shape`
+      (tuple of integers, does not include the samples axis)
+      when using this layer as the first layer in a model.
+
+    Output shape:
+      Same shape as input.
+    """
+
+
+
+
+    def __init__(self, stddev=0.5, seed=None, **kwargs):
+        super().__init__(**kwargs)
+        self.supports_masking = True
+        self.stddev = stddev
+
+
+    def call(self, inputs, training=None):
+        def noised():
+            return inputs + backend.random_uniform(
+                shape=tf.shape(inputs),
+                minval=-self.stddev,
+                maxval=self.stddev,
+                dtype=inputs.dtype,
+            )
+
+        return backend.in_train_phase(noised, inputs, training=training)
+
+    def get_config(self):
+        config = {"stddev": self.stddev}
+        base_config = super().get_config()
+        return dict(list(base_config.items()) + list(config.items()))
+
+    def compute_output_shape(self, input_shape):
+        return input_shape
--