shithub: opus

Download patch

ref: 931a37caf2a4af703d61f4c2cb52a7811e4939b8
parent: 5ded705cf4ffa13702c78eebecea0fdb2f4ef0de
author: Jan Buethe <jan.buethe@googlemail.com>
date: Tue Jul 22 14:01:48 EDT 2025

first version of highband evaluation script

--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/highband_eval.py
@@ -1,0 +1,91 @@
+import os
+import argparse
+
+import numpy as np
+from scipy.io import wavfile
+from scipy.signal import stft
+
+parser = argparse.ArgumentParser()
+parser.add_argument('ref_wav', type=str, help='reference wav file')
+parser.add_argument('test_wav', type=str, help='test wav file')
+
+opus_eband5ms = [0,  1,  2,  3,  4,  5,  6,  7,  8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100]
+opus_eband10ms = [2 * b for b in opus_eband5ms]
+
+def create_filter_bank(band_limits, num_bins, normalize=False):
+    filters = []
+    for i in range(len(band_limits) - 1):
+        i_start, i_stop = band_limits[i], band_limits[i + 1]
+        coeffs = np.zeros(num_bins)
+        coeffs[i_start:i_stop] = 1
+
+        if normalize:
+            coeffs = coeffs / sum(coeffs)
+
+        filters.append(coeffs.reshape(1, -1))
+
+    fb = np.concatenate(filters, axis=0)
+    return fb
+
+fb = create_filter_bank(opus_eband10ms, 241)
+
+highband_idx = -4
+func = lambda x : x ** 0.25
+
+def band_wise_distortion(Xref, Xcut, p=None):
+    Xref = np.abs(Xref)**2
+    Xcut = np.abs(Xcut)**2
+
+    Yref = fb @ Xref
+    Ycut = fb @ Xcut
+
+    nf = np.max(Yref, axis=1) * (10**(-30/10))
+    
+    Yref = func(np.maximum(Yref, nf.reshape(-1, 1)))
+    Ycut = func(np.maximum(Ycut, nf.reshape(-1, 1)))
+    delta = np.abs(Yref - Ycut)
+
+    if p is not None:
+        delta = np.linalg.norm(delta / delta.shape[-1], ord=p, axis=1)
+    
+    return 1000 * delta[highband_idx:]
+
+
+def main(ref_path, cut_path):
+    args = parser.parse_args()
+    # Read WAV files
+    _, x_ref = wavfile.read(ref_path)
+    _, x_cut = wavfile.read(cut_path)
+
+    m = np.max(np.abs(x_ref))
+    x_ref = x_ref / m
+    x_cut = x_cut / m
+    
+    # Compute STFTs
+    _, _, X_ref = stft(x_ref, nperseg=480)
+    _, _, X_cut = stft(x_cut, nperseg=480)
+    
+    # lowpass reference
+    X_lp = X_ref.copy()
+    X_lp[80:, :] = 0
+    
+    # flip comparison
+    X_flip = X_ref.copy()
+    X_flip[80:, :] = X_ref[80:, ::-1]
+    
+    # Calculate distortions
+    dist_ref_cut = band_wise_distortion(X_ref, X_cut, 2)
+    dist_ref_lp = band_wise_distortion(X_ref, X_lp, 2)
+    dist_ref_flip = band_wise_distortion(X_ref, X_flip, 2)
+    
+    # Print results
+    ref_cut_pass  = all(dist_ref_cut <= dist_ref_lp)
+    ref_flip_pass = all(dist_ref_flip <= dist_ref_lp)
+    print(f"ref-test distortion: {dist_ref_cut} ({'pass' if ref_cut_pass else 'fail'})")
+    print(f"ref-lp   distortion: {dist_ref_lp}")
+    print(f"ref-flip distortion: {dist_ref_flip} ({'pass' if ref_flip_pass else 'fail'})")
+    
+if __name__ == "__main__":
+    # Example usage: update these paths and filename as needed
+    args = parser.parse_args()
+    main(args.ref_wav, args.test_wav)
\ No newline at end of file
--