shithub: opus

Download patch

ref: 5a51e2eed1166b7435a07f324b10a823721e6752
parent: 1edf5d7986ed13f933062423907628ab0a2cf9e8
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Mon Jul 12 23:09:04 EDT 2021

Adding command-line options to training script

--- a/dnn/training_tf2/dump_lpcnet.py
+++ b/dnn/training_tf2/dump_lpcnet.py
@@ -229,12 +229,15 @@
     return False
 Embedding.dump_layer = dump_embedding_layer
 
+filename = sys.argv[1]
+with h5py.File(filename, "r") as f:
+    units = min(f['model_weights']['gru_a']['gru_a']['recurrent_kernel:0'].shape)
 
-model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=384)
+model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units)
 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
 #model.summary()
 
-model.load_weights(sys.argv[1])
+model.load_weights(filename)
 
 if len(sys.argv) > 2:
     cfile = sys.argv[2];
--- a/dnn/training_tf2/train_lpcnet.py
+++ b/dnn/training_tf2/train_lpcnet.py
@@ -25,9 +25,35 @@
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 '''
 
-# Train a LPCNet model (note not a Wavenet model)
+# Train an LPCNet model
 
-import lpcnet
+import argparse
+
+parser = argparse.ArgumentParser(description='Train an LPCNet model')
+
+parser.add_argument('features', metavar='<features file>', help='binary features file (float32)')
+parser.add_argument('data', metavar='<audio data file>', help='binary audio data file (uint8)')
+parser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
+parser.add_argument('--model', metavar='<model>', default='lpcnet', help='LPCNet model python definition (without .py)')
+parser.add_argument('--quantize', metavar='<input weights>', help='quantize model')
+parser.add_argument('--density', metavar='<global density>', type=float, help='average density of the recurrent weights (default 0.1)')
+parser.add_argument('--density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each recurrent gate (default 0.05, 0.05, 0.2)')
+parser.add_argument('--grua-size', metavar='<units>', default=384, type=int, help='number of units in GRU A (default 384)')
+parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
+parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
+
+
+args = parser.parse_args()
+
+density = (0.05, 0.05, 0.2)
+if args.density_split is not None:
+    density = args.density_split
+elif args.density is not None:
+    density = [0.5*args.density, 0.5*args.density, 2.0*args.density];
+
+import importlib
+lpcnet = importlib.import_module(args.model)
+
 import sys
 import numpy as np
 from tensorflow.keras.optimizers import Adam
@@ -44,16 +70,15 @@
 #  except RuntimeError as e:
 #    print(e)
 
-nb_epochs = 120
+nb_epochs = args.epochs
 
 # Try reducing batch_size if you run out of memory on your GPU
-batch_size = 128
+batch_size = args.batch_size
 
-#Set this to True to adapt an existing model (e.g. on new data)
-adaptation = False
+quantize = args.quantize is not None
 
-if adaptation:
-    lr = 0.0001
+if quantize:
+    lr = 0.00003
     decay = 0
 else:
     lr = 0.001
@@ -63,12 +88,12 @@
 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 
 with strategy.scope():
-    model, _, _ = lpcnet.new_lpcnet_model(training=True)
-    model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
+    model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, training=True, quantize=quantize)
+    model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
     model.summary()
 
-feature_file = sys.argv[1]
-pcm_file = sys.argv[2]     # 16 bit unsigned short PCM samples
+feature_file = args.features
+pcm_file = args.data     # 16 bit unsigned short PCM samples
 frame_size = model.frame_size
 nb_features = 55
 nb_used_features = model.nb_used_features
@@ -115,15 +140,15 @@
 del in_exc
 
 # dump models to disk as we go
-checkpoint = ModelCheckpoint('lpcnet33e_384_{epoch:02d}.h5')
+checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.grua_size, '{epoch:02d}'))
 
-if adaptation:
+if quantize:
     #Adapting from an existing model
-    model.load_weights('lpcnet33a_384_100.h5')
-    sparsify = lpcnet.Sparsify(0, 0, 1, (0.05, 0.05, 0.2))
+    model.load_weights(args.quantize)
+    sparsify = lpcnet.Sparsify(0, 0, 1, density)
 else:
     #Training from scratch
-    sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2))
+    sparsify = lpcnet.Sparsify(2000, 40000, 400, density)
 
-model.save_weights('lpcnet33e_384_00.h5');
+model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
 model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify])
--