ref: c5364153a859377dab94f184f8a822bc08ee5def
parent: ab9a09266f770e11eb780ebadcf01474661ee771
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Aug 4 10:02:59 EDT 2021
Add more training options
--- a/dnn/training_tf2/train_lpcnet.py
+++ b/dnn/training_tf2/train_lpcnet.py
@@ -35,7 +35,9 @@
parser.add_argument('data', metavar='<audio data file>', help='binary audio data file (uint8)')
parser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
parser.add_argument('--model', metavar='<model>', default='lpcnet', help='LPCNet model python definition (without .py)')
-parser.add_argument('--quantize', metavar='<input weights>', help='quantize model')
+group1 = parser.add_mutually_exclusive_group()
+group1.add_argument('--quantize', metavar='<input weights>', help='quantize model')
+group1.add_argument('--retrain', metavar='<input weights>', help='continue training model')
parser.add_argument('--density', metavar='<global density>', type=float, help='average density of the recurrent weights (default 0.1)')
parser.add_argument('--density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each recurrent gate (default 0.05, 0.05, 0.2)')
parser.add_argument('--grub-density', metavar='<global GRU B density>', type=float, help='average density of the recurrent weights (default 1.0)')
@@ -45,7 +47,11 @@
parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
parser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation')
+parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
+parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
+parser.add_argument('--gamma', metavar='<gamma>', type=float, help='adjust u-law compensation (default 2.0, should not be less than 1.0)')
+
args = parser.parse_args()
density = (0.05, 0.05, 0.2)
@@ -60,6 +66,8 @@
elif args.grub_density is not None:
grub_density = [0.5*args.grub_density, 0.5*args.grub_density, 2.0*args.grub_density];
+gamma = 2.0 if args.gamma is None else args.gamma
+
import importlib
lpcnet = importlib.import_module(args.model)
@@ -87,14 +95,25 @@
batch_size = args.batch_size
quantize = args.quantize is not None
+retrain = args.retrain is not None
if quantize:
lr = 0.00003
decay = 0
+ input_model = args.quantize
else:
lr = 0.001
decay = 2.5e-5
+if args.lr is not None:
+ lr = args.lr
+
+if args.decay is not None:
+ decay = args.decay
+
+if retrain:
+ input_model = args.retrain
+
flag_e2e = args.flag_e2e
opt = Adam(lr, decay=decay, beta_2=0.99)
@@ -105,7 +124,7 @@
if not flag_e2e:
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
else:
- model.compile(optimizer=opt, loss = interp_mulaw(gamma = 2),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss])
+ model.compile(optimizer=opt, loss = interp_mulaw(gamma=gamma),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss])
model.summary()
feature_file = args.features
@@ -146,9 +165,12 @@
# dump models to disk as we go
checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.grua_size, '{epoch:02d}'))
-if quantize:
+if args.retrain is not None:
+ model.load_weights(args.retrain)
+
+if quantize or retrain:
#Adapting from an existing model
- model.load_weights(args.quantize)
+ model.load_weights(input_model)
sparsify = lpcnet.Sparsify(0, 0, 1, density)
grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
else:
--
⑨