ref: 89db314efb08c9774c40d3b37440ae264e87a7ce
parent: 61459c24e0cedf8f36ed618db78d25fda7d8153b
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Tue Oct 4 11:46:13 EDT 2022
Updating fec_encoder.py for recent changes
--- a/dnn/training_tf2/fec_encoder.py
+++ b/dnn/training_tf2/fec_encoder.py
@@ -44,7 +44,7 @@
args = parser.parse_args()
-model, encoder, decoder = new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=1, cond_size=args.cond_size)
+model, encoder, decoder, qembedding = new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=1, cond_size=args.cond_size)
model.load_weights(args.weights)
lpc_order = 16
@@ -106,26 +106,30 @@
features = features[:, :, :nb_used_features]
features = features[:, :num_subframes, :]
-# lambda and q_id (ToDo: check validity of lambda and q_id)
-enc_lambda = args.enc_lambda * np.ones((1, num_frames, 1))
-quant_id = np.round(10*np.log(enc_lambda/.0007)).astype('int16')
+#variable quantizer depending on the delay
+q0 = 2
+q1 = 10
+quant_id = np.round(q1 + (q0-q1)*np.arange(args.num_redundancy_frames//2)/args.num_redundancy_frames).astype('int16')
+#print(quant_id)
+quant_embed = qembedding(quant_id)
# run encoder
print("running fec encoder...")
-symbols, quant_embed_dec, gru_state_dec = encoder.predict([features, quant_id, enc_lambda])
+symbols, gru_state_dec = encoder.predict(features)
# apply quantization
nsymbols = 80
-dead_zone = tf.math.softplus(quant_embed_dec[:, :, nsymbols : 2 * nsymbols])
-symbols = apply_dead_zone([symbols, dead_zone]).numpy()
-qsymbols = np.round(symbols)
-quant_gru_state_dec = pvq_quantize(gru_state_dec, 30)
+quant_scale = tf.math.softplus(quant_embed[:, :nsymbols]).numpy()
+dead_zone = tf.math.softplus(quant_embed[:, nsymbols : 2 * nsymbols]).numpy()
+#symbols = apply_dead_zone([symbols, dead_zone]).numpy()
+#qsymbols = np.round(symbols)
+quant_gru_state_dec = pvq_quantize(gru_state_dec, 82)
# rate estimate
-hard_distr_embed = tf.math.sigmoid(quant_embed_dec[:, :, 4 * nsymbols : ]).numpy()
-rate_input = np.concatenate((qsymbols, hard_distr_embed, enc_lambda), axis=-1)
-rates = sq_rate_metric(None, rate_input, reduce=False).numpy()
+hard_distr_embed = tf.math.sigmoid(quant_embed[:, 4 * nsymbols : ]).numpy()
+#rate_input = np.concatenate((qsymbols, hard_distr_embed, enc_lambda), axis=-1)
+#rates = sq_rate_metric(None, rate_input, reduce=False).numpy()
# run decoder
input_length = args.num_redundancy_frames // 2
@@ -134,22 +138,48 @@
packets = []
packet_sizes = []
+sym_batch = np.zeros((num_frames-offset, args.num_redundancy_frames//2, nsymbols), dtype='float32')
+quant_state = quant_gru_state_dec[0, offset:num_frames, :]
+#pack symbols for batch processing
for i in range(offset, num_frames):
- print(f"processing frame {i - offset}...")
- features = decoder.predict([qsymbols[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_embed_dec[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_gru_state_dec[:, i, :]])
- packets.append(features)
- packet_size = 8 * int((np.sum(rates[:, i - 2 * input_length + 2 : i + 1 : 2]) + 7) / 8) + 64
- packet_sizes.append(packet_size)
+ sym_batch[i-offset, :, :] = symbols[0, i - 2 * input_length + 2 : i + 1 : 2, :]
+#quantize symbols
+sym_batch = sym_batch * quant_scale
+sym_batch = apply_dead_zone([sym_batch, dead_zone]).numpy()
+sym_batch = np.round(sym_batch)
+hard_distr_embed = np.broadcast_to(hard_distr_embed, (sym_batch.shape[0], sym_batch.shape[1], 2*sym_batch.shape[2]))
+fake_lambda = np.ones((sym_batch.shape[0], sym_batch.shape[1], 1), dtype='float32')
+rate_input = np.concatenate((sym_batch, hard_distr_embed, fake_lambda), axis=-1)
+rates = sq_rate_metric(None, rate_input, reduce=False).numpy()
+print("rate = ", np.mean(rates))
+
+sym_batch = sym_batch / quant_scale
+print(sym_batch.shape, quant_state.shape)
+#features = decoder.predict([sym_batch, quant_state])
+features = decoder([sym_batch, quant_state])
+
+#for i in range(offset, num_frames):
+# print(f"processing frame {i - offset}...")
+# features = decoder.predict([qsymbols[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_embed_dec[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_gru_state_dec[:, i, :]])
+# packets.append(features)
+# packet_size = 8 * int((np.sum(rates[:, i - 2 * input_length + 2 : i + 1 : 2]) + 7) / 8) + 64
+# packet_sizes.append(packet_size)
+
+
# write packets
packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output
-write_fec_packets(packet_file, packets, packet_sizes)
+#write_fec_packets(packet_file, packets, packet_sizes)
-print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps")
+#print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps")
+#create packets array like in the original version for debugging purposes
+for i in range(offset, num_frames):
+ packets.append(features[i-offset:i-offset+1, :, :])
+
if args.debug_output:
import itertools
@@ -160,6 +190,7 @@
for batch, offset in itertools.product(batches, offsets):
stop = packets[0].shape[1] - offset
+ print(batch, offset, stop)
test_features = np.concatenate([packet[:,stop - batch: stop, :] for packet in packets[::batch//2]], axis=1)
test_features_full = np.zeros((test_features.shape[1], nb_features), dtype=np.float32)
--
⑨