ref: cd2e568bb63cc89b7235640b5714173926dda238
parent: 1db1946f77bed48cdaf6fb1c00611b27275e96ce
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Tue Feb 1 19:22:57 EST 2022
Using lost packet file instead of uniform random
--- a/dnn/training_tf2/plc_loader.py
+++ b/dnn/training_tf2/plc_loader.py
@@ -29,19 +29,25 @@
from tensorflow.keras.utils import Sequence
class PLCLoader(Sequence):
- def __init__(self, features, batch_size):
+ def __init__(self, features, lost, batch_size):
self.batch_size = batch_size
self.nb_batches = features.shape[0]//self.batch_size
- self.features = features[:self.nb_batches*self.batch_size, :]
+ self.features = features[:self.nb_batches*self.batch_size, :, :]
+ self.lost = lost.astype('float')
+ self.lost = self.lost[:(len(self.lost)//features.shape[1]-1)*features.shape[1]]
self.on_epoch_end()
def on_epoch_end(self):
self.indices = np.arange(self.nb_batches*self.batch_size)
np.random.shuffle(self.indices)
+ offset = np.random.randint(0, high=self.features.shape[1])
+ self.lost_offset = np.reshape(self.lost[offset:-self.features.shape[1]+offset], (-1, self.features.shape[1]))
+ self.lost_indices = np.random.randint(0, high=self.lost_offset.shape[0], size=self.nb_batches*self.batch_size)
def __getitem__(self, index):
features = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
- lost = (np.random.rand(features.shape[0], features.shape[1]) > .2).astype('float')
+ #lost = (np.random.rand(features.shape[0], features.shape[1]) > .2).astype('float')
+ lost = self.lost_offset[self.lost_indices[index*self.batch_size:(index+1)*self.batch_size], :]
lost = np.reshape(lost, (features.shape[0], features.shape[1], 1))
lost_mask = np.tile(lost, (1,1,features.shape[2]))
--- a/dnn/training_tf2/train_plc.py
+++ b/dnn/training_tf2/train_plc.py
@@ -34,6 +34,7 @@
parser = argparse.ArgumentParser(description='Train a PLC model')
parser.add_argument('features', metavar='<features file>', help='binary features file (float32)')
+parser.add_argument('lost_file', metavar='<packet loss file>', help='packet loss traces (int8)')
parser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
parser.add_argument('--model', metavar='<model>', default='lpcnet_plc', help='PLC model python definition (without .py)')
group1 = parser.add_mutually_exclusive_group()
@@ -151,6 +152,7 @@
features = features[:, :, :nb_used_features]
+lost = np.memmap(args.lost_file, dtype='int8', mode='r')
# dump models to disk as we go
checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.gru_size, '{epoch:02d}'))
@@ -164,7 +166,7 @@
model.save_weights('{}_{}_initial.h5'.format(args.output, args.gru_size))
-loader = PLCLoader(features, batch_size)
+loader = PLCLoader(features, lost, batch_size)
callbacks = [checkpoint]
if args.logdir is not None:
--
⑨