shithub: opus

Download patch

ref: b5a5f14036322957a8669de7f955220baabaa823
parent: b4909e1dd907a255fefd5cf060adf7ced93f2117
author: Jan Buethe <jbuethe@amazon.de>
date: Thu Apr 25 13:12:45 EDT 2024

BBWENet python implementation

--- /dev/null
+++ b/dnn/torch/osce/adv_train_bwe_model.py
@@ -1,0 +1,485 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+import sys
+import math as m
+import random
+
+import yaml
+
+from tqdm import tqdm
+
+try:
+    import git
+    has_git = True
+except:
+    has_git = False
+
+import torch
+from torch.optim.lr_scheduler import LambdaLR
+import torch.nn.functional as F
+
+from scipy.io import wavfile
+import numpy as np
+import pesq
+
+from data import SimpleBWESet
+from models import model_dict
+
+
+from utils.bwe_features import load_inference_data
+from utils.misc import count_parameters, retain_grads, get_grad_norm, create_weights
+
+from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
+from losses.td_lowpass import TDLowpass
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('setup', type=str, help='setup yaml file')
+parser.add_argument('output', type=str, help='output path')
+parser.add_argument('--device', type=str, help='compute device', default=None)
+parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
+parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
+parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
+
+args = parser.parse_args()
+
+
+def preemph(x, gamma):
+    y = torch.cat((x[..., 0:1], x[..., 1:] - gamma * x[...,:-1]), dim=-1)
+    return y
+
+torch.set_num_threads(4)
+
+with open(args.setup, 'r') as f:
+    setup = yaml.load(f.read(), yaml.FullLoader)
+
+checkpoint_prefix = 'checkpoint'
+output_prefix = 'output'
+setup_name = 'setup.yml'
+output_file='out.txt'
+
+
+# check model
+if not 'name' in setup['model']:
+    print(f'warning: did not find model entry in setup, using default PitchPostFilter')
+    model_name = 'pitchpostfilter'
+else:
+    model_name = setup['model']['name']
+
+# prepare output folder
+if os.path.exists(args.output):
+    print("warning: output folder exists")
+
+    reply = input('continue? (y/n): ')
+    while reply not in {'y', 'n'}:
+        reply = input('continue? (y/n): ')
+
+    if reply == 'n':
+        os._exit()
+else:
+    os.makedirs(args.output, exist_ok=True)
+
+checkpoint_dir = os.path.join(args.output, 'checkpoints')
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+# add repo info to setup
+if has_git:
+    working_dir = os.path.split(__file__)[0]
+    try:
+        repo = git.Repo(working_dir, search_parent_directories=True)
+        setup['repo'] = dict()
+        hash = repo.head.object.hexsha
+        urls = list(repo.remote().urls)
+        is_dirty = repo.is_dirty()
+
+        if is_dirty:
+            print("warning: repo is dirty")
+
+        setup['repo']['hash'] = hash
+        setup['repo']['urls'] = urls
+        setup['repo']['dirty'] = is_dirty
+    except:
+        has_git = False
+
+# dump setup
+with open(os.path.join(args.output, setup_name), 'w') as f:
+    yaml.dump(setup, f)
+
+
+if args.testdata is not None:
+
+    testsignal, features = load_inference_data(args.testdata, **setup['data'])
+
+    inference_test = True
+    inference_folder = os.path.join(args.output, 'inference_test')
+    os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
+
+else:
+    inference_test = False
+
+# training parameters
+batch_size      = setup['training']['batch_size']
+epochs          = setup['training']['epochs']
+lr              = setup['training']['lr']
+lr_decay_factor = setup['training']['lr_decay_factor']
+lr_gen          = lr * setup['training']['gen_lr_reduction']
+lambda_feat     =  setup['training']['lambda_feat']
+lambda_reg      = setup['training']['lambda_reg']
+adv_target      = setup['training'].get('adv_target', 'x_48')
+newloss         = setup['training'].get('newloss', False)
+
+# load training dataset
+data_config = setup['data']
+data = SimpleBWESet(setup['dataset'], **data_config)
+
+# load validation dataset if given
+if 'validation_dataset' in setup:
+    validation_data = SimpleBWESet(setup['validation_dataset'], **data_config)
+
+    validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
+
+    run_validation = True
+else:
+    run_validation = False
+
+# create model
+model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
+
+# create discriminator
+print(setup['discriminator']['name'],setup['discriminator']['kwargs'])
+disc_name = setup['discriminator']['name']
+disc = model_dict[disc_name](
+    *setup['discriminator']['args'], **setup['discriminator']['kwargs']
+)
+
+
+# set compute device
+if type(args.device) == type(None):
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+else:
+    device = torch.device(args.device)
+
+# dataloader
+dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
+
+# optimizer is introduced to trainable parameters
+parameters = [p for p in model.parameters() if p.requires_grad]
+optimizer = torch.optim.Adam(parameters, lr=lr_gen)
+
+# disc optimizer
+parameters = [p for p in disc.parameters() if p.requires_grad]
+optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
+
+# learning rate scheduler
+scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
+
+if args.initial_checkpoint is not None:
+    print(f"loading state dict from {args.initial_checkpoint}...")
+    chkpt = torch.load(args.initial_checkpoint, map_location=device)
+    model.load_state_dict(chkpt['state_dict'])
+
+    if 'disc_state_dict' in chkpt:
+        print(f"loading discriminator state dict from {args.initial_checkpoint}...")
+        disc.load_state_dict(chkpt['disc_state_dict'])
+
+    if 'optimizer_state_dict' in chkpt:
+        print(f"loading optimizer state dict from {args.initial_checkpoint}...")
+        optimizer.load_state_dict(chkpt['optimizer_state_dict'])
+
+    if 'disc_optimizer_state_dict' in chkpt:
+        print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
+        optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
+
+    if 'scheduler_state_disc' in chkpt:
+        print(f"loading scheduler state dict from {args.initial_checkpoint}...")
+        scheduler.load_state_dict(chkpt['scheduler_state_dict'])
+
+    # if 'torch_rng_state' in chkpt:
+    #     print(f"setting torch RNG state from {args.initial_checkpoint}...")
+    #     torch.set_rng_state(chkpt['torch_rng_state'])
+
+    if 'numpy_rng_state' in chkpt:
+        print(f"setting numpy RNG state from {args.initial_checkpoint}...")
+        np.random.set_state(chkpt['numpy_rng_state'])
+
+    if 'python_rng_state' in chkpt:
+        print(f"setting Python RNG state from {args.initial_checkpoint}...")
+        random.setstate(chkpt['python_rng_state'])
+
+# loss
+w_l1 = setup['training']['loss']['w_l1']
+w_lm = setup['training']['loss']['w_lm']
+w_slm = setup['training']['loss']['w_slm']
+w_sc = setup['training']['loss']['w_sc']
+w_logmel = setup['training']['loss']['w_logmel']
+w_wsc = setup['training']['loss']['w_wsc']
+w_xcorr = setup['training']['loss']['w_xcorr']
+w_sxcorr = setup['training']['loss']['w_sxcorr']
+w_l2 = setup['training']['loss']['w_l2']
+w_tdlp = setup['training']['loss'].get('w_tdlp', 0)
+preemph_gamma = setup['training']['preemph']
+
+w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 + w_tdlp
+
+fft_sizes_16k = [2048, 1024, 512, 256, 128, 64]
+fft_sizes_48k = [3 * n for n in fft_sizes_16k]
+stftloss = MRSTFTLoss(fft_sizes=fft_sizes_48k, fs=48000, sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
+logmelloss = MRLogMelLoss(fft_sizes=fft_sizes_48k).to(device)
+
+def xcorr_loss(y_true, y_pred):
+    dims = list(range(1, len(y_true.shape)))
+
+    loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
+
+    return torch.mean(loss)
+
+def td_l2_norm(y_true, y_pred):
+    dims = list(range(1, len(y_true.shape)))
+
+    loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
+
+    return loss.mean()
+
+def td_l1(y_true, y_pred, pow=0):
+    dims = list(range(1, len(y_true.shape)))
+    tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
+
+    return torch.mean(tmp)
+
+if newloss:
+    tdlp = TDLowpass(31, 4000/24000).to(device)
+else:
+    tdlp = TDLowpass(15, 4000/24000).to(device)
+
+if newloss:
+    def criterion(x, y, x_up):
+        # FD-losses are calculated on preemphasized signals
+        xp = preemph(x, preemph_gamma)
+        yp = preemph(y, preemph_gamma)
+
+        return (w_l1 * td_l1(x, y, pow=1) +  stftloss(xp, yp) + w_logmel * logmelloss(xp, yp)
+                + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y) + w_tdlp * tdlp(x_up, y)) / w_sum
+else:
+    def criterion(x, y, x_up):
+        # all losses are calculated on preemphasized signals
+        x = preemph(x, preemph_gamma)
+        y = preemph(y, preemph_gamma)
+        x_up = preemph(x_up, preemph_gamma)
+
+        return (w_l1 * td_l1(x, y, pow=1) +  stftloss(x, y) + w_logmel * logmelloss(x, y)
+                + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y) + w_tdlp * tdlp(x_up, y)) / w_sum
+
+# model checkpoint
+checkpoint = {
+    'setup'         : setup,
+    'state_dict'    : model.state_dict(),
+    'loss'          : -1
+}
+
+
+if not args.no_redirect:
+    print(f"re-directing output to {os.path.join(args.output, output_file)}")
+    sys.stdout = open(os.path.join(args.output, output_file), "w")
+
+
+print("summary:")
+
+print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
+if hasattr(model, 'flop_count'):
+    print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
+print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
+
+
+best_loss = 1e9
+log_interval = 10
+
+
+m_r = 0
+m_f = 0
+s_r = 1
+s_f = 1
+
+def optimizer_to(optim, device):
+    for param in optim.state.values():
+        if isinstance(param, torch.Tensor):
+            param.data = param.data.to(device)
+            if param._grad is not None:
+                param._grad.data = param._grad.data.to(device)
+        elif isinstance(param, dict):
+            for subparam in param.values():
+                if isinstance(subparam, torch.Tensor):
+                    subparam.data = subparam.data.to(device)
+                    if subparam._grad is not None:
+                        subparam._grad.data = subparam._grad.data.to(device)
+
+optimizer_to(optimizer, device)
+optimizer_to(optimizer_disc, device)
+
+retain_grads(model)
+retain_grads(disc)
+
+for ep in range(1, epochs + 1):
+    print(f"training epoch {ep}...")
+
+    model.to(device)
+    disc.to(device)
+    model.train()
+    disc.train()
+
+    running_disc_loss = 0
+    running_adv_loss = 0
+    running_feature_loss = 0
+    running_reg_loss = 0
+    running_disc_grad_norm = 0
+    running_model_grad_norm = 0
+
+    with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
+        for i, batch in enumerate(tepoch):
+
+            # set gradients to zero
+            optimizer.zero_grad()
+
+            # push batch to device
+            for key in batch:
+                batch[key] = batch[key].to(device)
+
+            target = batch['x_48']
+            x16 = batch['x_16'].unsqueeze(1)
+            x_up = model.upsampler(x16)
+            disc_target = batch['x_48']
+
+            # calculate model output
+            output = model(x16, batch['features'])
+
+            # pre-emphasize
+            disc_target = preemph(target, preemph_gamma)
+            output_preemph = preemph(output, preemph_gamma)
+
+            # discriminator update
+            scores_gen = disc(output_preemph.detach())
+            scores_real = disc(disc_target.unsqueeze(1))
+
+            disc_loss = 0
+            for score in scores_gen:
+                disc_loss += (((score[-1]) ** 2)).mean()
+                m_f = 0.9 * m_f + 0.1 * score[-1].detach().mean().cpu().item()
+                s_f = 0.9 * s_f + 0.1 * score[-1].detach().std().cpu().item()
+
+            for score in scores_real:
+                disc_loss += (((1 - score[-1]) ** 2)).mean()
+                m_r = 0.9 * m_r + 0.1 * score[-1].detach().mean().cpu().item()
+                s_r = 0.9 * s_r + 0.1 * score[-1].detach().std().cpu().item()
+
+            disc_loss = 0.5 * disc_loss / len(scores_gen)
+            winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
+
+            disc.zero_grad()
+            disc_loss.backward()
+
+            running_disc_grad_norm += get_grad_norm(disc).detach().cpu().item()
+
+            optimizer_disc.step()
+
+            # generator update
+            scores_gen = disc(output_preemph)
+
+            # calculate loss
+            loss_reg = criterion(target, output.squeeze(1), x_up)
+
+            num_discs = len(scores_gen)
+            gen_loss = 0
+            for score in  scores_gen:
+                gen_loss += (((1 - score[-1]) ** 2)).mean() / num_discs
+
+            loss_feat = 0
+            for k in range(num_discs):
+                num_layers = len(scores_gen[k]) - 1
+                f = 4 / num_discs / num_layers
+                for l in range(num_layers):
+                    loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
+
+            model.zero_grad()
+
+            (gen_loss + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
+
+            optimizer.step()
+
+            # sparsification
+            if hasattr(model, 'sparsifier'):
+                model.sparsifier()
+
+            running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
+            running_adv_loss += gen_loss.detach().cpu().item()
+            running_disc_loss += disc_loss.detach().cpu().item()
+            running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
+            running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
+
+            # update status bar
+            if i % log_interval == 0:
+                tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
+                                   disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
+                                   feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
+                                   reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
+                                   model_gradnorm=f"{running_model_grad_norm/(i+1):8.7f}",
+                                   disc_gradnorm=f"{running_disc_grad_norm/(i+1):8.7f}",
+                                   wc=f"{100*winning_chance:5.2f}%")
+
+
+    # save checkpoint
+    checkpoint['state_dict'] = model.state_dict()
+    checkpoint['disc_state_dict'] = disc.state_dict()
+    checkpoint['optimizer_state_dict'] = optimizer.state_dict()
+    checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
+    checkpoint['scheduler_state_dict'] = scheduler.state_dict()
+    checkpoint['torch_rng_state'] = torch.get_rng_state()
+    checkpoint['numpy_rng_state'] = np.random.get_state()
+    checkpoint['python_rng_state'] = random.getstate()
+    checkpoint['adv_loss']   = running_adv_loss/(i + 1)
+    checkpoint['disc_loss']  = running_disc_loss/(i + 1)
+    checkpoint['feature_loss'] = running_feature_loss/(i + 1)
+    checkpoint['reg_loss'] = running_reg_loss/(i + 1)
+
+
+    if inference_test:
+        print("running inference test...")
+        with torch.no_grad():
+            out = model(testsignal.to(device).view(1, 1, -1), features.to(device).unsqueeze(0)).cpu().squeeze().numpy()
+        wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 48000, (2**15 * out).astype(np.int16))
+
+
+    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
+    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
+
+
+    print()
+
+print('Done')
--- /dev/null
+++ b/dnn/torch/osce/bwe_preproc.py
@@ -1,0 +1,365 @@
+import os
+import argparse
+from typing import List
+
+import numpy as np
+from scipy import signal
+from scipy.io import wavfile
+import resampy
+
+
+import math as m
+
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument("filelist", type=str, help="file with filenames for concatenation in WAVE format")
+parser.add_argument("target_fs", type=int, help="target sampling rate of concatenated file")
+parser.add_argument("output", type=str, help="output directory")
+parser.add_argument("--basedir", type=str, help="basedir for filenames in filelist, defaults to ./", default="./")
+parser.add_argument("--normalize", action="store_true", help="apply normalization")
+parser.add_argument("--db_max", type=float, help="max DB for random normalization", default=0)
+parser.add_argument("--db_min", type=float, help="min DB for random normalization", default=0)
+parser.add_argument("--random_eq_prob", type=float, help="portion of items to which random eq will be applied (default: 0.4)", default=0.4)
+parser.add_argument("--static_noise_prob", type=float, help="portion of items to which static noise will be added (default: 0.2)", default=0.2)
+parser.add_argument("--random_dc_prob", type=float, help="portion of items to which random dc offset will be added (default: 0.1)", default=0.1)
+parser.add_argument("--rirdir", type=str, default=None, help="folder with room impulse responses in wav format (defaul: None)")
+parser.add_argument("--rir_prob", type=float, default=0.0, help="portion of items to which a random rir is applied (default: 0)")
+parser.add_argument("--verbose", action="store_true")
+
+def read_filelist(basedir, filelist):
+    with open(filelist, "r") as f:
+        files = f.readlines()
+
+    fullfiles = [os.path.join(basedir, f.rstrip('\n')) for f in files if len(f.rstrip('\n')) > 0]
+
+    return fullfiles
+
+def read_wave(file, target_fs):
+    fs, x = wavfile.read(file)
+
+    if fs < target_fs:
+        return None
+        print(f"[read_wave] warning: file {file} will be up-sampled from {fs} to {target_fs} Hz")
+
+    if fs != target_fs:
+        x = resampy.resample(x, fs, target_fs)
+
+    return x.astype(np.float32)
+
+def load_rirs(rirdir, target_fs):
+    """ read rirs (assumed .wav) from subfolders of rirdir """
+
+    rirs = []
+    for dirpath, dirnames, filenames in os.walk(rirdir):
+        for file in filenames:
+            if file.endswith(".wav"):
+                x = read_wave(os.path.join(dirpath, file), target_fs)
+                x = x / np.max(np.abs(x))
+                rirs.append(x)
+
+    return rirs
+
+
+lp_coeffs = signal.firwin(151, 20000, fs=48000)
+def apply_20kHz_lp(x, fs):
+    if fs != 48000:
+        return x
+
+    y = np.convolve(x, lp_coeffs, mode='valid')
+    y *= np.max(np.abs(x)) / np.max(np.abs(y) + 1e-6)
+
+    return y
+
+
+def random_normalize(x, db_min, db_max, max_val=2**15 - 1):
+    db = np.random.uniform(db_min, db_max, 1)
+    m = np.abs(x).max()
+    c = 10**(db/20) * max_val / m
+
+    return c * x
+
+def random_resamp16(x, fs=48000):
+    assert fs == 48000 and "only supporting 48kHz input sampling rate for now"
+
+    cutoff = 800 * np.random.rand(1) + 7200 # cutoff between 7.2 and 8 kHz
+    numtaps = 2 * np.random.randint(75, 150) + 1
+    a = signal.firwin(numtaps, cutoff, fs=fs)
+
+    x16 = np.convolve(x, a, mode='same')[::3]
+
+    return x16
+
+
+def estimate_bandwidth(x, fs):
+    assert fs >= 44100 and "currently only fs >= 44100 supported"
+    f, t, X = signal.stft(x, nperseg=960, fs=fs)
+    X = X.transpose()
+
+    X_pow = np.abs(X) ** 2
+
+    X_nrg = np.sum(X_pow, axis=1)
+    threshold = np.sort(X_nrg)[int(0.9 * len(X_nrg))] * 0.1
+    X_pow = X_pow[X_nrg > threshold]
+
+    i = 0
+    wb_nrg = 0
+    wb_bands = 0
+    while f[i] < 8000:
+        wb_nrg += np.sum(X_pow[:, i])
+        wb_bands += 1
+        i += 1
+    wb_nrg /= wb_bands
+
+    i += 5 # safety margin
+    swb_nrg = 0
+    swb_bands = 0
+    while f[i] < 16000:
+        swb_nrg += np.sum(X_pow[:, i])
+        swb_bands += 1
+        i += 1
+    swb_nrg /= swb_bands
+
+    i += 5 # safety margin
+    fb_nrg = 0
+    fb_bands = 0
+    while i < X_pow.shape[1]:
+        fb_nrg += np.sum(X_pow[:, i])
+        fb_bands += 1
+        i += 1
+    fb_nrg /= fb_bands
+
+
+    if swb_nrg / wb_nrg < 1e-5:
+        return 'wb'
+    elif fb_nrg / wb_nrg < 1e-7:
+        return 'swb'
+    else:
+        return 'fb'
+
+def _get_random_eq_filter(num_taps=51, min_gain=1/3, max_gain=3, cutoff=8000, fs=48000, num_bands=15):
+
+    nyquist = fs / 2
+    freqs = (np.arange(num_bands)) / (num_bands - 1)
+    cutoff = cutoff/nyquist
+    log_min_gain = m.log(min_gain)
+    log_max_gain = m.log(max_gain)
+    split = int(cutoff * (num_bands - 1)) + 1
+
+
+    log_gains =  np.random.rand(num_bands) * (log_max_gain - log_min_gain) + log_min_gain
+    low_band_mean = np.mean(log_gains[:split])
+    log_gains[:split] -= low_band_mean
+    log_gains[split:] = 0
+    gains = np.exp(log_gains)
+
+    taps = signal.firwin2(num_taps, freqs, gains, nfreqs=127)
+
+
+    return taps
+
+def trim_silence(x, fs, threshold=0.005):
+    frame_size = 320 * fs // 16000
+
+    num_frames = len(x) // frame_size
+    y = x[: frame_size * num_frames]
+
+    frame_nrg = np.sum(y.reshape(-1, frame_size) ** 2, axis=1)
+    ref_nrg = np.sort(frame_nrg)[int(num_frames * 0.9)]
+    silence_threshold = threshold * ref_nrg
+
+    for i, nrg in enumerate(frame_nrg):
+        if nrg > silence_threshold:
+            break
+
+    first_active_frame_index = i
+
+    for i in range(num_frames - 1, -1, -1):
+        if frame_nrg[i] > silence_threshold:
+            break
+
+    last_active_frame_index = i
+
+    i_start = max(first_active_frame_index - 20, 0) * frame_size
+    i_stop = min(last_active_frame_index + 20, num_frames - 1) * frame_size
+
+    return x[i_start:i_stop]
+
+
+
+def random_eq(x, fs, cutoff):
+    taps = _get_random_eq_filter(fs=fs, cutoff=cutoff)
+    y = np.convolve(taps, x.astype(np.float32))
+
+    # rescale
+    y *= np.max(np.abs(x)) / np.max(np.abs(y + 1e-9))
+
+    return y
+
+def static_lowband_noise(x, fs, cutoff, max_gain=0.02):
+    k_lp = (5 * fs // 16000)
+    lp_taps = signal.firwin(2 * k_lp + 1, 2 * cutoff / fs)
+    eq_taps = _get_random_eq_filter(num_bands=9)
+
+    noise = np.random.randn(len(x) + len(lp_taps) + len(eq_taps) - 2)
+    noise = np.convolve(noise, lp_taps, mode='valid')
+    noise = np.convolve(noise, eq_taps, mode='valid')
+
+    gain = np.random.rand(1) * max_gain
+
+    x_max = np.max(np.abs(x))
+
+    noise *= gain * x_max / np.max(np.abs(noise))
+
+    y = x + noise
+    y *= x_max / np.max(np.abs(y + 1e-9))
+
+    return y
+
+def apply_random_rir(x, rirs, rescale=True):
+    i = np.random.randint(0, len(rirs))
+    y = np.convolve(x, rirs[i], mode='same')
+    if rescale: y *= np.max(np.abs(x)) / np.max(np.abs(y) + 1e-6)
+    return y
+
+
+def random_dc_offset(x, max_rel_offset=0.03):
+    x_max = np.max(np.abs(x))
+    offset = x_max * (2 * np.random.rand(1) - 1) * max_rel_offset
+
+    y = x + offset
+    y *= x_max / np.max(np.abs(y + 1e-9))
+
+    return y
+
+
+def concatenate(filelist : str,
+                outdir : str,
+                target_fs : int,
+                normalize : bool=True,
+                db_min : float=0,
+                db_max : float=0,
+                rand_eq_prob : float=0,
+                static_noise_prob: float=0,
+                rand_dc_prob : float=0,
+                rirs : List = None,
+                rir_prob : float = 0,
+                verbose=False):
+
+    overlap_size = int(40 * target_fs / 8000)
+    overlap_mem = np.zeros(overlap_size, dtype=np.float32)
+    overlap_win1 = (0.5 + 0.5 * np.cos(np.arange(0, overlap_size) * np.pi / overlap_size)).astype(np.float32)
+    overlap_win2 = np.flipud(overlap_win1)
+
+    # same for 16 kHz
+    assert overlap_size % 3 == 0
+    overlap_size16 = overlap_size // 3
+    overlap_mem16 = np.zeros(overlap_size16, dtype=np.float32)
+    overlap_win1_16 = overlap_win1[::3]
+    overlap_win2_16 = np.flipud(overlap_win1_16)
+
+    output48 = os.path.join(outdir, 'signal_48kHz.s16')
+    output16 = os.path.join(outdir, 'signal_16kHz.s16')
+    os.makedirs(outdir, exist_ok=True)
+
+    with open(output48, 'wb') as f48, open(output16, 'wb') as f16:
+        for file in filelist:
+            x = read_wave(file, target_fs)
+            if x is None: continue
+
+            x = trim_silence(x, target_fs)
+
+            x = apply_20kHz_lp(x, target_fs)
+
+            bwidth = estimate_bandwidth(x, target_fs)
+            if bwidth != 'fb':
+                if verbose: print(f"bandwidth {bwidth} detected: skipping {file}...")
+                continue
+
+            if len(x) < 10 * overlap_size:
+                if verbose: print(f"skipping {file}...")
+                continue
+            elif verbose:
+                print(f"processing {file}...")
+
+            noise_first = np.random.randint(2)
+
+            if np.random.rand(1) < rand_eq_prob:
+                x = random_eq(x, target_fs, 5000)
+
+            if not noise_first:
+                if np.random.rand(1) < rir_prob:
+                    x = apply_random_rir(x, rirs)
+
+            if np.random.rand(1) < static_noise_prob:
+                x = static_lowband_noise(x, target_fs, 8000, max_gain=0.01)
+
+            if noise_first:
+                if np.random.rand(1) < rir_prob:
+                    x = apply_random_rir(x, rirs)
+
+            if np.random.rand(1) < rand_dc_prob:
+                x = random_dc_offset(x)
+
+            # trim final signal to length divisible by 3 to keep 16 and 48 kHz signals in sync
+            x = x[:len(x) - (len(x) % 3)]
+
+            if normalize:
+                x = random_normalize(x, db_min, db_max)
+
+            # write 48 and 16 kHz signals to disk
+            if False:
+                x1 = x[:-overlap_size]
+                x1[:overlap_size] = overlap_win1 * overlap_mem + overlap_win2 * x1[:overlap_size]
+                f48.write(x1.astype(np.int16).tobytes())
+
+                x16 = random_resamp16(x)
+                x1_16 = x16[:-overlap_size16]
+                x1_16[:overlap_size16] = overlap_win1_16 * overlap_mem16 + overlap_win2_16 * x1_16[:overlap_size16]
+                f16.write(x1_16.astype(np.int16).tobytes())
+
+                # memory update
+                overlap_mem = x[-overlap_size:]
+                overlap_mem16 = x16[-overlap_size16:]
+            else:
+                # window and zero pad signal
+                padding_samples = 3 * 100
+                x[:overlap_size]  *= overlap_win2 # fade in
+                x[-overlap_size:] *= overlap_win1 # fade out
+
+                x = np.concatenate((np.zeros(padding_samples), x, np.zeros(padding_samples)), dtype=x.dtype)
+
+                x16 = random_resamp16(x)
+
+                assert 3*len(x16) == len(x)
+                if np.max(x) > 2**15 - 1 or np.min(x) < -2**15: print("clipping")
+                if np.max(x16) > 2**15 - 1 or np.min(x16) < -2**15: print("clipping")
+                x = np.clip(x, -2**15, 2**15 - 1)
+                x16 = np.clip(x16, -2**15, 2**15 - 1)
+                f48.write(x.astype(np.int16).tobytes())
+                f16.write(x16.astype(np.int16).tobytes())
+
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+
+    filelist = read_filelist(args.basedir, args.filelist)
+
+    if args.rirdir is not None:
+        rirs = load_rirs(args.rirdir, args.target_fs)
+    else:
+        rirs = []
+
+    concatenate(filelist,
+                args.output,
+                args.target_fs,
+                normalize=args.normalize,
+                db_min=args.db_min,
+                db_max=args.db_max,
+                rand_eq_prob=args.random_eq_prob,
+                static_noise_prob=args.static_noise_prob,
+                rand_dc_prob=args.random_dc_prob,
+                rirs=rirs,
+                rir_prob=args.rir_prob,
+                verbose=args.verbose)
--- /dev/null
+++ b/dnn/torch/osce/concatenator.py
@@ -1,0 +1,85 @@
+import os
+import argparse
+
+import numpy as np
+from scipy import signal
+from scipy.io import wavfile
+import resampy
+
+
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument("filelist", type=str, help="file with filenames for concatenation in WAVE format")
+parser.add_argument("target_fs", type=int, help="target sampling rate of concatenated file")
+parser.add_argument("output", type=str, help="binary output file (integer16)")
+parser.add_argument("--basedir", type=str, help="basedir for filenames in filelist, defaults to ./", default="./")
+parser.add_argument("--normalize", action="store_true", help="apply normalization")
+parser.add_argument("--db_max", type=float, help="max DB for random normalization", default=0)
+parser.add_argument("--db_min", type=float, help="min DB for random normalization", default=0)
+parser.add_argument("--verbose", action="store_true")
+
+def read_filelist(basedir, filelist):
+    with open(filelist, "r") as f:
+        files = f.readlines()
+
+    fullfiles = [os.path.join(basedir, f.rstrip('\n')) for f in files if len(f.rstrip('\n')) > 0]
+
+    return fullfiles
+
+def read_wave(file, target_fs):
+    fs, x = wavfile.read(file)
+
+    if fs < target_fs:
+        return None
+        print(f"[read_wave] warning: file {file} will be up-sampled from {fs} to {target_fs} Hz")
+
+    if fs != target_fs:
+        x = resampy.resample(x, fs, target_fs)
+
+    return x.astype(np.float32)
+
+def random_normalize(x, db_min, db_max, max_val=2**15 - 1):
+    db = np.random.uniform(db_min, db_max, 1)
+    m = np.abs(x).max()
+    c = 10**(db/20) * max_val / m
+
+    return c * x
+
+
+def concatenate(filelist : str, output : str, target_fs: int, normalize=True, db_min=0, db_max=0, verbose=False):
+
+    overlap_size = int(40 * target_fs / 8000)
+    overlap_mem = np.zeros(overlap_size, dtype=np.float32)
+    overlap_win1 = (0.5 + 0.5 * np.cos(np.arange(0, overlap_size) * np.pi / overlap_size)).astype(np.float32)
+    overlap_win2 = np.flipud(overlap_win1)
+
+    with open(output, 'wb') as f:
+        for file in filelist:
+            x = read_wave(file, target_fs)
+            if x is None: continue
+
+            if len(x) < 10 * overlap_size:
+                if verbose: print(f"skipping {file}...")
+                continue
+            elif verbose:
+                print(f"processing {file}...")
+
+            if normalize:
+                x = random_normalize(x, db_min, db_max)
+
+            x1 = x[:-overlap_size]
+            x1[:overlap_size] = overlap_win1 * overlap_mem + overlap_win2 * x1[:overlap_size]
+
+            f.write(x1.astype(np.int16).tobytes())
+
+            overlap_mem = x[-overlap_size:]
+
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+
+    filelist = read_filelist(args.basedir, args.filelist)
+
+    concatenate(filelist, args.output, args.target_fs, normalize=args.normalize, db_min=args.db_min, db_max=args.db_max, verbose=args.verbose)
--- a/dnn/torch/osce/data/__init__.py
+++ b/dnn/torch/osce/data/__init__.py
@@ -1,2 +1,3 @@
 from .silk_enhancement_set import SilkEnhancementSet
-from .lpcnet_vocoding_dataset import LPCNetVocodingDataset
\ No newline at end of file
+from .lpcnet_vocoding_dataset import LPCNetVocodingDataset
+from .simple_bwe_dataset import SimpleBWESet
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/data/simple_bwe_dataset.py
@@ -1,0 +1,93 @@
+"""
+/* Copyright (c) 2024 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+
+from torch.utils.data import Dataset
+import numpy as np
+
+from utils.bwe_features import bwe_feature_factory
+
+
+class SimpleBWESet(Dataset):
+    FRAME_SIZE_16K = 160
+    def __init__(self,
+                 path,
+                 frames_per_sample=100,
+                 spec_num_bands=32,
+                 max_instafreq_bin=40,
+                 upsampling_delay48=13,
+                 ):
+
+        self.frames_per_sample = frames_per_sample
+        self.upsampling_delay48 = upsampling_delay48
+
+        self.signal_16k = np.fromfile(os.path.join(path, 'signal_16kHz.s16'), dtype=np.int16)
+        self.signal_48k = np.fromfile(os.path.join(path, 'signal_48kHz.s16'), dtype=np.int16)
+
+        num_frames = min(len(self.signal_16k) // self.FRAME_SIZE_16K,
+                         len(self.signal_48k) // (3 * self.FRAME_SIZE_16K))
+
+        self.create_features = bwe_feature_factory(spec_num_bands=spec_num_bands, max_instafreq_bin=max_instafreq_bin)
+
+        self.frame_offset = 6
+
+        self.len = (num_frames - self.frame_offset) // frames_per_sample
+
+    def __len__(self):
+        return self.len
+
+    def __getitem__(self, index):
+
+        frame_start = self.frames_per_sample * index + self.frame_offset
+        frame_stop  = frame_start + self.frames_per_sample
+
+        signal_start16 = frame_start * self.FRAME_SIZE_16K
+        signal_stop16  = frame_stop  * self.FRAME_SIZE_16K
+
+        x_16 = self.signal_16k[signal_start16 : signal_stop16].astype(np.float32) / 2**15
+        history_16 = self.signal_16k[signal_start16 - 320 : signal_start16].astype(np.float32) / 2**15
+
+        # dithering
+        x_16 += (np.random.rand(len(x_16)) - 0.5) / 2**15
+        history_16 += (np.random.rand(len(history_16)) - 0.5) / 2**15
+
+        x_48 = self.signal_48k[3 * signal_start16 - self.upsampling_delay48
+                               : 3 * signal_stop16 - self.upsampling_delay48].astype(np.float32) / 2**15
+
+        features = self.create_features(
+              x_16,
+              history_16
+        )
+
+        return {
+            'features'    : features,
+            'x_16'        : x_16.astype(np.float32),
+            'x_48'        : x_48.astype(np.float32),
+            }
--- /dev/null
+++ b/dnn/torch/osce/engine/bwe_engine.py
@@ -1,0 +1,112 @@
+import torch
+from tqdm import tqdm
+import sys
+
+def preemph(x, gamma):
+    y = torch.cat((x[..., 0:1], x[..., 1:] - gamma * x[...,:-1]), dim=-1)
+    return y
+
+def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, preemph_gamma=0, log_interval=10):
+
+    model.to(device)
+    model.train()
+
+    running_loss = 0
+    previous_running_loss = 0
+
+
+    with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
+
+        for i, batch in enumerate(tepoch):
+
+            # set gradients to zero
+            optimizer.zero_grad()
+
+            # push batch to device
+            for key in batch:
+                batch[key] = batch[key].to(device)
+
+            target = batch['x_48']
+            x16 = batch['x_16']
+            x_up = model.upsampler(x16.unsqueeze(1))
+
+            # calculate model output
+            output = model(batch['x_16'].unsqueeze(1), batch['features'])
+
+            # pre-emphasize
+            target = preemph(target, preemph_gamma)
+            x_up = preemph(x_up, preemph_gamma)
+            output = preemph(output, preemph_gamma)
+
+            # calculate loss
+            loss = criterion(target, output.squeeze(1), x_up)
+
+            # calculate gradients
+            loss.backward()
+
+            # update weights
+            optimizer.step()
+
+            # update learning rate
+            scheduler.step()
+
+            # sparsification
+            if hasattr(model, 'sparsifier'):
+                model.sparsifier()
+
+            # update running loss
+            running_loss += float(loss.cpu())
+
+            # update status bar
+            if i % log_interval == 0:
+                tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
+                previous_running_loss = running_loss
+
+
+    running_loss /= len(dataloader)
+
+    return running_loss
+
+def evaluate(model, criterion, dataloader, device, preemph_gamma=0, log_interval=10):
+
+    model.to(device)
+    model.eval()
+
+    running_loss = 0
+    previous_running_loss = 0
+
+    with torch.no_grad():
+        with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
+
+            for i, batch in enumerate(tepoch):
+
+                # push batch to device
+                for key in batch:
+                    batch[key] = batch[key].to(device)
+
+                target = batch['x_48']
+                x_up = model.upsampler(batch['x_16'].unsqueeze(1))
+
+                # calculate model output
+                output = model(batch['x_16'].unsqueeze(1), batch['features'])
+
+                # pre-emphasize
+                target = preemph(target, preemph_gamma)
+                x_up = preemph(x_up, preemph_gamma)
+                output = preemph(output, preemph_gamma)
+
+                # calculate loss
+                loss = criterion(target, output.squeeze(1), x_up)
+
+                # update running loss
+                running_loss += float(loss.cpu())
+
+                # update status bar
+                if i % log_interval == 0:
+                    tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
+                    previous_running_loss = running_loss
+
+
+        running_loss /= len(dataloader)
+
+        return running_loss
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/extract_setup.py
@@ -1,0 +1,18 @@
+import torch
+import yaml
+import argparse
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('checkpoint', type=str, help='model checkpoint')
+parser.add_argument('setup', type=str, help='setup filename')
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+
+    ckpt = torch.load(args.checkpoint, map_location='cpu')
+
+    setup = ckpt['setup']
+
+    with open(args.setup, "w") as f:
+        yaml.dump(setup, f)
\ No newline at end of file
--- a/dnn/torch/osce/losses/td_lowpass.py
+++ b/dnn/torch/osce/losses/td_lowpass.py
@@ -9,19 +9,21 @@
         super().__init__()
 
         self.b = scipy.signal.firwin(numtaps, cutoff)
-        self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
+        self.weight = torch.nn.Parameter(torch.from_numpy(self.b).float().view(1, 1, -1), requires_grad=False)
         self.power = power
 
     def forward(self, y_true, y_pred):
 
-        assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
+        if len(y_true.shape) < 3: y_true = y_true.unsqueeze(1)
+        if len(y_pred.shape) < 3: y_pred = y_pred.unsqueeze(1)
 
         diff = y_true - y_pred
         diff_lp = torch.nn.functional.conv1d(diff, self.weight)
 
-        loss = torch.mean(torch.abs(diff_lp ** self.power))
+        loss = torch.mean(torch.abs(diff_lp) ** self.power) / (torch.mean(torch.abs(y_true) ** self.power) + 1e-6**self.power)
+        loss = loss ** 1/self.power
 
-        return loss, diff_lp
+        return loss
 
     def get_freqz(self):
         freq, response = scipy.signal.freqz(self.b)
--- a/dnn/torch/osce/make_default_setup.py
+++ b/dnn/torch/osce/make_default_setup.py
@@ -66,7 +66,7 @@
 parser = argparse.ArgumentParser()
 
 parser.add_argument('name', type=str, help='name of default setup file')
-parser.add_argument('--model', choices=['lace', 'nolace', 'lavoce'], help='model name', default='lace')
+parser.add_argument('--model', choices=['lace', 'nolace', 'lavoce', 'bwenet', 'bbwenet'], help='model name', default='lace')
 parser.add_argument('--adversarial', action='store_true', help='setup for adversarial training')
 parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
 
--- a/dnn/torch/osce/models/__init__.py
+++ b/dnn/torch/osce/models/__init__.py
@@ -32,6 +32,9 @@
 from .lavoce import LaVoce
 from .lavoce_400 import LaVoce400
 from .fd_discriminator import TFDMultiResolutionDiscriminator as FDMResDisc
+from .td_discriminator import TDMultiResolutionDiscriminator as TDMResDisc
+from .bwe_net import BWENet
+from .bbwe_net import BBWENet
 
 model_dict = {
     'lace': LACE,
@@ -39,4 +42,7 @@
     'lavoce': LaVoce,
     'lavoce400': LaVoce400,
     'fdmresdisc': FDMResDisc,
+    'tdmresdisc': TDMResDisc,
+    'bwenet' : BWENet,
+    'bbwenet': BBWENet
 }
--- /dev/null
+++ b/dnn/torch/osce/models/bbwe_net.py
@@ -1,0 +1,257 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.complexity import _conv1d_flop_count
+from utils.layers.silk_upsampler import SilkUpsampler
+from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
+from utils.layers.td_shaper import TDShaper
+from dnntools.quantization.softquant import soft_quant
+
+DUMP=False
+
+if DUMP:
+    from scipy.io import wavfile
+    import numpy as np
+    import os
+
+    os.makedirs('dump', exist_ok=True)
+
+    def dump_as_wav(filename, fs, x):
+        s  = x.cpu().squeeze().flatten().numpy()
+        s = 0.5 * s / s.max()
+        wavfile.write(filename, fs, (2**15 * s).astype(np.int16))
+
+
+
+class FloatFeatureNet(nn.Module):
+
+    def __init__(self,
+                 feature_dim=84,
+                 num_channels=256,
+                 upsamp_factor=2,
+                 lookahead=False,
+                 softquant=False):
+
+        super().__init__()
+
+        self.feature_dim = feature_dim
+        self.num_channels = num_channels
+        self.upsamp_factor = upsamp_factor
+        self.lookahead = lookahead
+
+        self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
+        self.conv2 = nn.Conv1d(num_channels, num_channels, 3)
+
+        self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
+
+        self.tconv = nn.ConvTranspose1d(num_channels, num_channels, upsamp_factor, upsamp_factor)
+
+        if softquant:
+            self.conv2 = soft_quant(self.conv2)
+            self.gru   = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
+            self.tconv = soft_quant(self.tconv)
+
+    def flop_count(self, rate=100):
+        count = 0
+        for conv in self.conv1, self.conv2, self.tconv:
+            count += _conv1d_flop_count(conv, rate)
+
+        count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * self.upsamp_factor * rate
+
+        return count
+
+
+    def forward(self, features, state=None):
+        """ features shape: (batch_size, num_frames, feature_dim) """
+
+        batch_size = features.size(0)
+
+        if state is None:
+            state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
+
+
+        features = features.permute(0, 2, 1)
+        if self.lookahead:
+            c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
+            c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
+        else:
+            c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
+            c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
+
+        c = torch.tanh(self.tconv(c))
+
+        c = c.permute(0, 2, 1)
+
+        c, _ = self.gru(c, state)
+
+        return c
+
+
+class Folder(torch.nn.Module):
+    def __init__(self, num_taps, frame_size):
+        super().__init__()
+
+        self.num_taps = num_taps
+        self.frame_size = frame_size
+        assert frame_size % num_taps == 0
+        self.taps = torch.nn.Parameter(torch.randn(num_taps).view(1, 1, -1), requires_grad=True)
+
+
+    def flop_count(self, rate):
+
+        # single multiplication per sample
+        return rate
+
+    def forward(self, x, *args):
+
+        batch_size, num_channels, length = x.shape
+        assert length % self.num_taps == 0
+
+        y = x * torch.repeat_interleave(torch.exp(self.taps), length // self.num_taps, dim=-1)
+
+        return y
+
+class BBWENet(torch.nn.Module):
+    FRAME_SIZE16k=80
+
+    def __init__(self,
+                 feature_dim,
+                 cond_dim=128,
+                 kernel_size16=15,
+                 kernel_size32=15,
+                 kernel_size48=15,
+                 conv_gain_limits_db=[-12, 12], # might be a bit tight
+                 activation="ImPowI",
+                 avg_pool_k32 = 8,
+                 avg_pool_k48 = 12,
+                 interpolate_k32=1,
+                 interpolate_k48=1,
+                 shape_extension=True,
+                 func_extension=True,
+                 shaper='TDShaper',
+                 bias=False,
+                 softquant=False,
+                 lookahead=False,
+                 ):
+
+        super().__init__()
+
+
+        self.feature_dim            = feature_dim
+        self.cond_dim               = cond_dim
+        self.kernel_size16          = kernel_size16
+        self.kernel_size32          = kernel_size32
+        self.kernel_size48          = kernel_size48
+        self.conv_gain_limits_db    = conv_gain_limits_db
+        self.activation             = activation
+        self.shape_extension        = shape_extension
+        self.func_extension         = func_extension
+        self.shaper                 = shaper
+
+        assert (shape_extension or func_extension) and "Require at least one of shape_extension and func_extension to be true"
+
+
+        self.frame_size16 = 1 * self.FRAME_SIZE16k
+        self.frame_size32 = 2 * self.FRAME_SIZE16k
+        self.frame_size48 = 3 * self.FRAME_SIZE16k
+
+        # upsampler
+        self.upsampler = SilkUpsampler()
+
+        # feature net
+        self.feature_net = FloatFeatureNet(feature_dim=feature_dim, num_channels=cond_dim, softquant=softquant, lookahead=lookahead)
+
+        # non-linear transforms
+
+        if self.shape_extension:
+            if self.shaper == 'TDShaper':
+                self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k32, interpolate_k=interpolate_k32, bias=bias, softquant=softquant)
+                self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k48, interpolate_k=interpolate_k48, bias=bias, softquant=softquant)
+            elif self.shaper == 'Folder':
+                self.tdshape1 = Folder(8, frame_size=self.frame_size32)
+                self.tdshape2 = Folder(12, frame_size=self.frame_size48)
+            else:
+                raise ValueError(f"unknown shaper {self.shaper}")
+
+        if activation == 'ImPowI':
+            self.nlfunc = lambda x : x * torch.sin(torch.log(torch.abs(x) + 1e-6))
+        elif activation == "ReLU":
+            self.nlfunc = F.relu
+        else:
+            raise ValueError(f"unknown activation {activation}")
+
+        latent_channels = 1
+        if self.shape_extension: latent_channels += 1
+        if self.func_extension: latent_channels += 1
+
+        # spectral shaping
+        self.af1 = LimitedAdaptiveConv1d(1, latent_channels, self.kernel_size16, cond_dim, frame_size=self.frame_size16, overlap_size=self.frame_size16//2, use_bias=False, padding=[self.kernel_size16 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2, softquant=softquant)
+        self.af2 = LimitedAdaptiveConv1d(latent_channels, latent_channels, self.kernel_size32, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size32 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2, softquant=softquant)
+        self.af3 = LimitedAdaptiveConv1d(latent_channels, 1, self.kernel_size48, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size48 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2, softquant=softquant)
+
+
+    def flop_count(self, rate=16000, verbose=False):
+
+        frame_rate = rate / self.FRAME_SIZE16k
+
+        # feature net
+        feature_net_flops = self.feature_net.flop_count(frame_rate // 2)
+        af_flops = self.af1.flop_count(rate) + self.af2.flop_count(2 * rate) + self.af3.flop_count(3 * rate)
+
+        if self.shape_extension:
+            shape_flops = self.tdshape1.flop_count(2*rate) + self.tdshape2.flop_count(3*rate)
+        else:
+            shape_flops = 0
+
+        if verbose:
+            print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
+            print(f"shape flops: {shape_flops / 1e6} MFLOPS")
+            print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
+
+        return feature_net_flops + af_flops + shape_flops
+
+    def forward(self, x, features, debug=False):
+
+        cf = self.feature_net(features)
+
+        # split into latent_channels channels
+        y16 = self.af1(x, cf, debug=debug)
+
+        # first 2x upsampling step
+        y32 = self.upsampler.hq_2x_up(y16)
+        y32_out = y32[:, 0:1, :] # first channel is bypass channel
+
+        # extend frequencies
+        idx = 1
+        if self.shape_extension:
+            y32_shape = self.tdshape1(y32[:, idx:idx+1, :], cf)
+            y32_out = torch.cat((y32_out, y32_shape), dim=1)
+            idx += 1
+
+        if self.func_extension:
+            y32_func = self.nlfunc(y32[:, idx:idx+1, :])
+            y32_out = torch.cat((y32_out, y32_func), dim=1)
+
+        # mix-select
+        y32_out = self.af2(y32_out, cf)
+
+        # 1.5x upsampling
+        y48 = self.upsampler.interpolate_3_2(y32_out)
+        y48_out = y48[:, 0:1, :] # first channel is bypass channel
+
+        # extend frequencies
+        idx = 1
+        if self.shape_extension:
+            y48_shape = self.tdshape2(y48[:, idx:idx+1, :], cf)
+            y48_out = torch.cat((y48_out, y48_shape), dim=1)
+            idx += 1
+
+        if self.func_extension:
+            y48_func = self.nlfunc(y48[:, idx:idx+1, :])
+            y48_out = torch.cat((y48_out, y48_func), dim=1)
+
+        # 2nd mixing
+        y48_out = self.af3(y48_out, cf)
+
+        return y48_out
--- /dev/null
+++ b/dnn/torch/osce/models/bwe_net.py
@@ -1,0 +1,262 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.complexity import _conv1d_flop_count
+from utils.layers.silk_upsampler import SilkUpsampler
+from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
+from utils.layers.td_shaper import TDShaper
+
+
+DUMP=False
+
+if DUMP:
+    from scipy.io import wavfile
+    import numpy as np
+    import os
+
+    os.makedirs('dump', exist_ok=True)
+
+    def dump_as_wav(filename, fs, x):
+        s  = x.cpu().squeeze().flatten().numpy()
+        s = 0.5 * s / s.max()
+        wavfile.write(filename, fs, (2**15 * s).astype(np.int16))
+
+
+
+class FloatFeatureNet(nn.Module):
+
+    def __init__(self,
+                 feature_dim=84,
+                 num_channels=256,
+                 upsamp_factor=2,
+                 lookahead=False):
+
+        super().__init__()
+
+        self.feature_dim = feature_dim
+        self.num_channels = num_channels
+        self.upsamp_factor = upsamp_factor
+        self.lookahead = lookahead
+
+        self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
+        self.conv2 = nn.Conv1d(num_channels, num_channels, 3)
+
+        self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
+
+        self.tconv = nn.ConvTranspose1d(num_channels, num_channels, upsamp_factor, upsamp_factor)
+
+    def flop_count(self, rate=100):
+        count = 0
+        for conv in self.conv1, self.conv2, self.tconv:
+            count += _conv1d_flop_count(conv, rate)
+
+        count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
+
+        return count
+
+
+    def forward(self, features, state=None):
+        """ features shape: (batch_size, num_frames, feature_dim) """
+
+        batch_size = features.size(0)
+
+        if state is None:
+            state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
+
+
+        features = features.permute(0, 2, 1)
+        if self.lookahead:
+            c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
+            c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
+        else:
+            c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
+            c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
+
+        c = torch.tanh(self.tconv(c))
+
+        c = c.permute(0, 2, 1)
+
+        c, _ = self.gru(c, state)
+
+        return c
+
+def sawtooth(x):
+    return 2 * torch.frac(0.5 * x / torch.pi) - 1
+
+class BWENet(torch.nn.Module):
+    FRAME_SIZE16k=80
+
+    def __init__(self,
+                 feature_dim,
+                 cond_dim=128,
+                 kernel_size32=15,
+                 kernel_size48=15,
+                 conv_gain_limits_db=[-12, 12],
+                 activation="AdaShape",
+                 avg_pool_k32 = 8,
+                 avg_pool_k48=12,
+                 interpolate_k32=1,
+                 interpolate_k48=1,
+                 use_noise_shaper=False,
+                 use_extra_nl=False,
+                 disable_bias=False
+                 ):
+
+        super().__init__()
+
+
+        self.feature_dim            = feature_dim
+        self.cond_dim               = cond_dim
+        self.kernel_size32          = kernel_size32
+        self.kernel_size48          = kernel_size48
+        self.conv_gain_limits_db    = conv_gain_limits_db
+        self.activation             = activation
+        self.use_noise_shaper       = use_noise_shaper
+        self.use_extra_nl           = use_extra_nl
+
+        self.frame_size32 = 2 * self.FRAME_SIZE16k
+        self.frame_size48 = 3 * self.FRAME_SIZE16k
+
+        # upsampler
+        self.upsampler = SilkUpsampler()
+
+        # feature net
+        self.feature_net = FloatFeatureNet(feature_dim=feature_dim, num_channels=cond_dim)
+
+        # non-linear transforms
+        if activation == "AdaShape":
+            self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k32, interpolate_k=interpolate_k32, bias=not disable_bias)
+            self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k48, interpolate_k=interpolate_k48, bias=not disable_bias)
+            self.act1 = self.tdshape1
+            self.act2 = self.tdshape2
+        elif activation == "ReLU":
+            self.act1 = lambda x, _: F.relu(x)
+            self.act2 = lambda x, _: F.relu(x)
+        elif activation == "Power":
+            self.extaf1 = LimitedAdaptiveConv1d(1, 1, 5, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[4, 0], gain_limits_db=conv_gain_limits_db, norm_p=2, expansion_power=3)
+            self.extaf2 = LimitedAdaptiveConv1d(1, 1, 5, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[4, 0], gain_limits_db=conv_gain_limits_db, norm_p=2, expansion_power=3)
+            self.act1 = self.extaf1
+            self.act2 = self.extaf2
+        elif activation == "ImPowI":
+            self.act1 = lambda x, _ : x * torch.sin(torch.log((2**15) * torch.abs(x) + 1e-6))
+            self.act2 = lambda x, _ : x * torch.sin(torch.log((2**15) * torch.abs(x) + 1e-6))
+        elif activation == "SawLog":
+            self.act1 = lambda x, _ : x * sawtooth(torch.log((2**15) * torch.abs(x) + 1e-6))
+            self.act2 = lambda x, _ : x * sawtooth(torch.log((2**15) * torch.abs(x) + 1e-6))
+        else:
+            raise ValueError(f"unknown activation {activation}")
+
+        if self.use_noise_shaper:
+            self.nshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k32, interpolate_k=2, noise_substitution=True, cutoff=0.45)
+            self.nshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k48, interpolate_k=2, noise_substitution=True, cutoff=0.6)
+            latent_channels = 3
+        elif use_extra_nl:
+            latent_channels = 3
+            self.extra_nl = lambda x: x * torch.sin(torch.log((2**15) * torch.abs(x) + 1e-6))
+        else:
+            latent_channels = 2
+
+        # spectral shaping
+        self.af1 = LimitedAdaptiveConv1d(1, latent_channels, self.kernel_size32, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size32 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2)
+        self.af2 = LimitedAdaptiveConv1d(latent_channels, 1, self.kernel_size32, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size32 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2)
+        self.af3 = LimitedAdaptiveConv1d(1, latent_channels, self.kernel_size48, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size48 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2)
+        self.af4 = LimitedAdaptiveConv1d(latent_channels, 1, self.kernel_size48, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size48 - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=2)
+
+
+    def flop_count(self, rate=16000, verbose=False):
+
+        frame_rate = rate / self.FRAME_SIZE16k
+
+        # feature net
+        feature_net_flops = self.feature_net.flop_count(frame_rate)
+        af_flops = self.af1.flop_count(rate) + self.af2.flop_count(2 * rate) + self.af3.flop_count(3 * rate) + + self.af4.flop_count(3 * rate)
+
+        if self.activation == 'AdaShape':
+            shape_flops = self.act1.flop_count(2*rate) + self.act2.flop_count(3*rate)
+        else:
+            shape_flops = 0
+
+        if verbose:
+            print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
+            print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
+
+        return feature_net_flops + af_flops + shape_flops
+
+    def forward(self, x, features, debug=False):
+
+        cf = self.feature_net(features)
+
+        # first 2x upsampling step
+        y32 = self.upsampler.hq_2x_up(x)
+        if DUMP:
+            dump_as_wav('dump/y32_in.wav', 32000, y32)
+
+        # split
+        y32 = self.af1(y32, cf, debug=debug)
+
+        # activation
+        y32_1 = y32[:, 0:1, :]
+        y32_2 = self.act1(y32[:, 1:2, :], cf)
+        if DUMP:
+            dump_as_wav('dump/y32_1.wav', 32000,  y32_1)
+            dump_as_wav('dump/y32_2pre.wav', 32000,  y32[:, 1:2, :])
+            dump_as_wav('dump/y32_2act.wav', 32000,  y32_2)
+
+        if self.use_noise_shaper:
+            y32_3 = self.nshape1(y32[:, 2:3, :], cf)
+            if DUMP:
+                dump_as_wav('dump/y32_3pre.wav', 32000,  y32[:, 2:3, :])
+                dump_as_wav('dump/y32_3act.wav', 32000,  y32_3)
+            y32 = torch.cat((y32_1, y32_2, y32_3), dim=1)
+        elif self.use_extra_nl:
+            y32_3 = self.extra_nl(y32[:, 2:3, :])
+            if DUMP:
+                dump_as_wav('dump/y32_3pre.wav', 32000,  y32[:, 2:3, :])
+                dump_as_wav('dump/y32_3act.wav', 32000,  y32_3)
+            y32 = torch.cat((y32_1, y32_2, y32_3), dim=1)
+        else:
+            y32 = torch.cat((y32_1, y32_2), dim=1)
+
+        # mix
+        y32 = self.af2(y32, cf, debug=debug)
+
+        # 1.5x interpolation
+        y48 = self.upsampler.interpolate_3_2(y32)
+        if DUMP:
+            dump_as_wav('dump/y48_in.wav', 48000, y48)
+
+        # split
+        y48 = self.af3(y48, cf, debug=debug)
+
+        # activate
+        y48_1 = y48[:, 0:1, :]
+        y48_2 = self.act2(y48[:, 1:2, :], cf)
+        if DUMP:
+            dump_as_wav('dump/y48_1.wav', 48000, y48_1)
+            dump_as_wav('dump/y48_2pre.wav', 48000, y48[:, 1:2, :])
+            dump_as_wav('dump/y48_2act.wav', 48000, y48_2)
+
+        if self.use_noise_shaper:
+            y48_3 = self.nshape2(y48[:, 2:3, :], cf)
+            if DUMP:
+                dump_as_wav('dump/y48_3pre.wav', 48000, y48[:, 2:3, :])
+                dump_as_wav('dump/y48_3act.wav', 48000, y48_3)
+
+        elif self.use_extra_nl:
+            y48_3 = self.extra_nl(y48[:, 2:3, :])
+            if DUMP:
+                dump_as_wav('dump/y48_3pre.wav', 48000, y48[:, 2:3, :])
+                dump_as_wav('dump/y48_3act.wav', 48000, y48_3)
+
+            y48 = torch.cat((y48_1, y48_2, y48_3), dim=1)
+        else:
+            y48 = torch.cat((y48_1, y48_2), dim=1)
+
+        # mix
+        y48 = self.af4(y48, cf, debug=debug)
+
+        if DUMP:
+            dump_as_wav('dump/y48_out.wav', 48000, y48)
+
+        return y48
\ No newline at end of file
--- a/dnn/torch/osce/models/fd_discriminator.py
+++ b/dnn/torch/osce/models/fd_discriminator.py
@@ -132,7 +132,7 @@
                  resolution,
                  fs=16000,
                  freq_roi=[50, 7000],
-                 noise_gain=1e-3,
+                 noise_gain=0,
                  fmap_start_index=0
                  ):
         super().__init__()
@@ -150,10 +150,11 @@
         # filter bank for noise shaping
         n_fft = resolution[0]
 
-        self.filterbank = nn.Parameter(
-            gen_filterbank(n_fft // 2, fs, keep_size=True),
-            requires_grad=False
-        )
+        if self.noise_gain > 0:
+            self.filterbank = nn.Parameter(
+                gen_filterbank(n_fft // 2, fs, keep_size=True),
+                requires_grad=False
+            )
 
         # roi bins
         f_step = fs / n_fft
@@ -207,7 +208,7 @@
                 print("warning: exceeded max size while trying to determine receptive field")
 
         # create transposed convolutional kernel
-        #self.tconv_kernel = nn.Parameter(create_kernel(h0, w0, sw, sw), requires_grad=False)
+        self.tconv_kernel = nn.Parameter(create_kernel(h0, w0, sw, sw), requires_grad=False)
 
     def run_layer_stack(self, spec):
 
@@ -294,9 +295,10 @@
         x = torch.abs(x)
 
         # noise floor following spectral envelope
-        smoothed_x = torch.matmul(self.filterbank, x)
-        noise = torch.randn_like(x) * smoothed_x * self.noise_gain
-        x = x + noise
+        if self.noise_gain > 0:
+            smoothed_x = torch.matmul(self.filterbank, x)
+            noise = torch.randn_like(x) * smoothed_x * self.noise_gain
+            x = x + noise
 
         # frequency ROI
         x = x[:, self.start_bin : self.stop_bin + 1, ...]
@@ -643,7 +645,9 @@
             256: (2, 0),
             512: (3, 0),
             1024: (4, 0),
-            2048: (5, 0)
+            2048: (5, 0),
+            4096: (6, 0),
+            8192: (7, 0),
         },
         'down' : {
             64 : (0, 0),
@@ -651,7 +655,9 @@
             256: (2, 0),
             512: (3, 0),
             1024: (4, 0),
-            2048: (5, 0)
+            2048: (5, 0),
+            4096: (6, 0),
+            8192: (7, 0)
         }
     },
     'ft_down': {
@@ -722,6 +728,7 @@
                  max_channels=256,
                  num_layers=5,
                  use_spectral_norm=False,
+                 k_height=3,
                  design=None):
 
         if design is None:
@@ -729,8 +736,9 @@
 
         norm_f = weight_norm if use_spectral_norm == False else spectral_norm
 
-        stretch = configs[design]['stretch'][resolution[0]]
-        down = configs[design]['down'][resolution[0]]
+        resolution_16k = [(r * 16000) // fs for r in resolution]
+        stretch = configs[design]['stretch'][resolution_16k[0]]
+        down = configs[design]['down'][resolution_16k[0]]
 
         self.num_channels = num_channels
         self.num_channels_max = max_channels
@@ -746,7 +754,7 @@
             layers.append(
                 nn.Sequential(
                     FrequencyPositionalEmbedding(),
-                    norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=plan[i][0], dilation=plan[i][1], padding=plan[i][2])),
+                    norm_f(nn.Conv2d(in_channels, out_channels, (k_height, 3), stride=plan[i][0], dilation=plan[i][1], padding=plan[i][2])),
                     nn.ReLU(inplace=True)
                 )
             )
@@ -758,7 +766,7 @@
         layers.append(
             nn.Sequential(
                 FrequencyPositionalEmbedding(),
-                norm_f(nn.Conv2d(in_channels, 1, (3, 3), stride=plan[-1][0], dilation=plan[-1][1], padding=plan[-1][2])),
+                norm_f(nn.Conv2d(in_channels, 1, (k_height, 3), stride=plan[-1][0], dilation=plan[-1][1], padding=plan[-1][2])),
                 nn.Sigmoid()
             )
         )
--- /dev/null
+++ b/dnn/torch/osce/models/td_discriminator.py
@@ -1,0 +1,150 @@
+"""
+MIT License
+
+Copyright (c) 2020 Jungil Kong
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+
+# This is an adaptation of the HiFi-Gan discriminators derived from https://github.com/jik876/hifi-gan
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size*dilation - dilation)/2)
+
+LRELU_SLOPE = 0.1
+
+class DiscriminatorP(torch.nn.Module):
+    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, max_channels=1024):
+        super(DiscriminatorP, self).__init__()
+        self.max_channels = max_channels
+        self.period = period
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList([
+            norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+            norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+            norm_f(Conv2d(min(self.max_channels, 128), min(self.max_channels, 512), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+            norm_f(Conv2d(min(self.max_channels, 512), min(self.max_channels, 1024), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+            norm_f(Conv2d(min(self.max_channels, 1024), min(self.max_channels, 1024), (kernel_size, 1), 1, padding=(2, 0))),
+        ])
+        self.conv_post = norm_f(Conv2d(min(self.max_channels, 1024), 1, (3, 1), 1, padding=(1, 0)))
+
+    def forward(self, x):
+
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0: # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "reflect")
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        output = []
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            output.append(x)
+        x = self.conv_post(x)
+        output.append(x)
+
+        return output
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+    def __init__(self, max_channels=1024):
+        super(MultiPeriodDiscriminator, self).__init__()
+        self.discriminators = nn.ModuleList([
+            DiscriminatorP(2, max_channels=max_channels),
+            DiscriminatorP(3, max_channels=max_channels),
+            DiscriminatorP(5, max_channels=max_channels),
+            DiscriminatorP(7, max_channels=max_channels),
+            DiscriminatorP(11, max_channels=max_channels),
+        ])
+
+    def forward(self, y):
+        outputs = []
+        for disc in self.discriminators:
+            outputs.append(disc(y))
+
+        return outputs
+
+
+class DiscriminatorS(torch.nn.Module):
+    def __init__(self, use_spectral_norm=False, max_channels=1024):
+        super(DiscriminatorS, self).__init__()
+        self.max_channels = max_channels
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList([
+            norm_f(Conv1d(1, min(self.max_channels, 128), 15, 1, padding=7)),
+            norm_f(Conv1d(min(self.max_channels, 128), min(self.max_channels, 128), 41, 2, groups=4, padding=20)),
+            norm_f(Conv1d(min(self.max_channels, 128), min(self.max_channels, 256), 41, 2, groups=16, padding=20)),
+            norm_f(Conv1d(min(self.max_channels, 256), min(self.max_channels, 512), 41, 4, groups=16, padding=20)),
+            norm_f(Conv1d(min(self.max_channels, 512), min(self.max_channels, 1024), 41, 4, groups=16, padding=20)),
+            norm_f(Conv1d(min(self.max_channels, 1024), min(self.max_channels, 1024), 41, 1, groups=16, padding=20)),
+            norm_f(Conv1d(min(self.max_channels, 1024), min(self.max_channels, 1024), 5, 1, padding=2)),
+        ])
+        self.conv_post = norm_f(Conv1d(min(self.max_channels, 1024), 1, 3, 1, padding=1))
+
+    def forward(self, x):
+        output = []
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            output.append(x)
+        x = self.conv_post(x)
+        output.append(x)
+
+        return output
+
+
+class MultiScaleDiscriminator(torch.nn.Module):
+    def __init__(self, max_channels=1024):
+        super(MultiScaleDiscriminator, self).__init__()
+        self.discriminators = nn.ModuleList([
+            DiscriminatorS(use_spectral_norm=True, max_channels=max_channels),
+            DiscriminatorS(max_channels=max_channels),
+            DiscriminatorS(max_channels=max_channels),
+        ])
+        self.meanpools = nn.ModuleList([
+            AvgPool1d(4, 2, padding=2),
+            AvgPool1d(4, 2, padding=2)
+        ])
+
+    def forward(self, y):
+        outputs = []
+        for disc in self.discriminators:
+            outputs.append(disc(y))
+
+        return outputs
+
+
+class TDMultiResolutionDiscriminator(torch.nn.Module):
+    def __init__(self, max_channels=1024, **kwargs):
+        super().__init__()
+        print(f"{max_channels=}")
+        self.msd = MultiScaleDiscriminator(max_channels=max_channels)
+        self.mpd = MultiPeriodDiscriminator(max_channels=max_channels)
+
+    def forward(self, y):
+        return self.msd(y) + self.mpd(y)
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/pre_to_adv.py
@@ -1,0 +1,30 @@
+import argparse
+import yaml
+
+from utils.templates import setup_dict
+
+parser = argparse.ArgumentParser()
+parser.add_argument('pre_setup_yaml', type=str, help="yaml setup file for pre training")
+parser.add_argument('adv_setup_yaml', type=str, help="path to derived yaml setup file for adversarial training")
+
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+
+
+    with open(args.pre_setup_yaml, "r") as f:
+        setup = yaml.load(f, Loader=yaml.FullLoader)
+
+    key = setup['model']['name'] + '_adv'
+
+    try:
+        adv_setup = setup_dict[key]
+    except:
+        raise KeyError(f"No setup available for {key}")
+
+    setup['training'] = adv_setup['training']
+    setup['discriminator'] = adv_setup['discriminator']
+    setup['data']['frames_per_sample'] = 90
+
+    with open(args.adv_setup_yaml, 'w') as f:
+        yaml.dump(setup, f)
binary files /dev/null b/dnn/torch/osce/stndrd/evaluation/nbwe_dcr1.tar differ
binary files /dev/null b/dnn/torch/osce/stndrd/evaluation/tests/NBWE_ACR.tar differ
binary files /dev/null b/dnn/torch/osce/stndrd/evaluation/tests/dred_journal.tar differ
binary files /dev/null b/dnn/torch/osce/stndrd/evaluation/tests/lace.tar differ
binary files /dev/null b/dnn/torch/osce/stndrd/evaluation/tests/nolace.tar differ
--- /dev/null
+++ b/dnn/torch/osce/test_bwe_model.py
@@ -1,0 +1,101 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import argparse
+
+import torch
+import numpy as np
+
+from scipy.io import wavfile
+
+
+from models import model_dict
+from utils.bwe_features import load_inference_data
+from utils import endoscopy
+
+debug = False
+if debug:
+    args = type('dummy', (object,),
+    {
+        'input'         : 'testitems/all_0_orig.se',
+        'checkpoint'    : 'testout/checkpoints/checkpoint_epoch_5.pth',
+        'output'        : 'out.wav',
+    })()
+else:
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('input', type=str, help='input wave sampled file at 16 kHz')
+    parser.add_argument('checkpoint', type=str, help='checkpoint file')
+    parser.add_argument('output', type=str, help='output file')
+    parser.add_argument('--debug', action='store_true', help='enables debug output')
+    parser.add_argument('--upsamp', type=str, default=None, help='optional path to output upsampled signal')
+
+    args = parser.parse_args()
+
+
+torch.set_num_threads(2)
+
+input_folder = args.input
+checkpoint_file = args.checkpoint
+
+
+output_file = args.output
+if not output_file.endswith('.wav'):
+    output_file += '.wav'
+
+checkpoint = torch.load(checkpoint_file, map_location="cpu")
+
+# check model
+if not 'name' in checkpoint['setup']['model']:
+    print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
+    model_name = 'pitchpostfilter'
+else:
+    model_name = checkpoint['setup']['model']['name']
+
+model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
+
+model.load_state_dict(checkpoint['state_dict'])
+
+# generate model input
+setup = checkpoint['setup']
+signal, features = load_inference_data(input_folder, **setup['data'])
+
+if args.debug:
+    endoscopy.init()
+with torch.no_grad():
+    out = model(signal.view(1, 1, -1), features.unsqueeze(0)).squeeze().numpy()
+wavfile.write(output_file, 48000, (2**15 * out).astype(np.int16))
+
+if args.upsamp is not None:
+    with torch.no_grad():
+        upsamp = model.upsampler(signal.view(1, 1, -1)).numpy()
+    wavfile.write(args.upsamp, 48000, (2**15 * upsamp).astype(np.int16))
+
+if args.debug:
+    endoscopy.close()
--- /dev/null
+++ b/dnn/torch/osce/train_bwe_model.py
@@ -1,0 +1,303 @@
+"""
+/* Copyright (c) 2023 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+seed=1888
+
+import os
+import argparse
+import sys
+import random
+random.seed(seed)
+
+import yaml
+
+try:
+    import git
+    has_git = True
+except:
+    has_git = False
+
+import torch
+torch.manual_seed(seed)
+torch.backends.cudnn.benchmark = False
+from torch.optim.lr_scheduler import LambdaLR
+
+import numpy as np
+np.random.seed(seed)
+
+from scipy.io import wavfile
+
+
+from data import SimpleBWESet
+from models import model_dict
+from engine.bwe_engine import train_one_epoch, evaluate
+
+
+from utils.bwe_features import load_inference_data
+from utils.misc import count_parameters, count_nonzero_parameters
+
+from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
+from losses.td_lowpass import TDLowpass
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('setup', type=str, help='setup yaml file')
+parser.add_argument('output', type=str, help='output path')
+parser.add_argument('--device', type=str, help='compute device', default=None)
+parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
+parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
+parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
+
+args = parser.parse_args()
+
+
+
+torch.set_num_threads(4)
+
+with open(args.setup, 'r') as f:
+    setup = yaml.load(f.read(), yaml.FullLoader)
+
+checkpoint_prefix = 'checkpoint'
+output_prefix = 'output'
+setup_name = 'setup.yml'
+output_file='out.txt'
+
+
+# check model
+if not 'name' in setup['model']:
+    print(f'warning: did not find model entry in setup, using default PitchPostFilter')
+    model_name = 'pitchpostfilter'
+else:
+    model_name = setup['model']['name']
+
+# prepare output folder
+if os.path.exists(args.output):
+    print("warning: output folder exists")
+
+    reply = input('continue? (y/n): ')
+    while reply not in {'y', 'n'}:
+        reply = input('continue? (y/n): ')
+
+    if reply == 'n':
+        os._exit(0)
+else:
+    os.makedirs(args.output, exist_ok=True)
+
+checkpoint_dir = os.path.join(args.output, 'checkpoints')
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+# add repo info to setup
+if has_git:
+    working_dir = os.path.split(__file__)[0]
+    try:
+        repo = git.Repo(working_dir, search_parent_directories=True)
+        setup['repo'] = dict()
+        hash = repo.head.object.hexsha
+        urls = list(repo.remote().urls)
+        is_dirty = repo.is_dirty()
+
+        if is_dirty:
+            print("warning: repo is dirty")
+            with open(os.path.join(args.output, 'repo.diff'), "w") as f:
+                f.write(repo.git.execute(["git", "diff"]))
+
+        setup['repo']['hash'] = hash
+        setup['repo']['urls'] = urls
+        setup['repo']['dirty'] = is_dirty
+    except:
+        has_git = False
+
+# dump setup
+with open(os.path.join(args.output, setup_name), 'w') as f:
+    yaml.dump(setup, f)
+
+if args.testdata is not None:
+
+    testsignal, features = load_inference_data(args.testdata, **setup['data'])
+
+    inference_test = True
+    inference_folder = os.path.join(args.output, 'inference_test')
+    os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
+else:
+    inference_test = False
+
+# training parameters
+batch_size      = setup['training']['batch_size']
+epochs          = setup['training']['epochs']
+lr              = setup['training']['lr']
+lr_decay_factor = setup['training']['lr_decay_factor']
+preemph_gamma   = setup['training']['preemph']
+
+# load training dataset
+data_config = setup['data']
+data = SimpleBWESet(setup['dataset'], **data_config)
+
+# load validation dataset if given
+if 'validation_dataset' in setup:
+    validation_data = SimpleBWESet(setup['validation_dataset'], **data_config)
+
+    validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
+
+    run_validation = True
+else:
+    run_validation = False
+
+# create model
+model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
+
+if args.initial_checkpoint is not None:
+    print(f"loading state dict from {args.initial_checkpoint}...")
+    chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
+    model.load_state_dict(chkpt['state_dict'])
+
+# set compute device
+if type(args.device) == type(None):
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+else:
+    device = torch.device(args.device)
+
+# push model to device
+model.to(device)
+
+# dataloader
+dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
+
+# optimizer is introduced to trainable parameters
+parameters = [p for p in model.parameters() if p.requires_grad]
+optimizer = torch.optim.Adam(parameters, lr=lr)
+
+# learning rate scheduler
+scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
+
+# loss
+w_l1 = setup['training']['loss']['w_l1']
+w_lm = setup['training']['loss']['w_lm']
+w_slm = setup['training']['loss']['w_slm']
+w_sc = setup['training']['loss']['w_sc']
+w_logmel = setup['training']['loss']['w_logmel']
+w_wsc = setup['training']['loss']['w_wsc']
+w_xcorr = setup['training']['loss']['w_xcorr']
+w_sxcorr = setup['training']['loss']['w_sxcorr']
+w_l2 = setup['training']['loss']['w_l2']
+w_tdlp = setup['training']['loss'].get('w_tdlp', 0)
+
+w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 + w_tdlp
+
+
+fft_sizes_16k = [2048, 1024, 512, 256, 128, 64]
+fft_sizes_48k = [3 * n for n in fft_sizes_16k]
+stftloss = MRSTFTLoss(fft_sizes=fft_sizes_48k, fs=48000, sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
+logmelloss = MRLogMelLoss(fft_sizes=fft_sizes_48k).to(device)
+
+def xcorr_loss(y_true, y_pred):
+    dims = list(range(1, len(y_true.shape)))
+
+    loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
+
+    return torch.mean(loss)
+
+def td_l2_norm(y_true, y_pred):
+    dims = list(range(1, len(y_true.shape)))
+
+    loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
+
+    return loss.mean()
+
+def td_l1(y_true, y_pred, pow=0):
+    dims = list(range(1, len(y_true.shape)))
+    tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
+
+    return torch.mean(tmp)
+
+tdlp = TDLowpass(15, 4000/24000).to(device)
+
+def criterion(x, y, x_up):
+
+    return (w_l1 * td_l1(x, y, pow=1) +  stftloss(x, y) + w_logmel * logmelloss(x, y)
+            + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y) + w_tdlp * tdlp(x_up, y)) / w_sum
+
+
+
+# model checkpoint
+checkpoint = {
+    'setup'         : setup,
+    'state_dict'    : model.state_dict(),
+    'loss'          : -1
+}
+
+
+
+
+if not args.no_redirect:
+    print(f"re-directing output to {os.path.join(args.output, output_file)}")
+    sys.stdout = open(os.path.join(args.output, output_file), "w")
+
+print("summary:")
+
+print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
+if hasattr(model, 'flop_count'):
+    print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
+
+
+best_loss = 1e9
+
+for ep in range(1, epochs + 1):
+    print(f"training epoch {ep}...")
+    new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, preemph_gamma)
+
+
+    # save checkpoint
+    checkpoint['state_dict'] = model.state_dict()
+    checkpoint['loss']       = new_loss
+
+    if run_validation:
+        print("running validation...")
+        validation_loss = evaluate(model, criterion, validation_dataloader, device, preemph_gamma)
+        checkpoint['validation_loss'] = validation_loss
+
+        if validation_loss < best_loss:
+            torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
+            best_loss = validation_loss
+
+    if inference_test:
+        print("running inference test...")
+        with torch.no_grad():
+            out = model(testsignal.to(device).view(1, 1, -1), features.to(device).unsqueeze(0)).cpu().squeeze().numpy()
+        wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 48000, (2**15 * out).astype(np.int16))
+
+
+
+    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
+    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
+
+
+    print(f"non-zero parameters: {count_nonzero_parameters(model)}\n")
+
+print('Done')
--- a/dnn/torch/osce/utils/ada_conv.py
+++ b/dnn/torch/osce/utils/ada_conv.py
@@ -33,7 +33,7 @@
 
 # x is (batch, nb_in_channels, nb_frames*frame_size)
 # kernels is (batch, nb_out_channels, nb_in_channels, nb_frames, coeffs)
-def adaconv_kernel(x, kernels, half_window, fft_size=256):
+def adaconv_kernel(x, kernels, half_window, fft_size=256, expansion_power=1):
     device=x.device
     overlap_size=half_window.size(-1)
     nb_frames=kernels.size(3)
@@ -55,6 +55,12 @@
     x_prev = torch.cat([torch.zeros_like(x[:, :, :, :1, :]), x[:, :, :, :-1, :]], dim=-2)
     x_next = torch.cat([x[:, :, :, 1:, :overlap_size], torch.zeros_like(x[:, :, :, -1:, :overlap_size])], dim=-2)
     x_padded = torch.cat([x_prev, x, x_next, torch.zeros(nb_batches, 1, nb_in_channels, nb_frames, fft_size - 2 * frame_size - overlap_size, device=device)], -1)
+    if expansion_power != 1:
+        x_target_energy = torch.sum(x_padded ** 2, dim=-1)
+        x_padded = x_padded ** expansion_power
+        x_new_energy = torch.sum(x_padded ** 2, dim=-1)
+        x_padded = x_padded * torch.sqrt(x_target_energy / (x_new_energy + (1e-6 ** expansion_power))).unsqueeze(-1)
+
     k_padded = torch.cat([torch.flip(kernels, [-1]), torch.zeros(nb_batches, nb_out_channels, nb_in_channels, nb_frames, fft_size-kernel_size, device=device)], dim=-1)
 
     # compute convolution
--- /dev/null
+++ b/dnn/torch/osce/utils/bwe_features.py
@@ -1,0 +1,83 @@
+"""
+/* Copyright (c) 2024 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+
+import numpy as np
+import torch
+
+import scipy
+import scipy.signal
+from scipy.io import wavfile
+
+from utils.spec import log_spectrum, instafreq, create_filter_bank
+
+def bwe_feature_factory(
+    spec_num_bands=32,
+    max_instafreq_bin=40
+):
+    """ features for bwe; we work with a fixed window size of 320 and a hop size of 160 """
+
+    w = scipy.signal.windows.cosine(320)
+    fb = create_filter_bank(spec_num_bands, 320, scale='erb', round_center_bins=True, normalize=True)
+
+    def create_features(x, history=None):
+        if history is None:
+            history = np.zeros(320, dtype=np.float32)
+        lmspec = log_spectrum(np.concatenate((history[-160:], x), dtype=x.dtype), frame_size=320, window=w, fb=fb)
+        freqs = instafreq(np.concatenate((history[-320:], x), dtype=x.dtype), frame_size=320, max_bin=max_instafreq_bin, window=w)
+
+        features = np.concatenate((lmspec, freqs), axis=-1, dtype=np.float32)
+
+        return features
+
+    return create_features
+
+
+def load_inference_data(path,
+                        spec_num_bands=32,
+                        max_instafreq_bin=40,
+                        **kwargs):
+
+    print(f"[load_inference_data]: ignoring keyword arguments {kwargs.keys()}...")
+
+    if path.endswith(".wav"):
+        signal = wavfile.read(path)[1].astype(np.float32) / (2 ** 15)
+    else:
+        signal  = np.fromfile(path, dtype=np.int16).astype(np.float32) / (2 ** 15)
+
+    num_frames = len(signal) // 160
+    signal = signal[:num_frames*160]
+    history = np.zeros(320, dtype=np.float32)
+
+    create_features = bwe_feature_factory(spec_num_bands=spec_num_bands, max_instafreq_bin=max_instafreq_bin)
+
+    features = create_features(signal, history)
+
+    return torch.from_numpy(signal), torch.from_numpy(features)
\ No newline at end of file
--- a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
@@ -31,6 +31,8 @@
 from torch import nn
 import torch.nn.functional as F
 
+import math as m
+
 from utils.endoscopy import write_data
 
 from utils.ada_conv import adaconv_kernel
@@ -53,6 +55,7 @@
                  norm_p=2,
                  softquant=False,
                  apply_weight_norm=False,
+                 expansion_power=1,
                  **kwargs):
         """
 
@@ -95,6 +98,7 @@
         self.gain_limits_db = gain_limits_db
         self.shape_gain_db  = shape_gain_db
         self.norm_p         = norm_p
+        self.expansion_power = expansion_power
 
         if name is None:
             self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER)
@@ -123,7 +127,9 @@
 
         self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
 
+        self.fft_size = 2 ** int(m.ceil(m.log2(2 * frame_size + overlap_size)))
 
+
     def flop_count(self, rate):
         frame_rate = rate / self.frame_size
         overlap = self.overlap_size
@@ -194,7 +200,7 @@
 
         conv_kernels = conv_kernels.permute(0, 2, 3, 1, 4)
 
-        output = adaconv_kernel(x, conv_kernels, win1, fft_size=256)
+        output = adaconv_kernel(x, conv_kernels, win1, fft_size=self.fft_size, expansion_power=self.expansion_power)
 
 
         return output
\ No newline at end of file
--- a/dnn/torch/osce/utils/layers/td_shaper.py
+++ b/dnn/torch/osce/utils/layers/td_shaper.py
@@ -1,6 +1,7 @@
 import torch
 from torch import nn
 import torch.nn.functional as F
+import scipy.signal
 
 from utils.complexity import _conv1d_flop_count
 from utils.softquant import soft_quant
@@ -11,11 +12,15 @@
     def __init__(self,
                  feature_dim,
                  frame_size=160,
-                 avg_pool_k=4,
                  innovate=False,
+                 avg_pool_k=4,
                  pool_after=False,
                  softquant=False,
-                 apply_weight_norm=False
+                 apply_weight_norm=False,
+                 interpolate_k=1,
+                 noise_substitution=False,
+                 cutoff=None,
+                 bias=True,
     ):
         """
 
@@ -38,34 +43,41 @@
 
         super().__init__()
 
+        if innovate:
+            print("warning: option innovate is no longer supported, setting innovate to True will have no effect")
 
         self.feature_dim    = feature_dim
         self.frame_size     = frame_size
         self.avg_pool_k     = avg_pool_k
-        self.innovate       = innovate
         self.pool_after     = pool_after
+        self.interpolate_k  = interpolate_k
+        self.hidden_dim     = frame_size // interpolate_k
+        self.innovate       = innovate
+        self.noise_substitution = noise_substitution
+        self.cutoff             = cutoff
 
         assert frame_size % avg_pool_k == 0
+        assert frame_size % interpolate_k == 0
         self.env_dim = frame_size // avg_pool_k + 1
 
         norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
 
         # feature transform
-        self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
-        self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))
-        self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2))
+        self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, self.hidden_dim, 2, bias=bias))
+        self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, self.hidden_dim, 2, bias=bias))
+        self.feature_alpha2 = norm(nn.Conv1d(self.hidden_dim, self.hidden_dim, 2, bias=bias))
 
+        self.interpolate_weight = nn.Parameter(torch.ones(1, 1, self.interpolate_k) / self.interpolate_k, requires_grad=False)
+
         if softquant:
             self.feature_alpha1_f = soft_quant(self.feature_alpha1_f)
 
-        if self.innovate:
-            self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
-            self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
+        if self.noise_substitution:
+            self.hp = torch.nn.Parameter(torch.from_numpy(scipy.signal.firwin(15, cutoff, pass_zero=False)).float().view(1, 1, -1), requires_grad=False)
+        else:
+            self.hp = None
 
-            self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2))
-            self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2))
 
-
     def flop_count(self, rate):
 
         frame_rate = rate / self.frame_size
@@ -72,13 +84,8 @@
 
         shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
 
-        if self.innovate:
-            inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
-        else:
-            inno_flops = 0
+        return shape_flops
 
-        return shape_flops + inno_flops
-
     def envelope_transform(self, x):
 
         x = torch.abs(x)
@@ -111,9 +118,7 @@
         """
 
         batch_size = x.size(0)
-        num_frames = features.size(1)
         num_samples = x.size(2)
-        frame_size = self.frame_size
 
         # generate temporal envelope
         tenv = self.envelope_transform(x)
@@ -123,23 +128,24 @@
         t = F.pad(tenv.permute(0, 2, 1), [1, 0])
         alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t)
         alpha = F.leaky_relu(alpha, 0.2)
-        alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
+        alpha = self.feature_alpha2(F.pad(alpha, [1, 0]))
+        # reshape and interpolate to size (batch_size, 1, num_samples)
         alpha = alpha.permute(0, 2, 1)
+        alpha = alpha.reshape(batch_size, 1, num_samples // self.interpolate_k)
+        if self.interpolate_k != 1:
+            alpha = F.interpolate(alpha, self.interpolate_k * alpha.size(-1), mode='nearest')
+            alpha = F.conv1d(F.pad(alpha, [self.interpolate_k - 1, 0], mode='reflect'), self.interpolate_weight) # interpolation in log-domain
+        alpha = torch.exp(alpha)
 
-        if self.innovate:
-            inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2)
-            inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0])))
-            inno_alpha = inno_alpha.permute(0, 2, 1)
+        # sample-wise shaping in time domain
+        if self.noise_substitution:
+            if self.hp is not None:
+                x = torch.rand_like(x)
+                x = F.pad(x, [7, 7], mode='reflect')
+                x = F.conv1d(x, self.hp)
+            else:
+                x = 2 * torch.rand_like(x) - 1
 
-            inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2)
-            inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0])))
-            inno_x = inno_x.permute(0, 2, 1)
+        y = alpha * x
 
-        # signal path
-        y = x.reshape(batch_size, num_frames, -1)
-        y = alpha * y
-
-        if self.innovate:
-            y = y + inno_alpha * inno_x
-
-        return y.reshape(batch_size, 1, num_samples)
+        return y
--- a/dnn/torch/osce/utils/misc.py
+++ b/dnn/torch/osce/utils/misc.py
@@ -30,7 +30,7 @@
 import torch
 from torch.nn.utils import remove_weight_norm
 
-def count_parameters(model, verbose=False):
+def count_parameters(model, verbose=False, trainable=True):
     total = 0
     for name, p in model.named_parameters():
         count = torch.ones_like(p).sum().item()
@@ -38,6 +38,8 @@
         if verbose:
             print(f"{name}: {count} parameters")
 
+        if trainable and not p.requires_grad:
+            continue
         total += count
 
     return total
--- a/dnn/torch/osce/utils/spec.py
+++ b/dnn/torch/osce/utils/spec.py
@@ -207,4 +207,35 @@
 
     cepstrum = scipy.fftpack.dct(X, 2, norm='ortho')
 
-    return cepstrum
\ No newline at end of file
+    return cepstrum
+
+def instafreq(x, frame_size, max_bin, window=None):
+
+    assert(2*len(x)) % frame_size == 0
+    assert frame_size % 2 == 0
+
+    n = len(x)
+    num_even = n // frame_size
+    num_odd  = (n - frame_size // 2) // frame_size
+    num_bins = frame_size // 2 + 1
+
+    x_even = x[:num_even * frame_size].reshape(-1, frame_size)
+    x_odd  = x[frame_size//2 : frame_size//2 + frame_size *  num_odd].reshape(-1, frame_size)
+
+    x_unfold = np.empty((x_even.size + x_odd.size), dtype=x.dtype).reshape((-1, frame_size))
+    x_unfold[::2, :] = x_even
+    x_unfold[1::2, :] = x_odd
+
+    if window is not None:
+        x_unfold *= window.reshape(1, -1)
+
+    X = np.fft.fft(x_unfold, n=frame_size, axis=-1)
+
+    # instantaneus frequency
+    X_trunc = X[..., :max_bin + 1] + 1e-9
+    Y = X_trunc[1:] * np.conj(X_trunc[:-1])
+    Y = Y / (np.abs(Y) + 1e-9)
+
+    instafreq = np.concatenate((np.real(Y), np.imag(Y)), axis=-1, dtype=x.dtype)
+
+    return instafreq
--- a/dnn/torch/osce/utils/templates.py
+++ b/dnn/torch/osce/utils/templates.py
@@ -89,6 +89,230 @@
 }
 
 
+bwenet_setup = {
+    'dataset': '/local2/bwe0_dataset/training',
+    'validation_dataset': '/local2/bwe0_dataset/validation',
+    'model': {
+        'name': 'bwenet',
+        'args': [],
+        'kwargs': {
+            'cond_dim': 128,
+            'conv_gain_limits_db': [-12, 12],
+            'kernel_size32': 15,
+            'kernel_size48': 15,
+            'feature_dim': 114,
+            'activation' : "AdaShape"
+        }
+    },
+    'data': {
+        'frames_per_sample': 100,
+        'spec_num_bands' : 32,
+        'max_instafreq_bin' : 40,
+        'upsampling_delay48' : 13
+    },
+    'training': {
+        'batch_size': 128,
+        'lr': 5.e-4,
+        'lr_decay_factor': 2.5e-5,
+        'epochs': 50,
+        'loss': {
+            'w_l1': 0,
+            'w_lm': 0,
+            'w_logmel': 0,
+            'w_sc': 0,
+            'w_wsc': 0,
+            'w_xcorr': 0,
+            'w_sxcorr': 1,
+            'w_l2': 0,
+            'w_slm': 2,
+            'w_tdlp': 1
+        },
+        'preemph': 0.9
+    }
+}
+
+bwenet_setup_adv = {
+    'dataset': '/local2/bwe0_dataset/training',
+    'validation_dataset': '/local2/bwe0_dataset/validation',
+    'model': {
+        'name': 'bwenet',
+        'args': [],
+        'kwargs': {
+            'cond_dim': 128,
+            'conv_gain_limits_db': [-12, 12],
+            'kernel_size32': 15,
+            'kernel_size48': 15,
+            'feature_dim': 114,
+            'activation' : "AdaShape"
+        }
+    },
+    'data': {
+        'frames_per_sample': 60,
+        'spec_num_bands' : 32,
+        'max_instafreq_bin' : 40,
+        'upsampling_delay48' : 13
+    },
+    'discriminator': {
+        'args': [],
+        'kwargs': {
+            'architecture': 'free',
+            'design': 'f_down',
+            'fft_sizes_16k': [
+                64,
+                128,
+                256,
+                512,
+                1024,
+                2048,
+            ],
+            'freq_roi': [0, 22000],
+            'fs': 48000,
+            'k_height': 7,
+            'max_channels': 64,
+            'noise_gain': 0.0
+        },
+        'name': 'fdmresdisc',
+    },
+    'training': {
+        'adv_target': 'target_orig',
+        'batch_size': 64,
+        'epochs': 50,
+        'gen_lr_reduction': 1,
+        'lambda_feat': 1.0,
+        'lambda_reg': 0.6,
+        'loss': {
+            'w_l1': 0,
+            'w_l2': 0,
+            'w_lm': 0,
+            'w_logmel': 0,
+            'w_sc': 0,
+            'w_slm': 1,
+            'w_sxcorr': 2,
+            'w_wsc': 0,
+            'w_xcorr': 0,
+            'w_tdlp': 10,
+        },
+        'lr': 0.0001,
+        'lr_decay_factor': 2.5e-09,
+        'preemph': 0.85
+    }
+}
+
+bbwenet_setup = {
+    'dataset': '/local2/bwe0_dataset/training',
+    'validation_dataset': '/local2/bwe0_dataset/validation',
+    'model': {
+        'name': 'bbwenet',
+        'args': [],
+        'kwargs': {
+            'cond_dim': 128,
+            'conv_gain_limits_db': [-12, 12],
+            'kernel_size32': 25,
+            'kernel_size48': 15,
+            'feature_dim': 114,
+            'activation' : "ImPowI",
+            'interpolate_k32': 2,
+            'interpolate_k48': 2,
+            'func_extension': False,
+            'shape_extension': True,
+            'shaper': 'TDShaper'
+        }
+    },
+    'data': {
+        'frames_per_sample': 90,
+        'spec_num_bands' : 32,
+        'max_instafreq_bin' : 40,
+        'upsampling_delay48' : 13
+    },
+    'training': {
+        'batch_size': 128,
+        'lr': 5.e-4,
+        'lr_decay_factor': 2.5e-5,
+        'epochs': 50,
+        'loss': {
+            'w_l1': 0,
+            'w_lm': 0,
+            'w_logmel': 0,
+            'w_sc': 0,
+            'w_wsc': 0,
+            'w_xcorr': 0,
+            'w_sxcorr': 2,
+            'w_l2': 10,
+            'w_slm': 1,
+            'w_tdlp': 1
+        },
+        'preemph': 0.85
+    }
+}
+
+bbwenet_setup_adv = {
+    'dataset': '/local2/bwe0_dataset/training',
+    'validation_dataset': '/local2/bwe0_dataset/validation',
+    'model': {
+        'name': 'bwenet',
+        'args': [],
+        'kwargs': {
+            'cond_dim': 128,
+            'conv_gain_limits_db': [-12, 12],
+            'kernel_size32': 15,
+            'kernel_size48': 15,
+            'feature_dim': 114,
+            'activation' : "TDShaper"
+        }
+    },
+    'data': {
+        'frames_per_sample': 60,
+        'spec_num_bands' : 32,
+        'max_instafreq_bin' : 40,
+        'upsampling_delay48' : 13
+    },
+    'discriminator': {
+        'args': [],
+        'kwargs': {
+            'architecture': 'free',
+            'design': 'f_down',
+            'fft_sizes_16k': [
+                64,
+                128,
+                256,
+                512,
+                1024,
+                2048,
+            ],
+            'freq_roi': [0, 22000],
+            'fs': 48000,
+            'k_height': 7,
+            'max_channels': 64,
+            'noise_gain': 0.0
+        },
+        'name': 'fdmresdisc',
+    },
+    'training': {
+        'adv_target': 'target_orig',
+        'batch_size': 64,
+        'epochs': 50,
+        'gen_lr_reduction': 1,
+        'lambda_feat': 1.0,
+        'lambda_reg': 0.6,
+        'loss': {
+            'w_l1': 0,
+            'w_l2': 0,
+            'w_lm': 0,
+            'w_logmel': 0,
+            'w_sc': 0,
+            'w_slm': 1,
+            'w_sxcorr': 2,
+            'w_wsc': 0,
+            'w_xcorr': 0,
+            'w_tdlp': 10,
+        },
+        'lr': 0.0001,
+        'lr_decay_factor': 2.5e-09,
+        'preemph': 0.85
+    }
+}
+
+
 nolace_setup = {
     'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
     'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
@@ -343,5 +567,9 @@
     'nolace': nolace_setup,
     'nolace_adv': nolace_setup_adv,
     'lavoce': lavoce_setup,
-    'lavoce_adv': lavoce_setup_adv
+    'lavoce_adv': lavoce_setup_adv,
+    'bwenet' : bwenet_setup,
+    'bwenet_adv': bwenet_setup_adv,
+    'bbwenet': bbwenet_setup,
+    'bbwenet_adv': bbwenet_setup_adv
 }
--