shithub: opus

Download patch

ref: e7beaec3fb49df389b077799c5d1778ccb68610e
parent: b24c7b433ae9db990dbd52eb0f1b357568fb484c
author: Jan Buethe <jbuethe@amazon.de>
date: Wed Sep 13 12:31:29 EDT 2023

integrated JM's FFT ada conv

Signed-off-by: Jan Buethe <jbuethe@amazon.de>

--- /dev/null
+++ b/dnn/torch/osce/utils/ada_conv.py
@@ -1,0 +1,71 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jean-Marc Valin */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+# x is (batch, nb_in_channels, nb_frames*frame_size)
+# kernels is (batch, nb_out_channels, nb_in_channels, nb_frames, coeffs)
+def adaconv_kernel(x, kernels, half_window, fft_size=256):
+    device=x.device
+    overlap_size=half_window.size(-1)
+    nb_frames=kernels.size(3)
+    nb_batches=kernels.size(0)
+    nb_out_channels=kernels.size(1)
+    nb_in_channels=kernels.size(2)
+    kernel_size = kernels.size(-1)
+    x = x.reshape(nb_batches, 1, nb_in_channels, nb_frames, -1)
+    frame_size = x.size(-1)
+    # build window: [zeros, rising window, ones, falling window, zeros]
+    window = torch.cat(
+        [
+            torch.zeros(frame_size, device=device),
+            half_window,
+            torch.ones(frame_size - overlap_size, device=device),
+            1 - half_window,
+            torch.zeros(fft_size - 2 * frame_size - overlap_size,device=device)
+        ])
+    x_prev = torch.cat([torch.zeros_like(x[:, :, :, :1, :]), x[:, :, :, :-1, :]], dim=-2)
+    x_next = torch.cat([x[:, :, :, 1:, :overlap_size], torch.zeros_like(x[:, :, :, -1:, :overlap_size])], dim=-2)
+    x_padded = torch.cat([x_prev, x, x_next, torch.zeros(nb_batches, 1, nb_in_channels, nb_frames, fft_size - 2 * frame_size - overlap_size, device=device)], -1)
+    k_padded = torch.cat([torch.flip(kernels, [-1]), torch.zeros(nb_batches, nb_out_channels, nb_in_channels, nb_frames, fft_size-kernel_size, device=device)], dim=-1)
+
+    # compute convolution
+    X = torch.fft.rfft(x_padded, dim=-1)
+    K = torch.fft.rfft(k_padded, dim=-1)
+
+    out = torch.fft.irfft(X * K, dim=-1)
+    # combine in channels
+    out = torch.sum(out, dim=2)
+    # apply the cross-fading
+    out = window.reshape(1, 1, 1, -1)*out
+    crossfaded = out[:,:,:,frame_size:2*frame_size] + torch.cat([torch.zeros(nb_batches, nb_out_channels, 1, frame_size, device=device), out[:, :, :-1, 2*frame_size:3*frame_size]], dim=-2)
+
+    return crossfaded.reshape(nb_batches, nb_out_channels, -1)
\ No newline at end of file
--- a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
@@ -33,6 +33,9 @@
 
 from utils.endoscopy import write_data
 
+from utils.ada_conv import adaconv_kernel
+
+
 class LimitedAdaptiveConv1d(nn.Module):
     COUNTER = 1
 
@@ -184,39 +187,19 @@
             conv_biases  = self.conv_bias(features).permute(0, 2, 1)
 
         # calculate gains
-        conv_gains   = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
+        conv_gains   = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features)) + self.filter_gain_b)
         if debug and batch_size == 1:
             key = self.name + "_gains"
-            write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+            write_data(key, conv_gains.permute(0, 2, 1).detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
             key = self.name + "_kernels"
             write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
 
 
-        # frame-wise convolution with overlap-add
-        output_frames = []
-        overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
-        x = F.pad(x, self.padding)
-        x = F.pad(x, [0, self.overlap_size])
+        conv_kernels = conv_kernels * conv_gains.view(batch_size, num_frames, self.out_channels, 1, 1)
 
-        for i in range(num_frames):
-            xx = x[:, :, i * frame_size : (i + 1) * frame_size + kernel_size - 1 + overlap_size].reshape((1, batch_size * self.in_channels, -1))
-            new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
+        conv_kernels = conv_kernels.permute(0, 2, 3, 1, 4)
 
-            if self.use_bias:
-                new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
+        output = adaconv_kernel(x, conv_kernels, win1, fft_size=256)
 
-            new_chunk = new_chunk * conv_gains[:, :, i : i + 1]
-
-            # overlapping part
-            output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
-
-            # non-overlapping part
-            output_frames.append(new_chunk[:, :, overlap_size : frame_size])
-
-            # mem for next frame
-            overlap_mem = new_chunk[:, :, frame_size :]
-
-        # concatenate chunks
-        output = torch.cat(output_frames, dim=-1)
 
         return output
\ No newline at end of file
--