shithub: opus

Download patch

ref: 1b13f6313e8413056f6d9f1f15fa994d0dff7a57
parent: 4f4b6242099998d7acf89e17c287dc7f605af607
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Aug 30 14:36:09 EDT 2023

FARGAN initial commit in Opus

Copied/adapted from LPCNet repo

--- /dev/null
+++ b/dnn/torch/fargan/dataset.py
@@ -1,0 +1,52 @@
+import torch
+import numpy as np
+
+class FARGANDataset(torch.utils.data.Dataset):
+    def __init__(self,
+                feature_file,
+                signal_file,
+                frame_size=160,
+                sequence_length=15,
+                lookahead=1,
+                nb_used_features=20,
+                nb_features=36):
+
+        self.frame_size = frame_size
+        self.sequence_length = sequence_length
+        self.lookahead = lookahead
+        self.nb_features = nb_features
+        self.nb_used_features = nb_used_features
+        pcm_chunk_size = self.frame_size*self.sequence_length
+
+        self.data = np.memmap(signal_file, dtype='int16', mode='r')
+        #self.data = self.data[1::2]
+        self.nb_sequences = len(self.data)//(pcm_chunk_size)-1
+        self.data = self.data[(4-self.lookahead)*self.frame_size:]
+        self.data = self.data[:self.nb_sequences*pcm_chunk_size]
+
+
+        self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size))
+
+        self.features = np.reshape(np.memmap(feature_file, dtype='float32', mode='r'), (-1, nb_features))
+        sizeof = self.features.strides[-1]
+        self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length+4, nb_features),
+                                           strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
+        self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
+
+        self.lpc = self.features[:, :, self.nb_used_features:]
+        self.features = self.features[:, :, :self.nb_used_features]
+        print("lpc_size:", self.lpc.shape)
+
+    def __len__(self):
+        return self.nb_sequences
+
+    def __getitem__(self, index):
+        features = self.features[index, :, :].copy()
+        if self.lookahead != 0:
+            lpc = self.lpc[index, 4-self.lookahead:-self.lookahead, :].copy()
+        else:
+            lpc = self.lpc[index, 4:, :].copy()
+        data = self.data[index, :].copy().astype(np.float32) / 2**15
+        periods = self.periods[index, :].copy()
+
+        return features, periods, data, lpc
--- /dev/null
+++ b/dnn/torch/fargan/fargan.py
@@ -1,0 +1,260 @@
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import filters
+from torch.nn.utils import weight_norm
+
+Fs = 16000
+
+fid_dict = {}
+def dump_signal(x, filename):
+    return
+    if filename in fid_dict:
+        fid = fid_dict[filename]
+    else:
+        fid = open(filename, "w")
+        fid_dict[filename] = fid
+    x = x.detach().numpy().astype('float32')
+    x.tofile(fid)
+
+
+def sig_l1(y_true, y_pred):
+    return torch.mean(abs(y_true-y_pred))/torch.mean(abs(y_true))
+
+def sig_loss(y_true, y_pred):
+    t = y_true/(1e-15+torch.norm(y_true, dim=-1, p=2, keepdim=True))
+    p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True))
+    return torch.mean(1.-torch.sum(p*t, dim=-1))
+
+
+def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
+    device = x.device
+    batch_size = lpc.size(0)
+
+    nb_frames = lpc.shape[1]
+
+
+    sig = torch.zeros(batch_size, subframe_size+16, device=device)
+    x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size))
+    out = torch.zeros((batch_size, 0), device=device)
+
+    if gamma is not None:
+        bw = gamma**(torch.arange(1, 17, device=device))
+        lpc = lpc*bw[None,None,:]
+    ones = torch.ones((*(lpc.shape[:-1]), 1), device=device)
+    zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device)
+    a = torch.cat([ones, lpc], -1)
+    a_big = torch.cat([a, zeros], -1)
+    fir_mat_big = filters.toeplitz_from_filter(a_big)
+
+    #print(a_big[:,0,:])
+    for n in range(nb_frames):
+        for k in range(nb_subframes):
+
+            sig = torch.cat([sig[:,subframe_size:], x[:,n*nb_subframes + k, :]], 1)
+            exc = torch.bmm(fir_mat_big[:,n,:,:], sig[:,:,None])
+            out = torch.cat([out, exc[:,-subframe_size:,0]], 1)
+
+    return out
+
+
+# weight initialization and clipping
+def init_weights(module):
+    if isinstance(module, nn.GRU):
+        for p in module.named_parameters():
+            if p[0].startswith('weight_hh_'):
+                nn.init.orthogonal_(p[1])
+
+def gen_phase_embedding(periods, frame_size):
+    device = periods.device
+    batch_size = periods.size(0)
+    nb_frames = periods.size(1)
+    w0 = 2*torch.pi/periods
+    w0_shift = torch.cat([2*torch.pi*torch.rand((batch_size, 1), device=device)/frame_size, w0[:,:-1]], 1)
+    cum_phase = frame_size*torch.cumsum(w0_shift, 1)
+    fine_phase = w0[:,:,None]*torch.broadcast_to(torch.arange(frame_size, device=device), (batch_size, nb_frames, frame_size))
+    embed = torch.unsqueeze(cum_phase, 2) + fine_phase
+    embed = torch.reshape(embed, (batch_size, -1))
+    return torch.cos(embed), torch.sin(embed)
+
+class GLU(nn.Module):
+    def __init__(self, feat_size):
+        super(GLU, self).__init__()
+        
+        torch.manual_seed(5)
+
+        self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
+
+        self.init_weights()
+
+    def init_weights(self):
+    
+        for m in self.modules():
+            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
+            or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
+                nn.init.orthogonal_(m.weight.data)
+
+    def forward(self, x):
+        
+        out = x * torch.sigmoid(self.gate(x)) 
+        
+        return out
+
+
+class FARGANCond(nn.Module):
+    def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64):
+        super(FARGANCond, self).__init__()
+
+        self.feature_dim = feature_dim
+        self.cond_size = cond_size
+
+        self.pembed = nn.Embedding(256, pembed_dims)
+        self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False)
+        self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
+        self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
+        self.fdense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
+
+        self.apply(init_weights)
+
+    def forward(self, features, period):
+        p = self.pembed(period)
+        features = torch.cat((features, p), -1)
+        tmp = torch.tanh(self.fdense1(features))
+        tmp = tmp.permute(0, 2, 1)
+        tmp = torch.tanh(self.fconv1(tmp))
+        tmp = torch.tanh(self.fconv2(tmp))
+        tmp = tmp.permute(0, 2, 1)
+        tmp = torch.tanh(self.fdense2(tmp))
+        return tmp
+
+class FARGANSub(nn.Module):
+    def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, passthrough_size=0, has_gain=False):
+        super(FARGANSub, self).__init__()
+
+        self.subframe_size = subframe_size
+        self.nb_subframes = nb_subframes
+        self.cond_size = cond_size
+        self.has_gain = has_gain
+        self.passthrough_size = passthrough_size
+        
+        print("has_gain:", self.has_gain)
+        print("passthrough_size:", self.passthrough_size)
+        
+        gain_param = 1 if self.has_gain else 0
+
+        self.sig_dense1 = nn.Linear(3*self.subframe_size+self.passthrough_size+self.cond_size+gain_param, self.cond_size, bias=False)
+        self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
+        self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
+        self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
+        self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
+        
+        self.dense1_glu = GLU(self.cond_size)
+        self.dense2_glu = GLU(self.cond_size)
+        self.gru1_glu = GLU(self.cond_size)
+        self.gru2_glu = GLU(self.cond_size)
+        self.gru3_glu = GLU(self.cond_size)
+        
+        self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size+self.passthrough_size, bias=False)
+        if self.has_gain:
+            self.gain_dense_out = nn.Linear(self.cond_size, 1)
+
+
+        self.apply(init_weights)
+
+    def forward(self, cond, prev, exc_mem, phase, period, states):
+        device = exc_mem.device
+        #print(cond.shape, prev.shape)
+        
+        dump_signal(prev, 'prev_in.f32')
+
+        idx = 256-torch.maximum(torch.tensor(self.subframe_size, device=device), period[:,None])
+        rng = torch.arange(self.subframe_size, device=device)
+        idx = idx + rng[None,:]
+        prev = torch.gather(exc_mem, 1, idx)
+        #prev = prev*0
+        dump_signal(prev, 'pitch_exc.f32')
+        dump_signal(exc_mem, 'exc_mem.f32')
+        if self.has_gain:
+            gain = torch.norm(prev, dim=1, p=2, keepdim=True)
+            prev = prev/(1e-5+gain)
+            prev = torch.cat([prev, torch.log(1e-5+gain)], 1)
+
+        passthrough = states[3]
+        tmp = torch.cat((cond, prev, passthrough, phase), 1)
+
+        tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
+        tmp = self.dense2_glu(torch.tanh(self.sig_dense2(tmp)))
+        gru1_state = self.gru1(tmp, states[0])
+        gru2_state = self.gru2(self.gru1_glu(gru1_state), states[1])
+        gru3_state = self.gru3(self.gru2_glu(gru2_state), states[2])
+        gru3_out = self.gru3_glu(gru3_state)
+        sig_out = torch.tanh(self.sig_dense_out(gru3_out))
+        if self.passthrough_size != 0:
+            passthrough = sig_out[:,self.subframe_size:]
+            sig_out = sig_out[:,:self.subframe_size]
+        if self.has_gain:
+            out_gain = torch.exp(self.gain_dense_out(gru3_out))
+            sig_out = sig_out * out_gain
+        dump_signal(sig_out, 'exc_out.f32')
+        exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
+        dump_signal(sig_out, 'sig_out.f32')
+        return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough)
+
+class FARGAN(nn.Module):
+    def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
+        super(FARGAN, self).__init__()
+
+        self.subframe_size = subframe_size
+        self.nb_subframes = nb_subframes
+        self.frame_size = self.subframe_size*self.nb_subframes
+        self.feature_dim = feature_dim
+        self.cond_size = cond_size
+        self.has_gain = has_gain
+        self.passthrough_size = passthrough_size
+
+        self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size)
+        self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, has_gain=has_gain, passthrough_size=passthrough_size)
+
+    def forward(self, features, period, nb_frames, pre=None, states=None):
+        device = features.device
+        batch_size = features.size(0)
+
+        phase_real, phase_imag = gen_phase_embedding(period[:, 3:-1], self.frame_size)
+        #np.round(32000*phase.detach().numpy()).astype('int16').tofile('phase.sw')
+
+        prev = torch.zeros(batch_size, self.subframe_size, device=device)
+        exc_mem = torch.zeros(batch_size, 256, device=device)
+        nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
+
+        if states is None:
+            states = (
+                torch.zeros(batch_size, self.cond_size, device=device),
+                torch.zeros(batch_size, self.cond_size, device=device),
+                torch.zeros(batch_size, self.cond_size, device=device),
+                torch.zeros(batch_size, self.passthrough_size, device=device)
+            )
+
+        sig = torch.zeros((batch_size, 0), device=device)
+        cond = self.cond_net(features, period)
+        passthrough = torch.zeros(batch_size, self.passthrough_size, device=device)
+        for n in range(nb_frames+nb_pre_frames):
+            for k in range(self.nb_subframes):
+                pos = n*self.frame_size + k*self.subframe_size
+                preal = phase_real[:, pos:pos+self.subframe_size]
+                pimag = phase_imag[:, pos:pos+self.subframe_size]
+                phase = torch.cat([preal, pimag], 1)
+                #print("now: ", preal.shape, prev.shape, sig_in.shape)
+                pitch = period[:, 3+n]
+                out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states)
+
+                if n < nb_pre_frames:
+                    out = pre[:, pos:pos+self.subframe_size]
+                    exc_mem[:,-self.subframe_size:] = out
+                else:
+                    sig = torch.cat([sig, out], 1)
+
+                prev = out
+        states = [s.detach() for s in states]
+        return sig, states
+
--- /dev/null
+++ b/dnn/torch/fargan/filters.py
@@ -1,0 +1,46 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+import math
+
+def toeplitz_from_filter(a):
+    device = a.device
+    L = a.size(-1)
+    size0 = (*(a.shape[:-1]), L, L+1)
+    size = (*(a.shape[:-1]), L, L)
+    rnge = torch.arange(0, L, dtype=torch.int64, device=device)
+    z = torch.tensor(0, device=device)
+    idx = torch.maximum(rnge[:,None] - rnge[None,:] + 1, z)
+    a = torch.cat([a[...,:1]*0, a], -1)
+    #print(a)
+    a = a[...,None,:]
+    #print(idx)
+    a = torch.broadcast_to(a, size0)
+    idx = torch.broadcast_to(idx, size)
+    #print(idx)
+    return torch.gather(a, -1, idx)
+
+def filter_iir_response(a, N):
+    device = a.device
+    L = a.size(-1)
+    ar = a.flip(dims=(2,))
+    size = (*(a.shape[:-1]), N)
+    R = torch.zeros(size, device=device)
+    R[:,:,0] = torch.ones((a.shape[:-1]), device=device)
+    for i in range(1, L):
+        R[:,:,i] = - torch.sum(ar[:,:,L-i-1:-1] * R[:,:,:i], axis=-1)
+        #R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,L-i-1:-1], R[:,:,:i])
+    for i in range(L, N):
+        R[:,:,i] = - torch.sum(ar[:,:,:-1] * R[:,:,i-L+1:i], axis=-1)
+        #R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,:-1], R[:,:,i-L+1:i])
+    return R
+
+if __name__ == '__main__':
+    #a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]], [[1, .9, 0], [1, .8, 0]]])
+    a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]]])
+    A = toeplitz_from_filter(a)
+    #print(A)
+    R = filter_iir_response(a, 5)
+    
+    RA = toeplitz_from_filter(R)
+    print(RA)
--- /dev/null
+++ b/dnn/torch/fargan/stft_loss.py
@@ -1,0 +1,184 @@
+"""STFT-based Loss modules."""
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import torchaudio
+
+
+def stft(x, fft_size, hop_size, win_length, window):
+    """Perform STFT and convert to magnitude spectrogram.
+    Args:
+        x (Tensor): Input signal tensor (B, T).
+        fft_size (int): FFT size.
+        hop_size (int): Hop size.
+        win_length (int): Window length.
+        window (str): Window function type.
+    Returns:
+        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+    """
+    
+    #x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
+    #real = x_stft[..., 0]
+    #imag = x_stft[..., 1]
+
+    # (kan-bayashi): clamp is needed to avoid nan or inf
+    #return torchaudio.functional.amplitude_to_DB(torch.abs(x_stft),db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
+    #return torch.clamp(torch.abs(x_stft), min=1e-7)
+
+    x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
+    return torch.clamp(torch.abs(x_stft), min=1e-7)
+
+class SpectralConvergenceLoss(torch.nn.Module):
+    """Spectral convergence loss module."""
+
+    def __init__(self):
+        """Initilize spectral convergence loss module."""
+        super(SpectralConvergenceLoss, self).__init__()
+
+    def forward(self, x_mag, y_mag):
+        """Calculate forward propagation.
+        Args:
+            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            Tensor: Spectral convergence loss value.
+        """
+        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
+
+class LogSTFTMagnitudeLoss(torch.nn.Module):
+    """Log STFT magnitude loss module."""
+
+    def __init__(self):
+        """Initilize los STFT magnitude loss module."""
+        super(LogSTFTMagnitudeLoss, self).__init__()
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+        Args:
+            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            Tensor: Log STFT magnitude loss value.
+        """
+        #F.l1_loss(torch.sqrt(y_mag), torch.sqrt(x_mag)) +
+        #F.l1_loss(torchaudio.functional.amplitude_to_DB(y_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80),\
+        #torchaudio.functional.amplitude_to_DB(x_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80))
+
+        #y_mag[:,:y_mag.size(1)//2,:] = y_mag[:,:y_mag.size(1)//2,:] *0.0
+
+        #return F.l1_loss(torch.log(y_mag) + torch.sqrt(y_mag), torch.log(x_mag) + torch.sqrt(x_mag))
+
+        #return F.l1_loss(y_mag, x_mag)
+
+        error_loss =  F.l1_loss(y, x) #+ F.l1_loss(torch.sqrt(y), torch.sqrt(x))#F.l1_loss(torch.log(y), torch.log(x))#
+
+        #x = torch.log(x)
+        #y = torch.log(y)
+        #x = x.permute(0,2,1).contiguous()
+        #y = y.permute(0,2,1).contiguous()
+
+        '''mean_x = torch.mean(x, dim=1, keepdim=True)
+        mean_y = torch.mean(y, dim=1, keepdim=True)
+
+        var_x = torch.var(x, dim=1, keepdim=True)
+        var_y = torch.var(y, dim=1, keepdim=True)
+        
+        std_x = torch.std(x, dim=1, keepdim=True)
+        std_y = torch.std(y, dim=1, keepdim=True)
+        
+        x_minus_mean = x - mean_x
+        y_minus_mean = y - mean_y
+        
+        pearson_corr = torch.sum(x_minus_mean * y_minus_mean, dim=1, keepdim=True) / \
+                    (torch.sqrt(torch.sum(x_minus_mean ** 2, dim=1, keepdim=True) + 1e-7) * \
+                    torch.sqrt(torch.sum(y_minus_mean ** 2, dim=1, keepdim=True) + 1e-7))
+        
+        numerator = 2.0 * pearson_corr * std_x * std_y
+        denominator = var_x + var_y + (mean_y - mean_x)**2
+        
+        ccc = numerator/denominator
+     
+        ccc_loss = F.l1_loss(1.0 - ccc, torch.zeros_like(ccc))'''
+
+        return error_loss #+ ccc_loss#+ ccc_loss
+        
+
+class STFTLoss(torch.nn.Module):
+    """STFT loss module."""
+
+    def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
+        """Initialize STFT loss module."""
+        super(STFTLoss, self).__init__()
+        self.fft_size = fft_size
+        self.shift_size = shift_size
+        self.win_length = win_length
+        self.window = getattr(torch, window)(win_length).to(device)
+        self.spectral_convergenge_loss = SpectralConvergenceLoss()
+        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+        Args:
+            x (Tensor): Predicted signal (B, T).
+            y (Tensor): Groundtruth signal (B, T).
+        Returns:
+            Tensor: Spectral convergence loss value.
+            Tensor: Log STFT magnitude loss value.
+        """
+        x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
+        y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
+        sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+        return sc_loss, mag_loss
+
+
+class MultiResolutionSTFTLoss(torch.nn.Module):
+
+    def __init__(self,
+                 device,
+                 fft_sizes=[2048, 1024, 512, 256, 128, 64],
+                 hop_sizes=[512, 256, 128, 64, 32, 16],
+                 win_lengths=[2048, 1024, 512, 256, 128, 64],
+                 window="hann_window"):
+
+        '''def __init__(self,
+                 device,
+                 fft_sizes=[2048, 1024, 512, 256, 128, 64],
+                 hop_sizes=[256, 128, 64, 32, 16, 8],
+                 win_lengths=[1024, 512, 256, 128, 64, 32],
+                 window="hann_window"):'''
+
+        '''def __init__(self,
+                 device,
+                 fft_sizes=[2560, 1280, 640, 320, 160, 80],
+                 hop_sizes=[640, 320, 160, 80, 40, 20],
+                 win_lengths=[2560, 1280, 640, 320, 160, 80],
+                 window="hann_window"):'''
+
+        super(MultiResolutionSTFTLoss, self).__init__()
+        assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
+        self.stft_losses = torch.nn.ModuleList()
+        for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
+            self.stft_losses += [STFTLoss(device, fs, ss, wl, window)]
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+        Args:
+            x (Tensor): Predicted signal (B, T).
+            y (Tensor): Groundtruth signal (B, T).
+        Returns:
+            Tensor: Multi resolution spectral convergence loss value.
+            Tensor: Multi resolution log STFT magnitude loss value.
+        """
+        sc_loss = 0.0
+        mag_loss = 0.0
+        for f in self.stft_losses:
+            sc_l, mag_l = f(x, y)
+            sc_loss += sc_l
+            #mag_loss += mag_l
+        sc_loss /= len(self.stft_losses)
+        mag_loss /= len(self.stft_losses)
+
+        return  sc_loss #mag_loss #+
--- /dev/null
+++ b/dnn/torch/fargan/test_fargan.py
@@ -1,0 +1,107 @@
+import os
+import argparse
+import numpy as np
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import tqdm
+
+import fargan
+from dataset import FARGANDataset
+
+nb_features = 36
+nb_used_features = 20
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('model', type=str, help='CELPNet model')
+parser.add_argument('features', type=str, help='path to feature file in .f32 format')
+parser.add_argument('output', type=str, help='path to output file (16-bit PCM)')
+
+parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
+
+
+model_group = parser.add_argument_group(title="model parameters")
+model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
+
+args = parser.parse_args()
+
+if args.cuda_visible_devices != None:
+    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
+
+
+features_file = args.features
+signal_file = args.output
+
+
+
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+checkpoint = torch.load(args.model, map_location='cpu')
+
+model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+
+
+model.load_state_dict(checkpoint['state_dict'], strict=False)
+
+features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features))
+lpc = features[:,4-1:-1,nb_used_features:]
+features = features[:, :, :nb_used_features]
+periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
+
+nb_frames = features.shape[1]
+#nb_frames = 1000
+gamma = checkpoint['model_kwargs']['gamma']
+
+def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
+    
+    out = np.zeros_like(frame)
+    filt = np.flip(filt)
+    
+    inp = frame[:]
+    
+    
+    for i in range(0, inp.shape[0]):
+        
+        s = inp[i] - np.dot(buffer*weighting_vector, filt)
+        
+        buffer[0] = s
+        
+        buffer = np.roll(buffer, -1)
+        
+        out[i] = s
+        
+    return out
+
+def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
+    
+    #inverse perceptual weighting= H_preemph / W(z/gamma)
+    
+    signal = np.zeros_like(pw_signal)
+    buffer = np.zeros(16)
+    num_frames = pw_signal.shape[0] //160
+    assert num_frames == filters.shape[0]
+    for frame_idx in range(0, num_frames):
+        
+        in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
+        out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
+        signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
+        buffer[:] = out_sig_frame[-16:]
+    return signal
+
+
+
+if __name__ == '__main__':
+    model.to(device)
+    features = torch.tensor(features).to(device)
+    #lpc = torch.tensor(lpc).to(device)
+    periods = torch.tensor(periods).to(device)
+    
+    sig, _ = model(features, periods, nb_frames - 4)
+    weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
+    sig = sig.detach().numpy().flatten()
+    sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector)
+    
+    pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
+    pcm.tofile(signal_file)
--- /dev/null
+++ b/dnn/torch/fargan/train_fargan.py
@@ -1,0 +1,155 @@
+import os
+import argparse
+import random
+import numpy as np
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import tqdm
+
+import fargan
+from dataset import FARGANDataset
+from stft_loss import *
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('features', type=str, help='path to feature file in .f32 format')
+parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
+parser.add_argument('output', type=str, help='path to output folder')
+
+parser.add_argument('--suffix', type=str, help="model name suffix", default="")
+parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
+
+
+model_group = parser.add_argument_group(title="model parameters")
+model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
+model_group.add_argument('--has-gain', action='store_true', help="use gain-shape network")
+model_group.add_argument('--passthrough-size', type=int, help="state passing through in addition to audio, default: 0", default=0)
+model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
+
+training_group = parser.add_argument_group(title="training parameters")
+training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
+training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3)
+training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20)
+training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15)
+training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4)
+training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
+
+args = parser.parse_args()
+
+if args.cuda_visible_devices != None:
+    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
+
+# checkpoints
+checkpoint_dir = os.path.join(args.output, 'checkpoints')
+checkpoint = dict()
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+
+# training parameters
+batch_size = args.batch_size
+lr = args.lr
+epochs = args.epochs
+sequence_length = args.sequence_length
+lr_decay = args.lr_decay
+
+adam_betas = [0.9, 0.99]
+adam_eps = 1e-8
+features_file = args.features
+signal_file = args.signal
+
+# model parameters
+cond_size  = args.cond_size
+
+
+checkpoint['batch_size'] = batch_size
+checkpoint['lr'] = lr
+checkpoint['lr_decay'] = lr_decay
+checkpoint['epochs'] = epochs
+checkpoint['sequence_length'] = sequence_length
+checkpoint['adam_betas'] = adam_betas
+
+
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+checkpoint['model_args']    = ()
+checkpoint['model_kwargs']  = {'cond_size': cond_size, 'has_gain': args.has_gain, 'passthrough_size': args.passthrough_size, 'gamma': args.gamma}
+print(checkpoint['model_kwargs'])
+model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+
+#model = fargan.FARGAN()
+#model = nn.DataParallel(model)
+
+if type(args.initial_checkpoint) != type(None):
+    checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
+    model.load_state_dict(checkpoint['state_dict'], strict=False)
+
+checkpoint['state_dict']    = model.state_dict()
+
+
+dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
+dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
+
+
+optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
+
+
+# learning rate scheduler
+scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
+
+states = None
+
+spect_loss =  MultiResolutionSTFTLoss(device).to(device)
+
+if __name__ == '__main__':
+    model.to(device)
+
+    for epoch in range(1, epochs + 1):
+
+        running_specc = 0
+        running_cont_loss = 0
+        running_loss = 0
+
+        print(f"training epoch {epoch}...")
+        with tqdm.tqdm(dataloader, unit='batch') as tepoch:
+            for i, (features, periods, target, lpc) in enumerate(tepoch):
+                optimizer.zero_grad()
+                features = features.to(device)
+                lpc = lpc.to(device)
+                periods = periods.to(device)
+                target = target.to(device)
+                target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
+
+                #nb_pre = random.randrange(1, 6)
+                nb_pre = 2
+                pre = target[:, :nb_pre*160]
+                sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
+                sig = torch.cat([pre, sig], -1)
+
+                cont_loss = fargan.sig_l1(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
+                specc_loss = spect_loss(sig, target.detach())
+                loss = .2*cont_loss + specc_loss
+
+                loss.backward()
+                optimizer.step()
+                
+                #model.clip_weights()
+                
+                scheduler.step()
+
+                running_specc += specc_loss.detach().cpu().item()
+                running_cont_loss += cont_loss.detach().cpu().item()
+
+                running_loss += loss.detach().cpu().item()
+                tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
+                                   cont_loss=f"{running_cont_loss/(i+1):8.5f}",
+                                   specc=f"{running_specc/(i+1):8.5f}",
+                                   )
+
+        # save checkpoint
+        checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth')
+        checkpoint['state_dict'] = model.state_dict()
+        checkpoint['loss'] = running_loss / len(dataloader)
+        checkpoint['epoch'] = epoch
+        torch.save(checkpoint, checkpoint_path)
--