shithub: opus

Download patch

ref: 105e1d83fad6393b00edb7eb676be483eb4ee2d7
parent: 178672ed1823f2a2fdc7e36e34578383f799f4f6
author: Jan Buethe <jbuethe@amazon.de>
date: Fri Jun 30 17:15:56 EDT 2023

Opus ng lace

--- /dev/null
+++ b/dnn/torch/osce/README.md
@@ -1,0 +1,4 @@
+# Opus Speech Coding Enhancement
+
+This folder hosts models for enhancing SILK. See related Opus repo https://gitlab.xiph.org/xiph/opus/-/tree/exp-neural-silk-enhancement
+for feature generation.
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/data/__init__.py
@@ -1,0 +1,30 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+from .silk_enhancement_set import SilkEnhancementSet
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/data/silk_enhancement_set.py
@@ -1,0 +1,140 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+
+from torch.utils.data import Dataset
+import numpy as np
+
+from utils.silk_features import silk_feature_factory
+from utils.pitch import hangover, calculate_acorr_window
+
+
+class SilkEnhancementSet(Dataset):
+    def __init__(self,
+                 path,
+                 frames_per_sample=100,
+                 no_pitch_value=256,
+                 preemph=0.85,
+                 skip=91,
+                 acorr_radius=2,
+                 pitch_hangover=8,
+                 num_bands_clean_spec=64,
+                 num_bands_noisy_spec=18,
+                 noisy_spec_scale='opus',
+                 noisy_apply_dct=True,
+                 add_offset=False,
+                 add_double_lag_acorr=False
+                 ):
+
+        assert frames_per_sample % 4 == 0
+
+        self.frame_size = 80
+        self.frames_per_sample = frames_per_sample
+        self.no_pitch_value = no_pitch_value
+        self.preemph = preemph
+        self.skip = skip
+        self.acorr_radius = acorr_radius
+        self.pitch_hangover = pitch_hangover
+        self.num_bands_clean_spec = num_bands_clean_spec
+        self.num_bands_noisy_spec = num_bands_noisy_spec
+        self.noisy_spec_scale = noisy_spec_scale
+        self.add_double_lag_acorr = add_double_lag_acorr
+
+        self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
+        self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
+        self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
+        self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
+        self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
+        self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
+        self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
+
+        self.clean_signal = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
+        self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
+
+        self.create_features = silk_feature_factory(no_pitch_value,
+                                                    acorr_radius,
+                                                    pitch_hangover,
+                                                    num_bands_clean_spec,
+                                                    num_bands_noisy_spec,
+                                                    noisy_spec_scale,
+                                                    noisy_apply_dct,
+                                                    add_offset,
+                                                    add_double_lag_acorr)
+
+        self.history_len = 700 if add_double_lag_acorr else 350
+        # discard some frames to have enough signal history
+        self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2)
+
+        num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames
+
+        self.len = num_frames // frames_per_sample
+
+    def __len__(self):
+        return self.len
+
+    def __getitem__(self, index):
+
+        frame_start = self.frames_per_sample * index + self.skip_frames
+        frame_stop  = frame_start + self.frames_per_sample
+
+        signal_start = frame_start * self.frame_size - self.skip
+        signal_stop  = frame_stop  * self.frame_size - self.skip
+
+        clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15
+        coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
+
+        coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
+
+        features, periods = self.create_features(
+              coded_signal,
+              coded_signal_history,
+              self.lpcs[frame_start : frame_stop],
+              self.gains[frame_start : frame_stop],
+              self.ltps[frame_start : frame_stop],
+              self.periods[frame_start : frame_stop],
+              self.offsets[frame_start : frame_stop]
+        )
+
+        if self.preemph > 0:
+            clean_signal[1:] -= self.preemph * clean_signal[: -1]
+            coded_signal[1:] -= self.preemph * coded_signal[: -1]
+
+        num_bits        = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
+        num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
+
+        numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
+
+        return {
+            'features'  : features,
+            'periods'   : periods.astype(np.int64),
+            'target'    : clean_signal.astype(np.float32),
+            'signals'   : coded_signal.reshape(-1, 1).astype(np.float32),
+            'numbits'   : numbits.astype(np.float32)
+            }
--- /dev/null
+++ b/dnn/torch/osce/engine/engine.py
@@ -1,0 +1,101 @@
+import torch
+from tqdm import tqdm
+import sys
+
+def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
+
+    model.to(device)
+    model.train()
+
+    running_loss = 0
+    previous_running_loss = 0
+
+
+    with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
+
+        for i, batch in enumerate(tepoch):
+
+            # set gradients to zero
+            optimizer.zero_grad()
+
+
+            # push batch to device
+            for key in batch:
+                batch[key] = batch[key].to(device)
+
+            target = batch['target']
+
+            # calculate model output
+            output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
+
+            # calculate loss
+            if isinstance(output, list):
+                loss = torch.zeros(1, device=device)
+                for y in output:
+                    loss = loss + criterion(target, y.squeeze(1))
+                loss = loss / len(output)
+            else:
+                loss = criterion(target, output.squeeze(1))
+
+            # calculate gradients
+            loss.backward()
+
+            # update weights
+            optimizer.step()
+
+            # update learning rate
+            scheduler.step()
+
+            # update running loss
+            running_loss += float(loss.cpu())
+
+            # update status bar
+            if i % log_interval == 0:
+                tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
+                previous_running_loss = running_loss
+
+
+    running_loss /= len(dataloader)
+
+    return running_loss
+
+def evaluate(model, criterion, dataloader, device, log_interval=10):
+
+    model.to(device)
+    model.eval()
+
+    running_loss = 0
+    previous_running_loss = 0
+
+
+    with torch.no_grad():
+        with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
+
+            for i, batch in enumerate(tepoch):
+
+
+
+                # push batch to device
+                for key in batch:
+                    batch[key] = batch[key].to(device)
+
+                target = batch['target']
+
+                # calculate model output
+                output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
+
+                # calculate loss
+                loss = criterion(target, output.squeeze(1))
+
+                # update running loss
+                running_loss += float(loss.cpu())
+
+                # update status bar
+                if i % log_interval == 0:
+                    tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
+                    previous_running_loss = running_loss
+
+
+        running_loss /= len(dataloader)
+
+        return running_loss
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/losses/stft_loss.py
@@ -1,0 +1,277 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+"""STFT-based Loss modules."""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+import torchaudio
+
+
+def get_window(win_name, win_length, *args, **kwargs):
+    window_dict = {
+        'bartlett_window'   : torch.bartlett_window,
+        'blackman_window'   : torch.blackman_window,
+        'hamming_window'    : torch.hamming_window,
+        'hann_window'       : torch.hann_window,
+        'kaiser_window'     : torch.kaiser_window
+    }
+
+    if not win_name in window_dict:
+        raise ValueError()
+
+    return window_dict[win_name](win_length, *args, **kwargs)
+
+
+def stft(x, fft_size, hop_size, win_length, window):
+    """Perform STFT and convert to magnitude spectrogram.
+    Args:
+        x (Tensor): Input signal tensor (B, T).
+        fft_size (int): FFT size.
+        hop_size (int): Hop size.
+        win_length (int): Window length.
+        window (str): Window function type.
+    Returns:
+        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+    """
+
+    win = get_window(window, win_length).to(x.device)
+    x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
+
+
+    return torch.clamp(torch.abs(x_stft), min=1e-7)
+
+def spectral_convergence_loss(Y_true, Y_pred):
+    dims=list(range(1, len(Y_pred.shape)))
+    return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
+
+
+def log_magnitude_loss(Y_true, Y_pred):
+    Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
+    Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
+
+    return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
+
+def spectral_xcorr_loss(Y_true, Y_pred):
+    Y_true = Y_true.abs()
+    Y_pred = Y_pred.abs()
+    dims=list(range(1, len(Y_pred.shape)))
+    xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
+
+    return 1 - xcorr.mean()
+
+
+
+class MRLogMelLoss(nn.Module):
+    def __init__(self,
+                 fft_sizes=[512, 256, 128, 64],
+                 overlap=0.5,
+                 fs=16000,
+                 n_mels=18
+                 ):
+
+        self.fft_sizes  = fft_sizes
+        self.overlap    = overlap
+        self.fs         = fs
+        self.n_mels     = n_mels
+
+        super().__init__()
+
+        self.mel_specs = []
+        for fft_size in fft_sizes:
+            hop_size = int(round(fft_size * (1 - self.overlap)))
+
+            n_mels = self.n_mels
+            if fft_size < 128:
+                n_mels //= 2
+
+            self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
+
+        for i, mel_spec in enumerate(self.mel_specs):
+            self.add_module(f'mel_spec_{i+1}', mel_spec)
+
+    def forward(self, y_true, y_pred):
+
+        loss = torch.zeros(1, device=y_true.device)
+
+        for mel_spec in self.mel_specs:
+            Y_true = mel_spec(y_true)
+            Y_pred = mel_spec(y_pred)
+            loss = loss + log_magnitude_loss(Y_true, Y_pred)
+
+        loss = loss / len(self.mel_specs)
+
+        return loss
+
+def create_weight_matrix(num_bins, bins_per_band=10):
+    m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
+
+    r0 = bins_per_band // 2
+    r1 = bins_per_band - r0
+
+    for i in range(num_bins):
+        i0 = max(i - r0, 0)
+        j0 = min(i + r1, num_bins)
+
+        m[i, i0: j0] += 1
+
+        if i < r0:
+            m[i, :r0 - i] += 1
+
+        if i > num_bins - r1:
+            m[i, num_bins - r1 - i:] += 1
+
+    return m / bins_per_band
+
+def weighted_spectral_convergence(Y_true, Y_pred, w):
+
+    # calculate sfm based weights
+    logY = torch.log(torch.abs(Y_true) + 1e-9)
+    Y = torch.abs(Y_true)
+
+    avg_logY = torch.matmul(logY.transpose(1, 2), w)
+    avg_Y = torch.matmul(Y.transpose(1, 2), w)
+
+    sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
+
+    weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
+
+    loss = torch.mean(
+        torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
+        / (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
+    )
+
+    return loss
+
+def gen_filterbank(N, Fs=16000):
+    in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
+    out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
+    #ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
+    ERB_N = 24.7 + .108*in_freq
+    delta = np.abs(in_freq-out_freq)/ERB_N
+    center = (delta<.5).astype('float32')
+    R = -12*center*delta**2 + (1-center)*(3-12*delta)
+    RE = 10.**(R/10.)
+    norm = np.sum(RE, axis=1)
+    RE = RE/norm[:, np.newaxis]
+    return torch.from_numpy(RE)
+
+def smooth_log_mag(Y_true, Y_pred, filterbank):
+    Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
+    Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
+
+    loss = torch.abs(
+        torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
+    )
+
+    loss = loss.mean()
+
+    return loss
+
+class MRSTFTLoss(nn.Module):
+    def __init__(self,
+                 fft_sizes=[2048, 1024, 512, 256, 128, 64],
+                 overlap=0.5,
+                 window='hann_window',
+                 fs=16000,
+                 log_mag_weight=1,
+                 sc_weight=0,
+                 wsc_weight=0,
+                 smooth_log_mag_weight=0,
+                 sxcorr_weight=0):
+        super().__init__()
+
+        self.fft_sizes = fft_sizes
+        self.overlap = overlap
+        self.window = window
+        self.log_mag_weight = log_mag_weight
+        self.sc_weight = sc_weight
+        self.wsc_weight = wsc_weight
+        self.smooth_log_mag_weight = smooth_log_mag_weight
+        self.sxcorr_weight = sxcorr_weight
+        self.fs = fs
+
+        # weights for SFM weighted spectral convergence loss
+        self.wsc_weights = torch.nn.ParameterDict()
+        for fft_size in fft_sizes:
+            width = min(11, int(1000 * fft_size / self.fs + .5))
+            width += width % 2
+            self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
+                create_weight_matrix(fft_size // 2 + 1, width),
+                requires_grad=False
+            )
+
+        # filterbanks for smooth log magnitude loss
+        self.filterbanks = torch.nn.ParameterDict()
+        for fft_size in fft_sizes:
+            self.filterbanks[str(fft_size)] = torch.nn.Parameter(
+                gen_filterbank(fft_size//2),
+                requires_grad=False
+            )
+
+
+    def __call__(self, y_true, y_pred):
+
+
+        lm_loss = torch.zeros(1, device=y_true.device)
+        sc_loss = torch.zeros(1, device=y_true.device)
+        wsc_loss = torch.zeros(1, device=y_true.device)
+        slm_loss = torch.zeros(1, device=y_true.device)
+        sxcorr_loss = torch.zeros(1, device=y_true.device)
+
+        for fft_size in self.fft_sizes:
+            hop_size = int(round(fft_size * (1 - self.overlap)))
+            win_size = fft_size
+
+            Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
+            Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
+
+            if self.log_mag_weight > 0:
+                lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
+
+            if self.sc_weight > 0:
+                sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
+
+            if self.wsc_weight > 0:
+                wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
+
+            if self.smooth_log_mag_weight > 0:
+                slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
+
+            if self.sxcorr_weight > 0:
+                sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
+
+
+        total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
+                + self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
+                + self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
+
+        return total_loss
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/make_default_setup.py
@@ -1,0 +1,56 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import argparse
+
+import yaml
+
+from utils.templates import setup_dict
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('name', type=str, help='name of default setup file')
+parser.add_argument('--model', choices=['lace'], help='model name', default='lace')
+parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
+
+args = parser.parse_args()
+
+setup = setup_dict[args.model]
+
+# update dataset if given
+if type(args.path2dataset) != type(None):
+    setup['dataset'] = args.path2dataset
+
+name = args.name
+if not name.endswith('.yml'):
+    name += '.yml'
+
+if __name__ == '__main__':
+    with open(name, 'w') as f:
+        f.write(yaml.dump(setup))
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/models/__init__.py
@@ -1,0 +1,36 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+from .lace import LACE
+
+
+
+model_dict = {
+    'lace': LACE
+}
--- /dev/null
+++ b/dnn/torch/osce/models/lace.py
@@ -1,0 +1,176 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+import numpy as np
+
+from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
+from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
+
+from models.nns_base import NNSBase
+from models.silk_feature_net_pl import SilkFeatureNetPL
+from models.silk_feature_net import SilkFeatureNet
+from .scale_embedding import ScaleEmbedding
+
+class LACE(NNSBase):
+    """ Linear-Adaptive Coding Enhancer """
+    FRAME_SIZE=80
+
+    def __init__(self,
+                 num_features=47,
+                 pitch_embedding_dim=64,
+                 cond_dim=256,
+                 pitch_max=257,
+                 kernel_size=15,
+                 preemph=0.85,
+                 skip=91,
+                 comb_gain_limit_db=-6,
+                 global_gain_limits_db=[-6, 6],
+                 conv_gain_limits_db=[-6, 6],
+                 numbits_range=[50, 650],
+                 numbits_embedding_dim=8,
+                 hidden_feature_dim=64,
+                 partial_lookahead=True,
+                 norm_p=2):
+
+        super().__init__(skip=skip, preemph=preemph)
+
+
+        self.num_features           = num_features
+        self.cond_dim               = cond_dim
+        self.pitch_max              = pitch_max
+        self.pitch_embedding_dim    = pitch_embedding_dim
+        self.kernel_size            = kernel_size
+        self.preemph                = preemph
+        self.skip                   = skip
+        self.numbits_range          = numbits_range
+        self.numbits_embedding_dim  = numbits_embedding_dim
+        self.hidden_feature_dim     = hidden_feature_dim
+        self.partial_lookahead      = partial_lookahead
+
+        # pitch embedding
+        self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
+
+        # numbits embedding
+        self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
+
+        # feature net
+        if partial_lookahead:
+            self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
+        else:
+            self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
+
+        # comb filters
+        left_pad = self.kernel_size // 2
+        right_pad = self.kernel_size - 1 - left_pad
+        self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
+        self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
+
+        # spectral shaping
+        self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+
+    def flop_count(self, rate=16000, verbose=False):
+
+        frame_rate = rate / self.FRAME_SIZE
+
+        # feature net
+        feature_net_flops = self.feature_net.flop_count(frame_rate)
+        comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
+        af_flops = self.af1.flop_count(rate)
+
+        if verbose:
+            print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
+            print(f"comb filters: {comb_flops / 1e6} MFLOPS")
+            print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
+
+        return feature_net_flops + comb_flops + af_flops
+
+    def forward(self, x, features, periods, numbits, debug=False):
+
+        periods         = periods.squeeze(-1)
+        pitch_embedding = self.pitch_embedding(periods)
+        numbits_embedding = self.numbits_embedding(numbits).flatten(2)
+
+        full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
+        cf = self.feature_net(full_features)
+
+        y = self.cf1(x, cf, periods, debug=debug)
+
+        y = self.cf2(y, cf, periods, debug=debug)
+
+        y = self.af1(y, cf, debug=debug)
+
+        return y
+
+    def get_impulse_responses(self, features, periods, numbits):
+        """ generates impoulse responses on frame centers (input without batch dimension) """
+
+        num_frames = features.size(0)
+        batch_size = 32
+        max_len = 2 * (self.pitch_max + self.kernel_size) + 10
+
+        # spread out some pulses
+        x = np.zeros((batch_size, 1, num_frames * self.FRAME_SIZE))
+        for b in range(batch_size):
+            x[b, :, self.FRAME_SIZE // 2 + b * self.FRAME_SIZE :: batch_size * self.FRAME_SIZE] = 1
+
+        # prepare input
+        x = torch.from_numpy(x).float().to(features.device)
+        features = torch.repeat_interleave(features.unsqueeze(0), batch_size, 0)
+        periods = torch.repeat_interleave(periods.unsqueeze(0), batch_size, 0)
+        numbits = torch.repeat_interleave(numbits.unsqueeze(0), batch_size, 0)
+
+        # run network
+        with torch.no_grad():
+            periods         = periods.squeeze(-1)
+            pitch_embedding = self.pitch_embedding(periods)
+            numbits_embedding = self.numbits_embedding(numbits).flatten(2)
+            full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
+            cf = self.feature_net(full_features)
+            y = self.cf1(x, cf, periods, debug=False)
+            y = self.cf2(y, cf, periods, debug=False)
+            y = self.af1(y, cf, debug=False)
+
+        # collect responses
+        y = y.detach().squeeze().cpu().numpy()
+        cut_frames = (max_len + self.FRAME_SIZE - 1) // self.FRAME_SIZE
+        num_responses = num_frames - cut_frames
+        responses = np.zeros((num_responses, max_len))
+
+        for i in range(num_responses):
+            b = i % batch_size
+            start = self.FRAME_SIZE // 2 + i * self.FRAME_SIZE
+            stop = start + max_len
+
+            responses[i, :] = y[b, start:stop]
+
+        return responses
--- /dev/null
+++ b/dnn/torch/osce/models/nns_base.py
@@ -1,0 +1,69 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+class NNSBase(nn.Module):
+
+    def __init__(self, skip=91, preemph=0.85):
+        super().__init__()
+
+        self.skip = skip
+        self.preemph = preemph
+
+    def process(self, sig, features, periods, numbits, debug=False):
+
+        self.eval()
+        has_numbits = 'numbits' in self.forward.__code__.co_varnames
+        device = next(iter(self.parameters())).device
+        with torch.no_grad():
+
+            # run model
+            x = sig.view(1, 1, -1).to(device)
+            f = features.unsqueeze(0).to(device)
+            p = periods.unsqueeze(0).to(device)
+            n = numbits.unsqueeze(0).to(device)
+
+            if has_numbits:
+                y = self.forward(x, f, p, n, debug=debug).squeeze()
+            else:
+                y = self.forward(x, f, p, debug=debug).squeeze()
+
+            # deemphasis
+            if self.preemph > 0:
+                for i in range(len(y) - 1):
+                    y[i + 1] += self.preemph * y[i]
+
+            # delay compensation
+            y = torch.cat((y[self.skip:], torch.zeros(self.skip, dtype=y.dtype, device=y.device)))
+            out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
+
+        return out
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/models/scale_embedding.py
@@ -1,0 +1,68 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import math as m
+import torch
+from torch import nn
+
+
+class ScaleEmbedding(nn.Module):
+    def __init__(self,
+                 dim,
+                 min_val,
+                 max_val,
+                 logscale=False):
+
+        super().__init__()
+
+        if min_val >= max_val:
+            raise ValueError('min_val must be smaller than max_val')
+
+        if min_val <= 0 and logscale:
+            raise ValueError('min_val must be positive when logscale is true')
+
+        self.dim = dim
+        self.logscale = logscale
+        self.min_val = min_val
+        self.max_val = max_val
+
+        if logscale:
+            self.min_val = m.log(self.min_val)
+            self.max_val = m.log(self.max_val)
+
+
+        self.offset = (self.min_val + self.max_val) / 2
+        self.scale_factors = nn.Parameter(
+            torch.arange(1, dim+1, dtype=torch.float32) * torch.pi / (self.max_val - self.min_val)
+        )
+
+    def forward(self, x):
+        if self.logscale: x = torch.log(x)
+        x = torch.clip(x, self.min_val, self.max_val) - self.offset
+        return torch.sin(x.unsqueeze(-1) * self.scale_factors - 0.5)
--- /dev/null
+++ b/dnn/torch/osce/models/silk_feature_net.py
@@ -1,0 +1,86 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.complexity import _conv1d_flop_count
+
+class SilkFeatureNet(nn.Module):
+
+    def __init__(self,
+                 feature_dim=47,
+                 num_channels=256,
+                 lookahead=False):
+
+        super(SilkFeatureNet, self).__init__()
+
+        self.feature_dim = feature_dim
+        self.num_channels = num_channels
+        self.lookahead = lookahead
+
+        self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
+        self.conv2 = nn.Conv1d(num_channels, num_channels, 3, dilation=2)
+
+        self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
+
+    def flop_count(self, rate=200):
+        count = 0
+        for conv in self.conv1, self.conv2:
+            count += _conv1d_flop_count(conv, rate)
+
+        count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
+
+        return count
+
+
+    def forward(self, features, state=None):
+        """ features shape: (batch_size, num_frames, feature_dim) """
+
+        batch_size = features.size(0)
+
+        if state is None:
+            state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
+
+
+        features = features.permute(0, 2, 1)
+        if self.lookahead:
+            c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
+            c = torch.tanh(self.conv2(F.pad(c, [2, 2])))
+        else:
+            c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
+            c = torch.tanh(self.conv2(F.pad(c, [4, 0])))
+
+        c = c.permute(0, 2, 1)
+
+        c, _ = self.gru(c, state)
+
+        return c
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/models/silk_feature_net_pl.py
@@ -1,0 +1,90 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.complexity import _conv1d_flop_count
+
+class SilkFeatureNetPL(nn.Module):
+    """ feature net with partial lookahead """
+    def __init__(self,
+                 feature_dim=47,
+                 num_channels=256,
+                 hidden_feature_dim=64):
+
+        super(SilkFeatureNetPL, self).__init__()
+
+        self.feature_dim = feature_dim
+        self.num_channels = num_channels
+        self.hidden_feature_dim = hidden_feature_dim
+
+        self.conv1 = nn.Conv1d(feature_dim, self.hidden_feature_dim, 1)
+        self.conv2 = nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2)
+        self.tconv = nn.ConvTranspose1d(num_channels, num_channels, 4, 4)
+
+        self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
+
+    def flop_count(self, rate=200):
+        count = 0
+        for conv in self.conv1, self.conv2, self.tconv:
+            count += _conv1d_flop_count(conv, rate)
+
+        count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
+
+        return count
+
+
+    def forward(self, features, state=None):
+        """ features shape: (batch_size, num_frames, feature_dim) """
+
+        batch_size = features.size(0)
+        num_frames = features.size(1)
+
+        if state is None:
+            state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
+
+        features = features.permute(0, 2, 1)
+        # dimensionality reduction
+        c = torch.tanh(self.conv1(features))
+
+        # frame accumulation
+        c = c.permute(0, 2, 1)
+        c = c.reshape(batch_size, num_frames // 4, -1).permute(0, 2, 1)
+        c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
+
+        # upsampling
+        c = self.tconv(c)
+        c = c.permute(0, 2, 1)
+
+        c, _ = self.gru(c, state)
+
+        return c
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/test_model.py
@@ -1,0 +1,96 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import argparse
+
+import torch
+
+from scipy.io import wavfile
+
+
+from models import model_dict
+from utils.silk_features import load_inference_data
+from utils import endoscopy
+
+debug = False
+if debug:
+    args = type('dummy', (object,),
+    {
+        'input'         : 'testitems/all_0_orig.se',
+        'checkpoint'    : 'testout/checkpoints/checkpoint_epoch_5.pth',
+        'output'        : 'out.wav',
+    })()
+else:
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('input', type=str, help='path to folder with features and signals')
+    parser.add_argument('checkpoint', type=str, help='checkpoint file')
+    parser.add_argument('output', type=str, help='output file')
+    parser.add_argument('--debug', action='store_true', help='enables debug output')
+
+
+    args = parser.parse_args()
+
+
+torch.set_num_threads(2)
+
+input_folder = args.input
+checkpoint_file = args.checkpoint
+
+
+output_file = args.output
+if not output_file.endswith('.wav'):
+    output_file += '.wav'
+
+checkpoint = torch.load(checkpoint_file, map_location="cpu")
+
+# check model
+if not 'name' in checkpoint['setup']['model']:
+    print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
+    model_name = 'pitchpostfilter'
+else:
+    model_name = checkpoint['setup']['model']['name']
+
+model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
+
+model.load_state_dict(checkpoint['state_dict'])
+
+# generate model input
+setup = checkpoint['setup']
+signal, features, periods, numbits = load_inference_data(input_folder, **setup['data'])
+
+if args.debug:
+    endoscopy.init()
+
+output = model.process(signal, features, periods, numbits, debug=args.debug)
+
+wavfile.write(output_file, 16000, output.cpu().numpy())
+
+if args.debug:
+    endoscopy.close()
--- /dev/null
+++ b/dnn/torch/osce/train_model.py
@@ -1,0 +1,297 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+import sys
+
+import yaml
+
+try:
+    import git
+    has_git = True
+except:
+    has_git = False
+
+import torch
+from torch.optim.lr_scheduler import LambdaLR
+
+import numpy as np
+
+from scipy.io import wavfile
+
+import pesq
+
+from data import SilkEnhancementSet
+from models import model_dict
+from engine.engine import train_one_epoch, evaluate
+
+
+from utils.silk_features import load_inference_data
+from utils.misc import count_parameters
+
+from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('setup', type=str, help='setup yaml file')
+parser.add_argument('output', type=str, help='output path')
+parser.add_argument('--device', type=str, help='compute device', default=None)
+parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
+parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
+parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
+
+args = parser.parse_args()
+
+
+torch.set_num_threads(4)
+
+with open(args.setup, 'r') as f:
+    setup = yaml.load(f.read(), yaml.FullLoader)
+
+checkpoint_prefix = 'checkpoint'
+output_prefix = 'output'
+setup_name = 'setup.yml'
+output_file='out.txt'
+
+
+# check model
+if not 'name' in setup['model']:
+    print(f'warning: did not find model entry in setup, using default PitchPostFilter')
+    model_name = 'pitchpostfilter'
+else:
+    model_name = setup['model']['name']
+
+# prepare output folder
+if os.path.exists(args.output):
+    print("warning: output folder exists")
+
+    reply = input('continue? (y/n): ')
+    while reply not in {'y', 'n'}:
+        reply = input('continue? (y/n): ')
+
+    if reply == 'n':
+        os._exit()
+else:
+    os.makedirs(args.output, exist_ok=True)
+
+checkpoint_dir = os.path.join(args.output, 'checkpoints')
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+# add repo info to setup
+if has_git:
+    working_dir = os.path.split(__file__)[0]
+    try:
+        repo = git.Repo(working_dir)
+        setup['repo'] = dict()
+        hash = repo.head.object.hexsha
+        urls = list(repo.remote().urls)
+        is_dirty = repo.is_dirty()
+
+        if is_dirty:
+            print("warning: repo is dirty")
+
+        setup['repo']['hash'] = hash
+        setup['repo']['urls'] = urls
+        setup['repo']['dirty'] = is_dirty
+    except:
+        has_git = False
+
+# dump setup
+with open(os.path.join(args.output, setup_name), 'w') as f:
+    yaml.dump(setup, f)
+
+ref = None
+if args.testdata is not None:
+
+    testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
+
+    inference_test = True
+    inference_folder = os.path.join(args.output, 'inference_test')
+    os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
+
+    try:
+        ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
+    except:
+        pass
+else:
+    inference_test = False
+
+# training parameters
+batch_size      = setup['training']['batch_size']
+epochs          = setup['training']['epochs']
+lr              = setup['training']['lr']
+lr_decay_factor = setup['training']['lr_decay_factor']
+
+# load training dataset
+data_config = setup['data']
+data = SilkEnhancementSet(setup['dataset'], **data_config)
+
+# load validation dataset if given
+if 'validation_dataset' in setup:
+    validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
+
+    validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
+
+    run_validation = True
+else:
+    run_validation = False
+
+# create model
+model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
+
+if args.initial_checkpoint is not None:
+    print(f"loading state dict from {args.initial_checkpoint}...")
+    chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
+    model.load_state_dict(chkpt['state_dict'])
+
+# set compute device
+if type(args.device) == type(None):
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+else:
+    device = torch.device(args.device)
+
+# push model to device
+model.to(device)
+
+# dataloader
+dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
+
+# optimizer is introduced to trainable parameters
+parameters = [p for p in model.parameters() if p.requires_grad]
+optimizer = torch.optim.Adam(parameters, lr=lr)
+
+# learning rate scheduler
+scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
+
+# loss
+w_l1 = setup['training']['loss']['w_l1']
+w_lm = setup['training']['loss']['w_lm']
+w_slm = setup['training']['loss']['w_slm']
+w_sc = setup['training']['loss']['w_sc']
+w_logmel = setup['training']['loss']['w_logmel']
+w_wsc = setup['training']['loss']['w_wsc']
+w_xcorr = setup['training']['loss']['w_xcorr']
+w_sxcorr = setup['training']['loss']['w_sxcorr']
+w_l2 = setup['training']['loss']['w_l2']
+
+w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
+
+stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
+logmelloss = MRLogMelLoss().to(device)
+
+def xcorr_loss(y_true, y_pred):
+    dims = list(range(1, len(y_true.shape)))
+
+    loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
+
+    return torch.mean(loss)
+
+def td_l2_norm(y_true, y_pred):
+    dims = list(range(1, len(y_true.shape)))
+
+    loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
+
+    return loss.mean()
+
+def td_l1(y_true, y_pred, pow=0):
+    dims = list(range(1, len(y_true.shape)))
+    tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
+
+    return torch.mean(tmp)
+
+def criterion(x, y):
+
+    return (w_l1 * td_l1(x, y, pow=1) +  stftloss(x, y) + w_logmel * logmelloss(x, y)
+            + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
+
+
+
+# model checkpoint
+checkpoint = {
+    'setup'         : setup,
+    'state_dict'    : model.state_dict(),
+    'loss'          : -1
+}
+
+
+
+
+if not args.no_redirect:
+    print(f"re-directing output to {os.path.join(args.output, output_file)}")
+    sys.stdout = open(os.path.join(args.output, output_file), "w")
+
+print("summary:")
+
+print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
+if hasattr(model, 'flop_count'):
+    print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
+
+if ref is not None:
+    noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
+    initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
+    print(f"initial MOS (PESQ): {initial_mos}")
+
+best_loss = 1e9
+
+for ep in range(1, epochs + 1):
+    print(f"training epoch {ep}...")
+    new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
+
+
+    # save checkpoint
+    checkpoint['state_dict'] = model.state_dict()
+    checkpoint['loss']       = new_loss
+
+    if run_validation:
+        print("running validation...")
+        validation_loss = evaluate(model, criterion, validation_dataloader, device)
+        checkpoint['validation_loss'] = validation_loss
+
+        if validation_loss < best_loss:
+            torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
+            best_loss = validation_loss
+
+    if inference_test:
+        print("running inference test...")
+        out = model.process(testsignal, features, periods, numbits).cpu().numpy()
+        wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
+        if ref is not None:
+            mos = pesq.pesq(16000, ref, out, mode='wb')
+            print(f"MOS (PESQ): {mos}")
+
+
+    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
+    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
+
+
+    print()
+
+print('Done')
--- /dev/null
+++ b/dnn/torch/osce/utils/complexity.py
@@ -1,0 +1,35 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+def _conv1d_flop_count(layer, rate):
+    return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0]
+
+
+def _dense_flop_count(layer, rate):
+    return 2 * ((layer.in_features + 1) * layer.out_features * rate )
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/utils/endoscopy.py
@@ -1,0 +1,234 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+""" module for inspecting models during inference """
+
+import os
+
+import yaml
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+
+import torch
+import numpy as np
+
+# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
+_state = dict()
+_folder = 'endoscopy'
+
+def get_gru_gates(gru, input, state):
+    hidden_size = gru.hidden_size
+
+    direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
+    recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
+
+    # reset gate
+    start, stop = 0 * hidden_size, 1 * hidden_size
+    reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
+
+    # update gate
+    start, stop = 1 * hidden_size, 2 * hidden_size
+    update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
+
+    # new gate
+    start, stop = 2 * hidden_size, 3 * hidden_size
+    new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] +  gru.bias_hh_l0[start : stop]))
+
+    return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
+
+
+def init(folder='endoscopy'):
+    """ sets up output folder for endoscopy data """
+
+    global _folder
+    _folder = folder
+
+    if not os.path.exists(folder):
+        os.makedirs(folder)
+    else:
+        print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
+
+def write_data(key, data, fs):
+    """ appends data to previous data written under key """
+
+    global _state
+
+    # convert to numpy if torch.Tensor is given
+    if isinstance(data, torch.Tensor):
+        data = data.detach().numpy()
+
+    if not key in _state:
+        _state[key] = {
+            'fid'   : open(os.path.join(_folder, key + '.bin'), 'wb'),
+            'fs'    : fs,
+            'dim'   : tuple(data.shape),
+            'dtype' : str(data.dtype)
+        }
+
+        with open(os.path.join(_folder, key + '.yml'), 'w') as f:
+            f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
+    else:
+        if _state[key]['fs'] != fs:
+            raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
+        if _state[key]['dtype'] != str(data.dtype):
+            raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
+        if _state[key]['dim'] != tuple(data.shape):
+            raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
+
+    _state[key]['fid'].write(data.tobytes())
+
+def close(folder='endoscopy'):
+    """ clean up """
+    for key in _state.keys():
+        _state[key]['fid'].close()
+
+
+def read_data(folder='endoscopy'):
+    """ retrieves written data as numpy arrays """
+
+
+    keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
+
+    return_dict = dict()
+
+    for key in keys:
+        with open(os.path.join(folder, key + '.yml'), 'r') as f:
+            value = yaml.load(f.read(), yaml.FullLoader)
+
+        with open(os.path.join(folder, key + '.bin'), 'rb') as f:
+            data = np.frombuffer(f.read(), dtype=value['dtype'])
+
+        value['data'] = data.reshape((-1,) + value['dim'])
+
+        return_dict[key] = value
+
+    return return_dict
+
+def get_best_reshape(shape, target_ratio=1):
+    """ calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
+
+    if len(shape) > 1:
+        pixel_count = 1
+        for s in shape:
+            pixel_count *= s
+    else:
+        pixel_count = shape[0]
+
+    if pixel_count == 1:
+        return (1,)
+
+    num_columns = int((pixel_count / target_ratio)**.5)
+
+    while (pixel_count % num_columns):
+        num_columns -= 1
+
+    num_rows = pixel_count // num_columns
+
+    return (num_rows, num_columns)
+
+def get_type_and_shape(shape):
+
+    # can happen if data is one dimensional
+    if len(shape) == 0:
+        shape = (1,)
+
+    # calculate pixel count
+    if len(shape) > 1:
+        pixel_count = 1
+        for s in shape:
+            pixel_count *= s
+    else:
+        pixel_count = shape[0]
+
+    if pixel_count == 1:
+        return 'plot', (1, )
+
+    # stay with shape if already 2-dimensional
+    if len(shape) == 2:
+        if (shape[0] != pixel_count) or (shape[1] != pixel_count):
+            return 'image', shape
+
+    return 'image', get_best_reshape(shape)
+
+def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
+
+    # determine plot setup
+    num_keys = len(data.keys())
+
+    num_rows = int((num_keys * 3/4) ** .5)
+
+    num_cols = (num_keys + num_rows - 1) // num_rows
+
+    fig, axs = plt.subplots(num_rows, num_cols)
+    fig.set_size_inches(num_cols * 5, num_rows * 5)
+
+    display = dict()
+
+    fs_max = max([val['fs'] for val in data.values()])
+
+    num_samples = max([val['data'].shape[0] for val in data.values()])
+
+    keys = sorted(data.keys())
+
+    # inspect data
+    for i, key in enumerate(keys):
+        axs[i // num_cols, i % num_cols].title.set_text(key)
+
+        display[key] = dict()
+
+        display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
+        display[key]['down_factor'] = data[key]['fs'] / fs_max
+
+    start_index = max(start_index, half_signal_window_length)
+    while stop_index < 0:
+        stop_index += num_samples
+
+    stop_index = min(stop_index, num_samples - half_signal_window_length)
+
+    # actual plotting
+    frames = []
+    for index in range(start_index, stop_index):
+        ims = []
+        for i, key in enumerate(keys):
+            feature_index = int(round(index * display[key]['down_factor']))
+
+            if display[key]['type'] == 'plot':
+                ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
+
+            elif display[key]['type'] == 'image':
+                ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
+
+        frames.append(ims)
+
+    ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
+
+    if not filename.endswith('.mp4'):
+        filename += '.mp4'
+
+    ani.save(filename)
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
@@ -1,0 +1,236 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.endoscopy import write_data
+
+class LimitedAdaptiveComb1d(nn.Module):
+    COUNTER = 1
+
+    def __init__(self,
+                 kernel_size,
+                 feature_dim,
+                 frame_size=160,
+                 overlap_size=40,
+                 use_bias=True,
+                 padding=None,
+                 max_lag=256,
+                 name=None,
+                 gain_limit_db=10,
+                 global_gain_limits_db=[-6, 6],
+                 norm_p=2):
+        """
+
+        Parameters:
+        -----------
+
+        feature_dim : int
+            dimension of features from which kernels, biases and gains are computed
+
+        frame_size : int, optional
+            frame size, defaults to 160
+
+        overlap_size : int, optional
+            overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40
+
+        use_bias : bool, optional
+            if true, biases will be added to output channels. Defaults to True
+
+        padding : List[int, int], optional
+            left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2]
+
+        max_lag : int, optional
+            maximal pitch lag, defaults to 256
+
+        have_a0 : bool, optional
+            If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False
+
+        name: str or None, optional
+            specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d
+
+        """
+
+        super(LimitedAdaptiveComb1d, self).__init__()
+
+        self.in_channels   = 1
+        self.out_channels  = 1
+        self.feature_dim   = feature_dim
+        self.kernel_size   = kernel_size
+        self.frame_size    = frame_size
+        self.overlap_size  = overlap_size
+        self.use_bias      = use_bias
+        self.max_lag       = max_lag
+        self.limit_db      = gain_limit_db
+        self.norm_p        = norm_p
+
+        if name is None:
+            self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER)
+            LimitedAdaptiveComb1d.COUNTER += 1
+        else:
+            self.name = name
+
+        # network for generating convolution weights
+        self.conv_kernel = nn.Linear(feature_dim, kernel_size)
+
+        if self.use_bias:
+            self.conv_bias = nn.Linear(feature_dim,1)
+
+        # comb filter gain
+        self.filter_gain = nn.Linear(feature_dim, 1)
+        self.log_gain_limit = gain_limit_db * 0.11512925464970229
+        with torch.no_grad():
+            self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
+
+        self.global_filter_gain = nn.Linear(feature_dim, 1)
+        log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
+        self.filter_gain_a = (log_max - log_min) / 2
+        self.filter_gain_b = (log_max + log_min) / 2
+
+        if type(padding) == type(None):
+            self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
+        else:
+            self.padding = padding
+
+        self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
+
+    def forward(self, x, features, lags, debug=False):
+        """ adaptive 1d convolution
+
+
+        Parameters:
+        -----------
+        x : torch.tensor
+            input signal of shape (batch_size, in_channels, num_samples)
+
+        feathres : torch.tensor
+            frame-wise features of shape (batch_size, num_frames, feature_dim)
+
+        lags: torch.LongTensor
+            frame-wise lags for comb-filtering
+
+        """
+
+        batch_size = x.size(0)
+        num_frames = features.size(1)
+        num_samples = x.size(2)
+        frame_size = self.frame_size
+        overlap_size = self.overlap_size
+        kernel_size = self.kernel_size
+        win1 = torch.flip(self.overlap_win, [0])
+        win2 = self.overlap_win
+
+        if num_samples // self.frame_size != num_frames:
+            raise ValueError('non matching sizes in AdaptiveConv1d.forward')
+
+        conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
+        conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
+
+        if self.use_bias:
+            conv_biases  = self.conv_bias(features).permute(0, 2, 1)
+
+        conv_gains   = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
+        # calculate gains
+        global_conv_gains   = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
+
+        if debug and batch_size == 1:
+            key = self.name + "_gains"
+            write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+            key = self.name + "_kernels"
+            write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+            key = self.name + "_lags"
+            write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+            key = self.name + "_global_conv_gains"
+            write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+
+
+        # frame-wise convolution with overlap-add
+        output_frames = []
+        overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
+        x = F.pad(x, self.padding)
+        x = F.pad(x, [self.max_lag, self.overlap_size])
+
+        idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1)
+        idx = torch.repeat_interleave(idx, batch_size, 0)
+        idx = torch.repeat_interleave(idx, self.in_channels, 1)
+
+
+        for i in range(num_frames):
+
+            cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1)
+            xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1))
+
+            new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
+
+
+            if self.use_bias:
+                new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
+
+            offset = self.max_lag + self.padding[0]
+            new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
+
+            # overlapping part
+            output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
+
+            # non-overlapping part
+            output_frames.append(new_chunk[:, :, overlap_size : frame_size])
+
+            # mem for next frame
+            overlap_mem = new_chunk[:, :, frame_size :]
+
+        # concatenate chunks
+        output = torch.cat(output_frames, dim=-1)
+
+        return output
+
+    def flop_count(self, rate):
+        frame_rate = rate / self.frame_size
+        overlap = self.overlap_size
+        overhead = overlap / self.frame_size
+
+        count = 0
+
+        # kernel computation and filtering
+        count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
+        count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
+        count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
+
+        # bias computation
+        if self.use_bias:
+            count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead)
+
+        # a0 computation
+        count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
+
+        # windowing
+        count += overlap * frame_rate * 3 * self.out_channels
+
+        return count
--- /dev/null
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
@@ -1,0 +1,222 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.endoscopy import write_data
+
+class LimitedAdaptiveConv1d(nn.Module):
+    COUNTER = 1
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 feature_dim,
+                 frame_size=160,
+                 overlap_size=40,
+                 use_bias=True,
+                 padding=None,
+                 name=None,
+                 gain_limits_db=[-6, 6],
+                 shape_gain_db=0,
+                 norm_p=2):
+        """
+
+        Parameters:
+        -----------
+
+        in_channels : int
+            number of input channels
+
+        out_channels : int
+            number of output channels
+
+        feature_dim : int
+            dimension of features from which kernels, biases and gains are computed
+
+        frame_size : int
+            frame size
+
+        overlap_size : int
+            overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame
+
+        use_bias : bool
+            if true, biases will be added to output channels
+
+
+        padding : List[int, int]
+
+        """
+
+        super(LimitedAdaptiveConv1d, self).__init__()
+
+
+
+        self.in_channels    = in_channels
+        self.out_channels   = out_channels
+        self.feature_dim    = feature_dim
+        self.kernel_size    = kernel_size
+        self.frame_size     = frame_size
+        self.overlap_size   = overlap_size
+        self.use_bias       = use_bias
+        self.gain_limits_db = gain_limits_db
+        self.shape_gain_db  = shape_gain_db
+        self.norm_p         = norm_p
+
+        if name is None:
+            self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER)
+            LimitedAdaptiveConv1d.COUNTER += 1
+        else:
+            self.name = name
+
+        # network for generating convolution weights
+        self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size)
+
+        if self.use_bias:
+            self.conv_bias = nn.Linear(feature_dim, out_channels)
+
+        self.shape_gain = min(1, 10**(shape_gain_db / 20))
+
+        self.filter_gain = nn.Linear(feature_dim, out_channels)
+        log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
+        self.filter_gain_a = (log_max - log_min) / 2
+        self.filter_gain_b = (log_max + log_min) / 2
+
+        if type(padding) == type(None):
+            self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
+        else:
+            self.padding = padding
+
+        self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
+
+
+    def flop_count(self, rate):
+        frame_rate = rate / self.frame_size
+        overlap = self.overlap_size
+        overhead = overlap / self.frame_size
+
+        count = 0
+
+        # kernel computation and filtering
+        count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
+        count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
+
+        # bias computation
+        if self.use_bias:
+            count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead)
+
+        # gain computation
+
+        count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
+
+        # windowing
+        count += 3 * overlap * frame_rate * self.out_channels
+
+        return count
+
+    def forward(self, x, features, debug=False):
+        """ adaptive 1d convolution
+
+
+        Parameters:
+        -----------
+        x : torch.tensor
+            input signal of shape (batch_size, in_channels, num_samples)
+
+        feathres : torch.tensor
+            frame-wise features of shape (batch_size, num_frames, feature_dim)
+
+        """
+
+        batch_size = x.size(0)
+        num_frames = features.size(1)
+        num_samples = x.size(2)
+        frame_size = self.frame_size
+        overlap_size = self.overlap_size
+        kernel_size = self.kernel_size
+        win1 = torch.flip(self.overlap_win, [0])
+        win2 = self.overlap_win
+
+        if num_samples // self.frame_size != num_frames:
+            raise ValueError('non matching sizes in AdaptiveConv1d.forward')
+
+        conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
+
+        # normalize kernels (TODO: switch to L1 and normalize over kernel and input channel dimension)
+        conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=[-2, -1], keepdim=True))
+
+        # limit shape
+        id_kernels = torch.zeros_like(conv_kernels)
+        id_kernels[..., self.padding[1]] = 1
+
+        conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels
+
+        if self.use_bias:
+            conv_biases  = self.conv_bias(features).permute(0, 2, 1)
+
+        # calculate gains
+        conv_gains   = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
+        if debug and batch_size == 1:
+            key = self.name + "_gains"
+            write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+            key = self.name + "_kernels"
+            write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+
+
+        # frame-wise convolution with overlap-add
+        output_frames = []
+        overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
+        x = F.pad(x, self.padding)
+        x = F.pad(x, [0, self.overlap_size])
+
+        for i in range(num_frames):
+            xx = x[:, :, i * frame_size : (i + 1) * frame_size + kernel_size - 1 + overlap_size].reshape((1, batch_size * self.in_channels, -1))
+            new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
+
+            if self.use_bias:
+                new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
+
+            new_chunk = new_chunk * conv_gains[:, :, i : i + 1]
+
+            # overlapping part
+            output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
+
+            # non-overlapping part
+            output_frames.append(new_chunk[:, :, overlap_size : frame_size])
+
+            # mem for next frame
+            overlap_mem = new_chunk[:, :, frame_size :]
+
+        # concatenate chunks
+        output = torch.cat(output_frames, dim=-1)
+
+        return output
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/utils/layers/pitch_auto_correlator.py
@@ -1,0 +1,84 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class PitchAutoCorrelator(nn.Module):
+    def __init__(self,
+                 frame_size=80,
+                 pitch_min=32,
+                 pitch_max=300,
+                 radius=2):
+
+        super().__init__()
+
+        self.frame_size = frame_size
+        self.pitch_min = pitch_min
+        self.pitch_max = pitch_max
+        self.radius = radius
+
+
+    def forward(self, x, periods):
+        # x of shape (batch_size, channels, num_samples)
+        # periods of shape (batch_size, num_frames)
+
+        num_frames = periods.size(1)
+        batch_size = periods.size(0)
+        num_samples = self.frame_size * num_frames
+        channels = x.size(1)
+
+        assert num_samples == x.size(-1)
+
+        range = torch.arange(-self.radius, self.radius + 1, device=x.device)
+        idx = torch.arange(self.frame_size * num_frames, device=x.device)
+        p_up = torch.repeat_interleave(periods, self.frame_size, 1)
+        lookup = idx + self.pitch_max -  p_up
+        lookup = lookup.unsqueeze(-1) + range
+        lookup = lookup.unsqueeze(1)
+
+        # padding
+        x_pad = F.pad(x, [self.pitch_max, 0])
+        x_ext = torch.repeat_interleave(x_pad.unsqueeze(-1), 2 * self.radius + 1, -1)
+
+        # framing
+        x_select = torch.gather(x_ext, 2, lookup)
+        x_frames = x_pad[..., self.pitch_max : ].reshape(batch_size, channels, num_frames, self.frame_size, 1)
+        lag_frames = x_select.reshape(batch_size, 1, num_frames, self.frame_size, -1)
+
+        # calculate auto-correlation
+        dotp = torch.sum(x_frames * lag_frames, dim=-2)
+        frame_nrg = torch.sum(x_frames * x_frames, dim=-2)
+        lag_frame_nrg  = torch.sum(lag_frames * lag_frames, dim=-2)
+
+        acorr = dotp / torch.sqrt(frame_nrg * lag_frame_nrg + 1e-9)
+
+        return acorr
--- /dev/null
+++ b/dnn/torch/osce/utils/misc.py
@@ -1,0 +1,42 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+def count_parameters(model, verbose=False):
+    total = 0
+    for name, p in model.named_parameters():
+        count = torch.ones_like(p).sum().item()
+
+        if verbose:
+            print(f"{name}: {count} parameters")
+
+        total += count
+
+    return total
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/utils/pitch.py
@@ -1,0 +1,121 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import numpy as np
+
+def hangover(lags, num_frames=10):
+    lags = lags.copy()
+    count = 0
+    last_lag = 0
+
+    for i in range(len(lags)):
+        lag = lags[i]
+
+        if lag == 0:
+            if count < num_frames:
+                lags[i] = last_lag
+                count += 1
+        else:
+            count = 0
+
+    return lags
+
+
+def smooth_pitch_lags(lags, d=2):
+
+    assert d < 4
+
+    num_silk_frames = len(lags) // 4
+
+    smoothed_lags = lags.copy()
+
+    tmp = np.arange(1, d+1)
+    kernel = np.concatenate((tmp, [d+1], tmp[::-1]), dtype=np.float32)
+    kernel = kernel / np.sum(kernel)
+
+    last = lags[0:d][::-1]
+    for i in range(num_silk_frames):
+        frame = lags[i * 4: (i+1) * 4]
+
+        if np.max(np.abs(frame)) == 0:
+            last = frame[4-d:]
+            continue
+
+        if i == num_silk_frames - 1:
+            next = frame[4-d:][::-1]
+        else:
+            next = lags[(i+1) * 4 : (i+1) * 4 + d]
+
+        if np.max(np.abs(next)) == 0:
+            next = frame[4-d:][::-1]
+
+        if np.max(np.abs(last)) == 0:
+            last = frame[0:d][::-1]
+
+        smoothed_frame = np.convolve(np.concatenate((last, frame, next), dtype=np.float32), kernel, mode='valid')
+
+        smoothed_lags[i * 4: (i+1) * 4] = np.round(smoothed_frame)
+
+        last = frame[4-d:]
+
+    return smoothed_lags
+
+def calculate_acorr_window(x, frame_size, lags, history=None, max_lag=300, radius=2, add_double_lag_acorr=False, no_pitch_threshold=32):
+    eps = 1e-9
+
+    lag_multiplier = 2 if add_double_lag_acorr else 1
+
+    if history is None:
+        history = np.zeros(lag_multiplier * max_lag + radius, dtype=x.dtype)
+
+    offset = len(history)
+
+    assert offset >= max_lag + radius
+    assert len(x) % frame_size == 0
+
+    num_frames = len(x) // frame_size
+    lags = lags.copy()
+
+    x_ext = np.concatenate((history, x), dtype=x.dtype)
+
+    d = radius
+    num_acorrs = 2 * d + 1
+    acorrs = np.zeros((num_frames, lag_multiplier * num_acorrs), dtype=x.dtype)
+
+    for idx in range(num_frames):
+        lag = lags[idx].item()
+        frame = x_ext[offset + idx * frame_size : offset + (idx + 1) * frame_size]
+
+        for k in range(lag_multiplier):
+            lag1 = (k + 1) * lag if lag >= no_pitch_threshold else lag
+            for j in range(num_acorrs):
+                past = x_ext[offset + idx * frame_size - lag1 + j - d : offset + (idx + 1) * frame_size - lag1 + j - d]
+                acorrs[idx, j + k * num_acorrs] = np.dot(frame, past) / np.sqrt(np.dot(frame, frame) * np.dot(past, past) + eps)
+
+    return acorrs, lags
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/utils/silk_features.py
@@ -1,0 +1,151 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+
+import os
+
+import numpy as np
+import torch
+
+import scipy
+
+from utils.pitch import hangover, calculate_acorr_window
+from utils.spec import create_filter_bank, cepstrum, log_spectrum, log_spectrum_from_lpc
+
+def spec_from_lpc(a, n_fft=128, eps=1e-9):
+    order = a.shape[-1]
+    assert order + 1 < n_fft
+
+    x = np.zeros((*a.shape[:-1], n_fft ))
+    x[..., 0] = 1
+    x[..., 1:1 + order] = -a
+
+    X = np.fft.fft(x, axis=-1)
+    X = np.abs(X[..., :n_fft//2 + 1]) ** 2
+
+    S = 1 / (X + eps)
+
+    return S
+
+def silk_feature_factory(no_pitch_value=256,
+                         acorr_radius=2,
+                         pitch_hangover=8,
+                         num_bands_clean_spec=64,
+                         num_bands_noisy_spec=18,
+                         noisy_spec_scale='opus',
+                         noisy_apply_dct=True,
+                         add_offset=False,
+                         add_double_lag_acorr=False
+                         ):
+
+    w = scipy.signal.windows.cosine(320)
+    fb_clean_spec = create_filter_bank(num_bands_clean_spec, 320, scale='erb', round_center_bins=True, normalize=True)
+    fb_noisy_spec = create_filter_bank(num_bands_noisy_spec, 320, scale=noisy_spec_scale, round_center_bins=True, normalize=True)
+
+    def create_features(noisy, noisy_history, lpcs, gains, ltps, periods, offsets):
+
+        periods = periods.copy()
+
+        if pitch_hangover > 0:
+            periods = hangover(periods, num_frames=pitch_hangover)
+
+        periods[periods == 0] = no_pitch_value
+
+        clean_spectrum = 0.3 * log_spectrum_from_lpc(lpcs, fb=fb_clean_spec, n_fft=320)
+
+        if noisy_apply_dct:
+            noisy_cepstrum = np.repeat(
+                cepstrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
+        else:
+            noisy_cepstrum = np.repeat(
+                log_spectrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
+
+        log_gains = np.log(gains + 1e-9).reshape(-1, 1)
+
+        acorr, _ = calculate_acorr_window(noisy, 80, periods, noisy_history, radius=acorr_radius, add_double_lag_acorr=add_double_lag_acorr)
+
+        if add_offset:
+            features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains, offsets.reshape(-1, 1)), axis=-1, dtype=np.float32)
+        else:
+            features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains), axis=-1, dtype=np.float32)
+
+        return features, periods.astype(np.int64)
+
+    return create_features
+
+
+
+def load_inference_data(path,
+                        no_pitch_value=256,
+                        skip=92,
+                        preemph=0.85,
+                        acorr_radius=2,
+                        pitch_hangover=8,
+                        num_bands_clean_spec=64,
+                        num_bands_noisy_spec=18,
+                        noisy_spec_scale='opus',
+                        noisy_apply_dct=True,
+                        add_offset=False,
+                        add_double_lag_acorr=False,
+                        **kwargs):
+
+    print(f"[load_inference_data]: ignoring keyword arguments {kwargs.keys()}...")
+
+    lpcs    = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
+    ltps    = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
+    gains   = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
+    periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
+    num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32).astype(np.float32).reshape(-1, 1)
+    num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32).reshape(-1, 1)
+    offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
+
+    # load signal, add back delay and pre-emphasize
+    signal  = np.fromfile(os.path.join(path, 'noisy.s16'), dtype=np.int16).astype(np.float32) / (2 ** 15)
+    signal = np.concatenate((np.zeros(skip, dtype=np.float32), signal), dtype=np.float32)
+
+    create_features = silk_feature_factory(no_pitch_value, acorr_radius, pitch_hangover, num_bands_clean_spec, num_bands_noisy_spec, noisy_spec_scale, noisy_apply_dct, add_offset, add_double_lag_acorr)
+
+    num_frames = min((len(signal) // 320) * 4, len(lpcs))
+    signal = signal[: num_frames * 80]
+    lpcs = lpcs[: num_frames]
+    ltps = ltps[: num_frames]
+    gains = gains[: num_frames]
+    periods = periods[: num_frames]
+    num_bits = num_bits[: num_frames // 4]
+    num_bits_smooth = num_bits[: num_frames // 4]
+    offsets = offsets[: num_frames]
+
+    numbits = np.repeat(np.concatenate((num_bits, num_bits_smooth), axis=-1, dtype=np.float32), 4, axis=0)
+
+    features, periods = create_features(signal, np.zeros(350, dtype=signal.dtype), lpcs, gains, ltps, periods, offsets)
+
+    if preemph > 0:
+        signal[1:] -= preemph * signal[:-1]
+
+    return torch.from_numpy(signal), torch.from_numpy(features), torch.from_numpy(periods), torch.from_numpy(numbits)
--- /dev/null
+++ b/dnn/torch/osce/utils/spec.py
@@ -1,0 +1,194 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import math as m
+import numpy as np
+import scipy
+
+def erb(f):
+    return 24.7 * (4.37 * f + 1)
+
+def inv_erb(e):
+    return (e / 24.7 - 1) / 4.37
+
+def bark(f):
+    return 6 * m.asinh(f/600)
+
+def inv_bark(b):
+    return 600 * m.sinh(b / 6)
+
+
+scale_dict = {
+    'bark': [bark, inv_bark],
+    'erb': [erb, inv_erb]
+}
+
+def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
+
+    f0 = 0
+    num_bins = n_fft // 2 + 1
+    f1 = fs / n_fft * (num_bins - 1)
+    fstep = fs / n_fft
+
+    if scale == 'opus':
+        bins_5ms = [0,  1,  2,  3,  4,  5,  6,  7,  8, 10, 12, 14, 16, 20, 24, 28, 34, 40]
+        fac = 1000 * n_fft / fs / 5
+        if num_bands != 18:
+            print("warning: requested Opus filter bank with num_bands != 18. Adjusting num_bands.")
+            num_bands = 18
+        center_bins = np.array([fac * bin for bin in bins_5ms])
+    else:
+        to_scale, from_scale = scale_dict[scale]
+
+        s0 = to_scale(f0)
+        s1 = to_scale(f1)
+
+        center_freqs = np.array([f0] + [from_scale(s0 + i * (s1 - s0) / (num_bands)) for i in range(1, num_bands - 1)] + [f1])
+        center_bins  = (center_freqs - f0) / fstep
+
+    if round_center_bins:
+        center_bins = np.round(center_bins)
+
+    filter_bank = np.zeros((num_bands, num_bins))
+
+    band = 0
+    for bin in range(num_bins):
+        # update band index
+        if bin > center_bins[band + 1]:
+            band += 1
+
+        # calculate filter coefficients
+        frac = (center_bins[band + 1] - bin) / (center_bins[band + 1] - center_bins[band])
+        filter_bank[band][bin]     = frac
+        filter_bank[band + 1][bin] = 1 - frac
+
+    if return_upper:
+        extend = n_fft - num_bins
+        filter_bank = np.concatenate((filter_bank, np.fliplr(filter_bank[:, 1:extend+1])), axis=1)
+
+    if normalize:
+        filter_bank = filter_bank / np.sum(filter_bank, axis=1).reshape(-1, 1)
+
+    return filter_bank
+
+
+def compressed_log_spec(pspec):
+
+    lpspec = np.zeros_like(pspec)
+    num_bands = pspec.shape[-1]
+
+    log_max = -2
+    follow = -2
+
+    for i in range(num_bands):
+        tmp = np.log10(pspec[i] + 1e-9)
+        tmp = max(log_max, max(follow - 2.5, tmp))
+        lpspec[i] = tmp
+        log_max = max(log_max, tmp)
+        follow = max(follow - 2.5, tmp)
+
+    return lpspec
+
+def log_spectrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False, power=1):
+    """ calculates cepstrum from SILK lpcs """
+    order = a.shape[-1]
+    assert order + 1 < n_fft
+
+    a = a * (gamma ** (1 + np.arange(order))).astype(np.float32)
+
+    x = np.zeros((*a.shape[:-1], n_fft ))
+    x[..., 0] = 1
+    x[..., 1:1 + order] = -a
+
+    X = np.fft.fft(x, axis=-1)
+    X = np.abs(X[..., :n_fft//2 + 1]) ** power
+
+    S = 1 / (X + eps)
+
+    if fb is None:
+        Sf = S
+    else:
+        Sf = np.matmul(S, fb.T)
+
+    if compress:
+        Sf = np.apply_along_axis(compressed_log_spec, -1, Sf)
+    else:
+        Sf = np.log(Sf + eps)
+
+    return Sf
+
+def cepstrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False):
+    """ calculates cepstrum from SILK lpcs """
+
+    Sf = log_spectrum_from_lpc(a, fb, n_fft, eps, gamma, compress)
+
+    cepstrum = scipy.fftpack.dct(Sf, 2, norm='ortho')
+
+    return cepstrum
+
+
+
+def log_spectrum(x, frame_size, fb=None, window=None, power=1):
+    """ calculate cepstrum on 50% overlapping frames """
+
+    assert(2*len(x)) % frame_size == 0
+    assert frame_size % 2 == 0
+
+    n = len(x)
+    num_even = n // frame_size
+    num_odd  = (n - frame_size // 2) // frame_size
+    num_bins = frame_size // 2 + 1
+
+    x_even = x[:num_even * frame_size].reshape(-1, frame_size)
+    x_odd  = x[frame_size//2 : frame_size//2 + frame_size *  num_odd].reshape(-1, frame_size)
+
+    x_unfold = np.empty((x_even.size + x_odd.size), dtype=x.dtype).reshape((-1, frame_size))
+    x_unfold[::2, :] = x_even
+    x_unfold[1::2, :] = x_odd
+
+    if window is not None:
+        x_unfold *= window.reshape(1, -1)
+
+    X = np.abs(np.fft.fft(x_unfold, n=frame_size, axis=-1))[:, :num_bins] ** power
+
+    if fb is not None:
+        X = np.matmul(X, fb.T)
+
+
+    return np.log(X + 1e-9)
+
+
+def cepstrum(x, frame_size, fb=None, window=None):
+    """ calculate cepstrum on 50% overlapping frames """
+
+    X = log_spectrum(x, frame_size, fb, window)
+
+    cepstrum = scipy.fftpack.dct(X, 2, norm='ortho')
+
+    return cepstrum
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/utils/templates.py
@@ -1,0 +1,92 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+
+setup_dict = dict()
+
+lace_setup = {
+    'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
+    'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
+    'model': {
+        'name': 'lace',
+        'args': [],
+        'kwargs': {
+            'comb_gain_limit_db': 10,
+            'cond_dim': 128,
+            'conv_gain_limits_db': [-12, 12],
+            'global_gain_limits_db': [-6, 6],
+            'hidden_feature_dim': 96,
+            'kernel_size': 15,
+            'num_features': 93,
+            'numbits_embedding_dim': 8,
+            'numbits_range': [50, 650],
+            'partial_lookahead': True,
+            'pitch_embedding_dim': 64,
+            'pitch_max': 300,
+            'preemph': 0.85,
+            'skip': 91
+        }
+    },
+    'data': {
+        'frames_per_sample': 100,
+        'no_pitch_value': 7,
+        'preemph': 0.85,
+        'skip': 91,
+        'pitch_hangover': 8,
+        'acorr_radius': 2,
+        'num_bands_clean_spec': 64,
+        'num_bands_noisy_spec': 18,
+        'noisy_spec_scale': 'opus',
+        'pitch_hangover': 8,
+    },
+    'training': {
+        'batch_size': 256,
+        'lr': 5.e-4,
+        'lr_decay_factor': 2.5e-5,
+        'epochs': 50,
+        'frames_per_sample': 50,
+        'loss': {
+            'w_l1': 0,
+            'w_lm': 0,
+            'w_logmel': 0,
+            'w_sc': 0,
+            'w_wsc': 0,
+            'w_xcorr': 0,
+            'w_sxcorr': 1,
+            'w_l2': 10,
+            'w_slm': 2
+        }
+    }
+}
+
+
+
+setup_dict = {
+    'lace': lace_setup,
+}
--