ref: 7df2c67be1a976cf10b7094b289180b1b5bb1c94
parent: 3499d0aac76d20ba14918cafb8020278154bf2e6
author: Jan Buethe <jbuethe@amazon.de>
date: Tue Jan 23 12:10:34 EST 2024
fixes in osce python code
--- a/dnn/torch/osce/models/no_lace.py
+++ b/dnn/torch/osce/models/no_lace.py
@@ -177,10 +177,7 @@
def feature_transform(self, f, layer):
f0 = f.permute(0, 2, 1)
f = F.pad(f0, [1, 0])
- if self.residual_in_feature_transform:
- f = torch.tanh(layer(f) + f0)
- else:
- f = torch.tanh(layer(f))
+ f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, x, features, periods, numbits, debug=False):
--- a/dnn/torch/osce/models/silk_feature_net_pl.py
+++ b/dnn/torch/osce/models/silk_feature_net_pl.py
@@ -92,7 +92,7 @@
def flop_count(self, rate=200):
count = 0
- for conv in [self.conv1, self.conv2] if self.repeat_upsamp else [self.conv1, self.conv2, self.tconv]:
+ for conv in self.conv1, self.conv2, self.tconv:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
--
⑨