shithub: opus

ref: 35ee397e060283d30c098ae5e17836316bbec08b
dir: /dnn/torch/lpcnet/models/lpcnet.py/

View raw version
import torch
from torch import nn
import numpy as np

from utils.ulaw import lin2ulawq, ulaw2lin
from utils.sample import sample_excitation
from utils.pcm import clip_to_int16
from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
from utils.layers import DualFC
from utils.misc import get_pdf_from_tree


class LPCNet(nn.Module):
    def __init__(self, config):
        super(LPCNet, self).__init__()

        #
        self.input_layout       = config['input_layout']
        self.feature_history    = config['feature_history']
        self.feature_lookahead  = config['feature_lookahead']

        # frame rate network parameters
        self.feature_dimension          = config['feature_dimension']
        self.period_embedding_dim       = config['period_embedding_dim']
        self.period_levels              = config['period_levels']
        self.feature_channels           = self.feature_dimension + self.period_embedding_dim
        self.feature_conditioning_dim   = config['feature_conditioning_dim']
        self.feature_conv_kernel_size   = config['feature_conv_kernel_size']


        # frame rate network layers
        self.period_embedding   = nn.Embedding(self.period_levels, self.period_embedding_dim)
        self.feature_conv1      = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
        self.feature_conv2      = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
        self.feature_dense1     = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
        self.feature_dense2     = nn.Linear(*(2*[self.feature_conditioning_dim]))

        # sample rate network parameters
        self.frame_size             = config['frame_size']
        self.signal_levels          = config['signal_levels']
        self.signal_embedding_dim   = config['signal_embedding_dim']
        self.gru_a_units            = config['gru_a_units']
        self.gru_b_units            = config['gru_b_units']
        self.output_levels          = config['output_levels']
        self.hsampling              = config.get('hsampling', False)

        self.gru_a_input_dim        = len(self.input_layout['signals']) * self.signal_embedding_dim + self.feature_conditioning_dim
        self.gru_b_input_dim        = self.gru_a_units + self.feature_conditioning_dim

        # sample rate network layers
        self.signal_embedding   = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
        self.gru_a              = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
        self.gru_b              = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
        self.dual_fc            = DualFC(self.gru_b_units, self.output_levels)

        # sparsification
        self.sparsifier = []

        # GRU A
        if 'gru_a' in config['sparsification']:
            gru_config  = config['sparsification']['gru_a']
            task_list = [(self.gru_a, gru_config['params'])]
            self.sparsifier.append(GRUSparsifier(task_list,
                                                 gru_config['start'],
                                                 gru_config['stop'],
                                                 gru_config['interval'],
                                                 gru_config['exponent'])
            )
            self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
                                                                      gru_config['params'], drop_input=True)
        else:
            self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)

        # GRU B
        if 'gru_b' in config['sparsification']:
            gru_config  = config['sparsification']['gru_b']
            task_list = [(self.gru_b, gru_config['params'])]
            self.sparsifier.append(GRUSparsifier(task_list,
                                                 gru_config['start'],
                                                 gru_config['stop'],
                                                 gru_config['interval'],
                                                 gru_config['exponent'])
            )
            self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
                                                                      gru_config['params'])
        else:
            self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)

        # inference parameters
        self.lpc_gamma = config.get('lpc_gamma', 1)

    def sparsify(self):
        for sparsifier in self.sparsifier:
            sparsifier.step()

    def get_gflops(self, fs, verbose=False):
        gflops = 0

        # frame rate network
        conditioning_dim = self.feature_conditioning_dim
        feature_channels = self.feature_channels
        frame_rate = fs / self.frame_size
        frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
        if verbose:
            print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
        gflops += frame_rate_network_complexity

        # gru a
        gru_a_rate = fs
        gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
        if verbose:
            print(f"gru A: {gru_a_complexity} GFLOPS")
        gflops += gru_a_complexity

        # gru b
        gru_b_rate = fs
        gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
        if verbose:
            print(f"gru B: {gru_b_complexity} GFLOPS")
        gflops += gru_b_complexity


        # dual fcs
        fc = self.dual_fc
        rate = fs
        input_size = fc.dense1.in_features
        output_size = fc.dense1.out_features
        dual_fc_complexity = 1e-9 *  (4 * input_size * output_size + 22 * output_size) * rate
        if self.hsampling:
            dual_fc_complexity /= 8
        if verbose:
            print(f"dual_fc: {dual_fc_complexity} GFLOPS")
        gflops += dual_fc_complexity

        if verbose:
            print(f'total: {gflops} GFLOPS')

        return gflops

    def frame_rate_network(self, features, periods):

        embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
        features = torch.concat((features, embedded_periods), dim=-1)

        # convert to channels first and calculate conditioning vector
        c = torch.permute(features, [0, 2, 1])

        c = torch.tanh(self.feature_conv1(c))
        c = torch.tanh(self.feature_conv2(c))
        # back to channels last
        c = torch.permute(c, [0, 2, 1])
        c = torch.tanh(self.feature_dense1(c))
        c = torch.tanh(self.feature_dense2(c))

        return c

    def sample_rate_network(self, signals, c, gru_states):
        embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
        c_upsampled      = torch.repeat_interleave(c, self.frame_size, dim=1)

        y = torch.concat((embedded_signals, c_upsampled), dim=-1)
        y, gru_a_state = self.gru_a(y, gru_states[0])
        y = torch.concat((y, c_upsampled), dim=-1)
        y, gru_b_state = self.gru_b(y, gru_states[1])

        y = self.dual_fc(y)

        if self.hsampling:
            y = torch.sigmoid(y)
            log_probs = torch.log(get_pdf_from_tree(y) + 1e-6)
        else:
            log_probs = torch.log_softmax(y, dim=-1)

        return log_probs, (gru_a_state, gru_b_state)

    def decoder(self, signals, c, gru_states):
        embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)

        y = torch.concat((embedded_signals, c), dim=-1)
        y, gru_a_state = self.gru_a(y, gru_states[0])
        y = torch.concat((y, c), dim=-1)
        y, gru_b_state = self.gru_b(y, gru_states[1])

        y = self.dual_fc(y)

        if self.hsampling:
            y = torch.sigmoid(y)
            probs = get_pdf_from_tree(y)
        else:
            probs = torch.softmax(y, dim=-1)

        return probs, (gru_a_state, gru_b_state)

    def forward(self, features, periods, signals, gru_states):

        c            = self.frame_rate_network(features, periods)
        log_probs, _ = self.sample_rate_network(signals, c, gru_states)

        return log_probs

    def generate(self, features, periods, lpcs):

        with torch.no_grad():
            device = self.parameters().__next__().device

            num_frames          = features.shape[0] - self.feature_history - self.feature_lookahead
            lpc_order           = lpcs.shape[-1]
            num_input_signals   = len(self.input_layout['signals'])
            pitch_corr_position = self.input_layout['features']['pitch_corr'][0]

            # signal buffers
            pcm    = torch.zeros((num_frames * self.frame_size + lpc_order))
            output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
            mem = 0

            # state buffers
            gru_a_state = torch.zeros((1, 1, self.gru_a_units))
            gru_b_state = torch.zeros((1, 1, self.gru_b_units))
            gru_states = [gru_a_state, gru_b_state]

            input_signals = torch.zeros((1, 1, num_input_signals), dtype=torch.long) + 128

            # push data to device
            features = features.to(device)
            periods  = periods.to(device)
            lpcs     = lpcs.to(device)

            # lpc weighting
            weights = torch.FloatTensor([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]).to(device)
            lpcs    = lpcs * weights

            # run feature encoding
            c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))

            for frame_index in range(num_frames):
                frame_start = frame_index * self.frame_size
                pitch_corr  = features[frame_index + self.feature_history, pitch_corr_position]
                a           = - torch.flip(lpcs[frame_index + self.feature_history], [0])
                current_c   = c[:, frame_index : frame_index + 1, :]

                for i in range(self.frame_size):
                    pcm_position    = frame_start + i + lpc_order
                    output_position = frame_start + i

                    # prepare input
                    pred = torch.sum(pcm[pcm_position - lpc_order : pcm_position] * a)
                    if 'prediction' in self.input_layout['signals']:
                        input_signals[0, 0, self.input_layout['signals']['prediction']] = lin2ulawq(pred)

                    # run single step of sample rate network
                    probs, gru_states = self.decoder(
                                            input_signals,
                                            current_c,
                                            gru_states
                                    )

                    # sample from output
                    exc_ulaw = sample_excitation(probs, pitch_corr)

                    # signal generation
                    exc = ulaw2lin(exc_ulaw)
                    sig = exc + pred
                    pcm[pcm_position] = sig
                    mem = 0.85 * mem + float(sig)
                    output[output_position] = clip_to_int16(round(mem))

                    # buffer update
                    if 'last_signal' in self.input_layout['signals']:
                        input_signals[0, 0, self.input_layout['signals']['last_signal']] = lin2ulawq(sig)

                    if 'last_error' in self.input_layout['signals']:
                        input_signals[0, 0, self.input_layout['signals']['last_error']]  = lin2ulawq(exc)

        return output