shithub: opus

Download patch

ref: 82bec7d2a6a78aaad7e8cb5bf4056dc249a7d040
parent: b3ed2bb9cb8572fff947ca6ec9584f63afd8410e
author: Ralph Giles <giles@thaumas.net>
date: Thu Apr 17 11:16:09 EDT 2025

Remove trailing whitespace from the dnn torch modules

This is general best practice, but we also have a failing github
action complaining about these new files.

Signed-off-by: Jean-Marc Valin <jeanmarcv@google.com>

--- a/dnn/torch/osce/losses/td_lowpass.py
+++ b/dnn/torch/osce/losses/td_lowpass.py
@@ -7,28 +7,23 @@
 class TDLowpass(torch.nn.Module):
     def __init__(self, numtaps, cutoff, power=2):
         super().__init__()
-        
+
         self.b = scipy.signal.firwin(numtaps, cutoff)
         self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
         self.power = power
-        
+
     def forward(self, y_true, y_pred):
-        
+
         assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
-        
+
         diff = y_true - y_pred
         diff_lp = torch.nn.functional.conv1d(diff, self.weight)
-        
+
         loss = torch.mean(torch.abs(diff_lp ** self.power))
-        
+
         return loss, diff_lp
-    
+
     def get_freqz(self):
         freq, response = scipy.signal.freqz(self.b)
-        
+
         return freq, response
-        
-        
-        
-        
-        
\ No newline at end of file
--- a/dnn/torch/osce/silk_16_to_48.py
+++ b/dnn/torch/osce/silk_16_to_48.py
@@ -12,17 +12,17 @@
 
 if __name__ == "__main__":
     args = parser.parse_args()
-    
+
     fs, x = wavfile.read(args.input)
 
     # being lazy for now
     assert fs == 16000 and x.dtype == np.int16
-    
+
     x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)
-    
+
     upsampler = SilkUpsampler()
     y = upsampler(x)
-    
+
     y = y.squeeze().numpy().astype(np.int16)
-    
-    wavfile.write(args.output, 48000, y[13:])
\ No newline at end of file
+
+    wavfile.write(args.output, 48000, y[13:])
--- a/dnn/torch/osce/utils/layers/fir.py
+++ b/dnn/torch/osce/utils/layers/fir.py
@@ -8,20 +8,20 @@
 class FIR(nn.Module):
     def __init__(self, numtaps, bands, desired, fs=2):
         super().__init__()
-        
+
         if numtaps % 2 == 0:
             print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
             numtaps += 1
-        
+
         a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
-        
+
         self.weight = torch.from_numpy(a.astype(np.float32))
-        
+
     def forward(self, x):
         num_channels = x.size(1)
-        
+
         weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
-        
+
         y = F.conv1d(x, weight, groups=num_channels)
-        
-        return y
\ No newline at end of file
+
+        return y
--