shithub: opus

Download patch

ref: 7df2c67be1a976cf10b7094b289180b1b5bb1c94
parent: 3499d0aac76d20ba14918cafb8020278154bf2e6
author: Jan Buethe <jbuethe@amazon.de>
date: Tue Jan 23 12:10:34 EST 2024

fixes in osce python code

--- a/dnn/torch/osce/models/no_lace.py
+++ b/dnn/torch/osce/models/no_lace.py
@@ -177,10 +177,7 @@
     def feature_transform(self, f, layer):
         f0 = f.permute(0, 2, 1)
         f = F.pad(f0, [1, 0])
-        if self.residual_in_feature_transform:
-            f = torch.tanh(layer(f) + f0)
-        else:
-            f = torch.tanh(layer(f))
+        f = torch.tanh(layer(f))
         return f.permute(0, 2, 1)
 
     def forward(self, x, features, periods, numbits, debug=False):
--- a/dnn/torch/osce/models/silk_feature_net_pl.py
+++ b/dnn/torch/osce/models/silk_feature_net_pl.py
@@ -92,7 +92,7 @@
 
     def flop_count(self, rate=200):
         count = 0
-        for conv in [self.conv1, self.conv2] if self.repeat_upsamp else [self.conv1, self.conv2, self.tconv]:
+        for conv in self.conv1, self.conv2, self.tconv:
             count += _conv1d_flop_count(conv, rate)
 
         count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
--