ref: 82bec7d2a6a78aaad7e8cb5bf4056dc249a7d040
parent: b3ed2bb9cb8572fff947ca6ec9584f63afd8410e
author: Ralph Giles <giles@thaumas.net>
date: Thu Apr 17 11:16:09 EDT 2025
Remove trailing whitespace from the dnn torch modules This is general best practice, but we also have a failing github action complaining about these new files. Signed-off-by: Jean-Marc Valin <jeanmarcv@google.com>
--- a/dnn/torch/osce/losses/td_lowpass.py
+++ b/dnn/torch/osce/losses/td_lowpass.py
@@ -7,28 +7,23 @@
class TDLowpass(torch.nn.Module):
def __init__(self, numtaps, cutoff, power=2):
super().__init__()
-
+
self.b = scipy.signal.firwin(numtaps, cutoff)
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
self.power = power
-
+
def forward(self, y_true, y_pred):
-
+
assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
-
+
diff = y_true - y_pred
diff_lp = torch.nn.functional.conv1d(diff, self.weight)
-
+
loss = torch.mean(torch.abs(diff_lp ** self.power))
-
+
return loss, diff_lp
-
+
def get_freqz(self):
freq, response = scipy.signal.freqz(self.b)
-
+
return freq, response
-
-
-
-
-
\ No newline at end of file
--- a/dnn/torch/osce/silk_16_to_48.py
+++ b/dnn/torch/osce/silk_16_to_48.py
@@ -12,17 +12,17 @@
if __name__ == "__main__":
args = parser.parse_args()
-
+
fs, x = wavfile.read(args.input)
# being lazy for now
assert fs == 16000 and x.dtype == np.int16
-
+
x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)
-
+
upsampler = SilkUpsampler()
y = upsampler(x)
-
+
y = y.squeeze().numpy().astype(np.int16)
-
- wavfile.write(args.output, 48000, y[13:])
\ No newline at end of file
+
+ wavfile.write(args.output, 48000, y[13:])
--- a/dnn/torch/osce/utils/layers/fir.py
+++ b/dnn/torch/osce/utils/layers/fir.py
@@ -8,20 +8,20 @@
class FIR(nn.Module):
def __init__(self, numtaps, bands, desired, fs=2):
super().__init__()
-
+
if numtaps % 2 == 0:
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
numtaps += 1
-
+
a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
-
+
self.weight = torch.from_numpy(a.astype(np.float32))
-
+
def forward(self, x):
num_channels = x.size(1)
-
+
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
-
+
y = F.conv1d(x, weight, groups=num_channels)
-
- return y
\ No newline at end of file
+
+ return y
--
⑨