shithub: opus

ref: 08f5ff06bb4c769e48802e87160dd3c1ca230b1e
dir: /dnn/torch/osce/utils/layers/fir.py/

View raw version
import numpy as np
import scipy.signal
import torch
from torch import nn
import torch.nn.functional as F


class FIR(nn.Module):
    def __init__(self, numtaps, bands, desired, fs=2):
        super().__init__()
        
        if numtaps % 2 == 0:
            print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
            numtaps += 1
        
        a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
        
        self.weight = torch.from_numpy(a.astype(np.float32))
        
    def forward(self, x):
        num_channels = x.size(1)
        
        weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
        
        y = F.conv1d(x, weight, groups=num_channels)
        
        return y