ref: b9912f7eb3455926987434db5e3dc97e1a24c1e9
dir: /dnn/torch/osce/bwe_preproc.py/
import os import argparse from typing import List import numpy as np from scipy import signal from scipy.io import wavfile import resampy import math as m parser = argparse.ArgumentParser() parser.add_argument("filelist", type=str, help="file with filenames for concatenation in WAVE format") parser.add_argument("target_fs", type=int, help="target sampling rate of concatenated file") parser.add_argument("output", type=str, help="output directory") parser.add_argument("--basedir", type=str, help="basedir for filenames in filelist, defaults to ./", default="./") parser.add_argument("--normalize", action="store_true", help="apply normalization") parser.add_argument("--db_max", type=float, help="max DB for random normalization", default=0) parser.add_argument("--db_min", type=float, help="min DB for random normalization", default=0) parser.add_argument("--random_eq_prob", type=float, help="portion of items to which random eq will be applied (default: 0.4)", default=0.4) parser.add_argument("--static_noise_prob", type=float, help="portion of items to which static noise will be added (default: 0.2)", default=0.2) parser.add_argument("--random_dc_prob", type=float, help="portion of items to which random dc offset will be added (default: 0.1)", default=0.1) parser.add_argument("--rirdir", type=str, default=None, help="folder with room impulse responses in wav format (defaul: None)") parser.add_argument("--rir_prob", type=float, default=0.0, help="portion of items to which a random rir is applied (default: 0)") parser.add_argument("--verbose", action="store_true") def read_filelist(basedir, filelist): with open(filelist, "r") as f: files = f.readlines() fullfiles = [os.path.join(basedir, f.rstrip('\n')) for f in files if len(f.rstrip('\n')) > 0] return fullfiles def read_wave(file, target_fs): fs, x = wavfile.read(file) if fs < target_fs: return None print(f"[read_wave] warning: file {file} will be up-sampled from {fs} to {target_fs} Hz") if fs != target_fs: x = resampy.resample(x, fs, target_fs) return x.astype(np.float32) def load_rirs(rirdir, target_fs): """ read rirs (assumed .wav) from subfolders of rirdir """ rirs = [] for dirpath, dirnames, filenames in os.walk(rirdir): for file in filenames: if file.endswith(".wav"): x = read_wave(os.path.join(dirpath, file), target_fs) x = x / np.max(np.abs(x)) rirs.append(x) return rirs lp_coeffs = signal.firwin(151, 20000, fs=48000) def apply_20kHz_lp(x, fs): if fs != 48000: return x y = np.convolve(x, lp_coeffs, mode='valid') y *= np.max(np.abs(x)) / np.max(np.abs(y) + 1e-6) return y def random_normalize(x, db_min, db_max, max_val=2**15 - 1): db = np.random.uniform(db_min, db_max, 1) m = np.abs(x).max() c = 10**(db/20) * max_val / m return c * x def random_resamp16(x, fs=48000): assert fs == 48000 and "only supporting 48kHz input sampling rate for now" cutoff = 800 * np.random.rand(1) + 7200 # cutoff between 7.2 and 8 kHz numtaps = 2 * np.random.randint(75, 150) + 1 a = signal.firwin(numtaps, cutoff, fs=fs) x16 = np.convolve(x, a, mode='same')[::3] return x16 def estimate_bandwidth(x, fs): assert fs >= 44100 and "currently only fs >= 44100 supported" f, t, X = signal.stft(x, nperseg=960, fs=fs) X = X.transpose() X_pow = np.abs(X) ** 2 X_nrg = np.sum(X_pow, axis=1) threshold = np.sort(X_nrg)[int(0.9 * len(X_nrg))] * 0.1 X_pow = X_pow[X_nrg > threshold] i = 0 wb_nrg = 0 wb_bands = 0 while f[i] < 8000: wb_nrg += np.sum(X_pow[:, i]) wb_bands += 1 i += 1 wb_nrg /= wb_bands i += 5 # safety margin swb_nrg = 0 swb_bands = 0 while f[i] < 16000: swb_nrg += np.sum(X_pow[:, i]) swb_bands += 1 i += 1 swb_nrg /= swb_bands i += 5 # safety margin fb_nrg = 0 fb_bands = 0 while i < X_pow.shape[1]: fb_nrg += np.sum(X_pow[:, i]) fb_bands += 1 i += 1 fb_nrg /= fb_bands if swb_nrg / wb_nrg < 1e-5: return 'wb' elif fb_nrg / wb_nrg < 1e-7: return 'swb' else: return 'fb' def _get_random_eq_filter(num_taps=51, min_gain=1/3, max_gain=3, cutoff=8000, fs=48000, num_bands=15): nyquist = fs / 2 freqs = (np.arange(num_bands)) / (num_bands - 1) cutoff = cutoff/nyquist log_min_gain = m.log(min_gain) log_max_gain = m.log(max_gain) split = int(cutoff * (num_bands - 1)) + 1 log_gains = np.random.rand(num_bands) * (log_max_gain - log_min_gain) + log_min_gain low_band_mean = np.mean(log_gains[:split]) log_gains[:split] -= low_band_mean log_gains[split:] = 0 gains = np.exp(log_gains) taps = signal.firwin2(num_taps, freqs, gains, nfreqs=127) return taps def trim_silence(x, fs, threshold=0.005): frame_size = 320 * fs // 16000 num_frames = len(x) // frame_size y = x[: frame_size * num_frames] frame_nrg = np.sum(y.reshape(-1, frame_size) ** 2, axis=1) ref_nrg = np.sort(frame_nrg)[int(num_frames * 0.9)] silence_threshold = threshold * ref_nrg for i, nrg in enumerate(frame_nrg): if nrg > silence_threshold: break first_active_frame_index = i for i in range(num_frames - 1, -1, -1): if frame_nrg[i] > silence_threshold: break last_active_frame_index = i i_start = max(first_active_frame_index - 20, 0) * frame_size i_stop = min(last_active_frame_index + 20, num_frames - 1) * frame_size return x[i_start:i_stop] def random_eq(x, fs, cutoff): taps = _get_random_eq_filter(fs=fs, cutoff=cutoff) y = np.convolve(taps, x.astype(np.float32)) # rescale y *= np.max(np.abs(x)) / np.max(np.abs(y + 1e-9)) return y def static_lowband_noise(x, fs, cutoff, max_gain=0.02): k_lp = (5 * fs // 16000) lp_taps = signal.firwin(2 * k_lp + 1, 2 * cutoff / fs) eq_taps = _get_random_eq_filter(num_bands=9) noise = np.random.randn(len(x) + len(lp_taps) + len(eq_taps) - 2) noise = np.convolve(noise, lp_taps, mode='valid') noise = np.convolve(noise, eq_taps, mode='valid') gain = np.random.rand(1) * max_gain x_max = np.max(np.abs(x)) noise *= gain * x_max / np.max(np.abs(noise)) y = x + noise y *= x_max / np.max(np.abs(y + 1e-9)) return y def apply_random_rir(x, rirs, rescale=True): i = np.random.randint(0, len(rirs)) y = np.convolve(x, rirs[i], mode='same') if rescale: y *= np.max(np.abs(x)) / np.max(np.abs(y) + 1e-6) return y def random_dc_offset(x, max_rel_offset=0.03): x_max = np.max(np.abs(x)) offset = x_max * (2 * np.random.rand(1) - 1) * max_rel_offset y = x + offset y *= x_max / np.max(np.abs(y + 1e-9)) return y def concatenate(filelist : str, outdir : str, target_fs : int, normalize : bool=True, db_min : float=0, db_max : float=0, rand_eq_prob : float=0, static_noise_prob: float=0, rand_dc_prob : float=0, rirs : List = None, rir_prob : float = 0, verbose=False): overlap_size = int(40 * target_fs / 8000) overlap_mem = np.zeros(overlap_size, dtype=np.float32) overlap_win1 = (0.5 + 0.5 * np.cos(np.arange(0, overlap_size) * np.pi / overlap_size)).astype(np.float32) overlap_win2 = np.flipud(overlap_win1) # same for 16 kHz assert overlap_size % 3 == 0 overlap_size16 = overlap_size // 3 overlap_mem16 = np.zeros(overlap_size16, dtype=np.float32) overlap_win1_16 = overlap_win1[::3] overlap_win2_16 = np.flipud(overlap_win1_16) output48 = os.path.join(outdir, 'signal_48kHz.s16') output16 = os.path.join(outdir, 'signal_16kHz.s16') os.makedirs(outdir, exist_ok=True) with open(output48, 'wb') as f48, open(output16, 'wb') as f16: for file in filelist: x = read_wave(file, target_fs) if x is None: continue x = trim_silence(x, target_fs) x = apply_20kHz_lp(x, target_fs) bwidth = estimate_bandwidth(x, target_fs) if bwidth != 'fb': if verbose: print(f"bandwidth {bwidth} detected: skipping {file}...") continue if len(x) < 10 * overlap_size: if verbose: print(f"skipping {file}...") continue elif verbose: print(f"processing {file}...") noise_first = np.random.randint(2) if np.random.rand(1) < rand_eq_prob: x = random_eq(x, target_fs, 5000) if not noise_first: if np.random.rand(1) < rir_prob: x = apply_random_rir(x, rirs) if np.random.rand(1) < static_noise_prob: x = static_lowband_noise(x, target_fs, 8000, max_gain=0.01) if noise_first: if np.random.rand(1) < rir_prob: x = apply_random_rir(x, rirs) if np.random.rand(1) < rand_dc_prob: x = random_dc_offset(x) # trim final signal to length divisible by 3 to keep 16 and 48 kHz signals in sync x = x[:len(x) - (len(x) % 3)] if normalize: x = random_normalize(x, db_min, db_max) # write 48 and 16 kHz signals to disk if False: x1 = x[:-overlap_size] x1[:overlap_size] = overlap_win1 * overlap_mem + overlap_win2 * x1[:overlap_size] f48.write(x1.astype(np.int16).tobytes()) x16 = random_resamp16(x) x1_16 = x16[:-overlap_size16] x1_16[:overlap_size16] = overlap_win1_16 * overlap_mem16 + overlap_win2_16 * x1_16[:overlap_size16] f16.write(x1_16.astype(np.int16).tobytes()) # memory update overlap_mem = x[-overlap_size:] overlap_mem16 = x16[-overlap_size16:] else: # window and zero pad signal padding_samples = 3 * 100 x[:overlap_size] *= overlap_win2 # fade in x[-overlap_size:] *= overlap_win1 # fade out x = np.concatenate((np.zeros(padding_samples), x, np.zeros(padding_samples)), dtype=x.dtype) x16 = random_resamp16(x) assert 3*len(x16) == len(x) if np.max(x) > 2**15 - 1 or np.min(x) < -2**15: print("clipping") if np.max(x16) > 2**15 - 1 or np.min(x16) < -2**15: print("clipping") x = np.clip(x, -2**15, 2**15 - 1) x16 = np.clip(x16, -2**15, 2**15 - 1) f48.write(x.astype(np.int16).tobytes()) f16.write(x16.astype(np.int16).tobytes()) if __name__ == "__main__": args = parser.parse_args() filelist = read_filelist(args.basedir, args.filelist) if args.rirdir is not None: rirs = load_rirs(args.rirdir, args.target_fs) else: rirs = [] concatenate(filelist, args.output, args.target_fs, normalize=args.normalize, db_min=args.db_min, db_max=args.db_max, rand_eq_prob=args.random_eq_prob, static_noise_prob=args.static_noise_prob, rand_dc_prob=args.random_dc_prob, rirs=rirs, rir_prob=args.rir_prob, verbose=args.verbose)