ref: 35ee397e060283d30c098ae5e17836316bbec08b
dir: /dnn/torch/lpcnet/utils/layers/subconditioner.py/
from re import sub import torch from torch import nn def get_subconditioner( method, number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs): subconditioner_dict = { 'additive' : AdditiveSubconditioner, 'concatenative' : ConcatenativeSubconditioner, 'modulative' : ModulativeSubconditioner } return subconditioner_dict[method](number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs) class Subconditioner(nn.Module): def __init__(self): """ upsampling by subconditioning Upsamples a sequence of states conditioning on pcm signals and optionally a feature vector. """ super(Subconditioner, self).__init__() def forward(self, states, signals, features=None): raise Exception("Base class should not be called") def single_step(self, index, state, signals, features): raise Exception("Base class should not be called") def get_output_dim(self, index): raise Exception("Base class should not be called") class AdditiveSubconditioner(Subconditioner): def __init__(self, number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs): """ subconditioning by addition """ super(AdditiveSubconditioner, self).__init__() self.number_of_subsamples = number_of_subsamples self.pcm_embedding_size = pcm_embedding_size self.state_size = state_size self.pcm_levels = pcm_levels self.number_of_signals = number_of_signals if self.pcm_embedding_size != self.state_size: raise ValueError('For additive subconditioning state and embedding ' + f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}') self.embeddings = [None] for i in range(1, self.number_of_subsamples): embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) self.add_module('pcm_embedding_' + str(i), embedding) self.embeddings.append(embedding) def forward(self, states, signals): """ creates list of subconditioned states Parameters: ----------- states : torch.tensor states of shape (batch, seq_length // s, state_size) signals : torch.tensor signals of shape (batch, seq_length, number_of_signals) Returns: -------- c_states : list of torch.tensor list of s subconditioned states """ s = self.number_of_subsamples c_states = [states] new_states = states for i in range(1, self.number_of_subsamples): embed = self.embeddings[i](signals[:, i::s]) # reduce signal dimension embed = torch.sum(embed, dim=2) new_states = new_states + embed c_states.append(new_states) return c_states def single_step(self, index, state, signals): """ carry out single step for inference Parameters: ----------- index : int position in subconditioning batch state : torch.tensor state to sub-condition signals : torch.tensor signals for subconditioning, all but the last dimensions must match those of state Returns: c_state : torch.tensor subconditioned state """ if index == 0: c_state = state else: embed_signals = self.embeddings[index](signals) c = torch.sum(embed_signals, dim=-2) c_state = state + c return c_state def get_output_dim(self, index): return self.state_size def get_average_flops_per_step(self): s = self.number_of_subsamples flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size return flops class ConcatenativeSubconditioner(Subconditioner): def __init__(self, number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, recurrent=True, **kwargs): """ subconditioning by concatenation """ super(ConcatenativeSubconditioner, self).__init__() self.number_of_subsamples = number_of_subsamples self.pcm_embedding_size = pcm_embedding_size self.state_size = state_size self.pcm_levels = pcm_levels self.number_of_signals = number_of_signals self.recurrent = recurrent self.embeddings = [] start_index = 0 if self.recurrent: start_index = 1 self.embeddings.append(None) for i in range(start_index, self.number_of_subsamples): embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) self.add_module('pcm_embedding_' + str(i), embedding) self.embeddings.append(embedding) def forward(self, states, signals): """ creates list of subconditioned states Parameters: ----------- states : torch.tensor states of shape (batch, seq_length // s, state_size) signals : torch.tensor signals of shape (batch, seq_length, number_of_signals) Returns: -------- c_states : list of torch.tensor list of s subconditioned states """ s = self.number_of_subsamples if self.recurrent: c_states = [states] start = 1 else: c_states = [] start = 0 new_states = states for i in range(start, self.number_of_subsamples): embed = self.embeddings[i](signals[:, i::s]) # reduce signal dimension embed = torch.flatten(embed, -2) if self.recurrent: new_states = torch.cat((new_states, embed), dim=-1) else: new_states = torch.cat((states, embed), dim=-1) c_states.append(new_states) return c_states def single_step(self, index, state, signals): """ carry out single step for inference Parameters: ----------- index : int position in subconditioning batch state : torch.tensor state to sub-condition signals : torch.tensor signals for subconditioning, all but the last dimensions must match those of state Returns: c_state : torch.tensor subconditioned state """ if index == 0 and self.recurrent: c_state = state else: embed_signals = self.embeddings[index](signals) c = torch.flatten(embed_signals, -2) if not self.recurrent and index > 0: # overwrite previous conditioning vector c_state = torch.cat((state[...,:self.state_size], c), dim=-1) else: c_state = torch.cat((state, c), dim=-1) return c_state return c_state def get_average_flops_per_step(self): return 0 def get_output_dim(self, index): if self.recurrent: return self.state_size + index * self.pcm_embedding_size * self.number_of_signals else: return self.state_size + self.pcm_embedding_size * self.number_of_signals class ModulativeSubconditioner(Subconditioner): def __init__(self, number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, state_recurrent=False, **kwargs): """ subconditioning by modulation """ super(ModulativeSubconditioner, self).__init__() self.number_of_subsamples = number_of_subsamples self.pcm_embedding_size = pcm_embedding_size self.state_size = state_size self.pcm_levels = pcm_levels self.number_of_signals = number_of_signals self.state_recurrent = state_recurrent self.hidden_size = self.pcm_embedding_size * self.number_of_signals if self.state_recurrent: self.hidden_size += self.pcm_embedding_size self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size) self.embeddings = [None] self.alphas = [None] self.betas = [None] for i in range(1, self.number_of_subsamples): embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) self.add_module('pcm_embedding_' + str(i), embedding) self.embeddings.append(embedding) self.alphas.append(nn.Linear(self.hidden_size, self.state_size)) self.add_module('alpha_dense_' + str(i), self.alphas[-1]) self.betas.append(nn.Linear(self.hidden_size, self.state_size)) self.add_module('beta_dense_' + str(i), self.betas[-1]) def forward(self, states, signals): """ creates list of subconditioned states Parameters: ----------- states : torch.tensor states of shape (batch, seq_length // s, state_size) signals : torch.tensor signals of shape (batch, seq_length, number_of_signals) Returns: -------- c_states : list of torch.tensor list of s subconditioned states """ s = self.number_of_subsamples c_states = [states] new_states = states for i in range(1, self.number_of_subsamples): embed = self.embeddings[i](signals[:, i::s]) # reduce signal dimension embed = torch.flatten(embed, -2) if self.state_recurrent: comp_states = self.state_transform(new_states) embed = torch.cat((embed, comp_states), dim=-1) alpha = torch.tanh(self.alphas[i](embed)) beta = torch.tanh(self.betas[i](embed)) # new state obtained by modulating previous state new_states = torch.tanh((1 + alpha) * new_states + beta) c_states.append(new_states) return c_states def single_step(self, index, state, signals): """ carry out single step for inference Parameters: ----------- index : int position in subconditioning batch state : torch.tensor state to sub-condition signals : torch.tensor signals for subconditioning, all but the last dimensions must match those of state Returns: c_state : torch.tensor subconditioned state """ if index == 0: c_state = state else: embed_signals = self.embeddings[index](signals) c = torch.flatten(embed_signals, -2) if self.state_recurrent: r_state = self.state_transform(state) c = torch.cat((c, r_state), dim=-1) alpha = torch.tanh(self.alphas[index](c)) beta = torch.tanh(self.betas[index](c)) c_state = torch.tanh((1 + alpha) * state + beta) return c_state return c_state def get_output_dim(self, index): return self.state_size def get_average_flops_per_step(self): s = self.number_of_subsamples # estimate activation by 10 flops # c_state = torch.tanh((1 + alpha) * state + beta) flops = 13 * self.state_size # hidden size hidden_size = self.number_of_signals * self.pcm_embedding_size if self.state_recurrent: hidden_size += self.pcm_embedding_size # counting 2 * A * B flops for Linear(A, B) # alpha = torch.tanh(self.alphas[index](c)) # beta = torch.tanh(self.betas[index](c)) flops += 4 * hidden_size * self.state_size + 20 * self.state_size # r_state = self.state_transform(state) if self.state_recurrent: flops += 2 * self.state_size * self.pcm_embedding_size # average over steps flops *= (s - 1) / s return flops class ComparitiveSubconditioner(Subconditioner): def __init__(self, number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, error_index=-1, apply_gate=True, normalize=False): """ subconditioning by comparison """ super(ComparitiveSubconditioner, self).__init__() self.comparison_size = self.pcm_embedding_size self.error_position = error_index self.apply_gate = apply_gate self.normalize = normalize self.state_transform = nn.Linear(self.state_size, self.comparison_size) self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size) self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size) if self.apply_gate: self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size) # embeddings and state transforms self.embeddings = [None] self.alpha_denses = [None] self.beta_denses = [None] self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)] self.add_module('state_transform_0', self.state_transforms[0]) for i in range(1, self.number_of_subsamples): embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) self.add_module('pcm_embedding_' + str(i), embedding) self.embeddings.append(embedding) state_transform = nn.Linear(self.state_size, self.comparison_size) self.add_module('state_transform_' + str(i), state_transform) self.state_transforms.append(state_transform) self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)) self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1]) self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)) self.add_module('beta_dense_' + str(i), self.beta_denses[-1]) def forward(self, states, signals): s = self.number_of_subsamples c_states = [states] new_states = states for i in range(1, self.number_of_subsamples): embed = self.embeddings[i](signals[:, i::s]) # reduce signal dimension embed = torch.flatten(embed, -2) comp_states = self.state_transforms[i](new_states) alpha = torch.tanh(self.alpha_dense(embed)) beta = torch.tanh(self.beta_dense(embed)) # new state obtained by modulating previous state new_states = torch.tanh((1 + alpha) * comp_states + beta) c_states.append(new_states) return c_states