ref: 0e5a38fac65fb06dddc75755b3d8b3ca01d6b524
parent: 26073861985df7d81688bf6d56bf1360e490ee5a
author: jbuethe <jbuethe@amazon.de>
date: Mon Nov 7 11:13:48 EST 2022
removed deprecated lambda from fec_encoder
--- a/dnn/training_tf2/fec_encoder.py
+++ b/dnn/training_tf2/fec_encoder.py
@@ -59,7 +59,7 @@
parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)')
parser.add_argument('weights', metavar='<weights>', help='trained model file (.h5)')
- parser.add_argument('enc_lambda', metavar='<lambda>', type=float, help='lambda for controlling encoder rate')
+# parser.add_argument('enc_lambda', metavar='<lambda>', type=float, help='lambda for controlling encoder rate')
parser.add_argument('output', type=str, help='output file (will be extended with .fec)')
parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)')
@@ -181,13 +181,13 @@
fake_lambda = np.ones((sym_batch.shape[0], sym_batch.shape[1], 1), dtype='float32')
rate_input = np.concatenate((sym_batch, hard_distr_embed, fake_lambda), axis=-1)
rates = sq_rate_metric(None, rate_input, reduce=False).numpy()
-print(rates.shape)
+#print(rates.shape)
print("average rate = ", np.mean(rates[args.num_redundancy_frames:,:]))
#sym_batch.tofile('qsyms.f32')
sym_batch = sym_batch / quant_scale
-print(sym_batch.shape, quant_state.shape)
+#print(sym_batch.shape, quant_state.shape)
#features = decoder.predict([sym_batch, quant_state])
features = decoder([sym_batch, quant_state])
@@ -236,8 +236,10 @@
if args.debug_output:
import itertools
- batches = [2, 4]
- offsets = [0, 4, 20]
+ #batches = [2, 4]
+ batches = [4]
+ #offsets = [0, 4, 20]
+ offsets = [0, (args.num_redundancy_frames - 2)*2]
# sanity checks
# 1. concatenate features at offset 0
for batch, offset in itertools.product(batches, offsets):
@@ -249,6 +251,6 @@
test_features_full = np.zeros((test_features.shape[1], nb_features), dtype=np.float32)
test_features_full[:, :nb_used_features] = test_features[0, :, :]
- print(f"writing debug output {packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32'}")
- test_features_full.tofile(packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32')
+ print(f"writing debug output {packet_file[:-4] + f'_tf_batch{batch}_offset{offset}.f32'}")
+ test_features_full.tofile(packet_file[:-4] + f'_tf_batch{batch}_offset{offset}.f32')
--
⑨