shithub: opus

ref: d9df94947de8a754af8ab35b9ca1808d43a26f9e
dir: /dnn/torch/osce/stndrd/evaluation/highband_eval.py/

View raw version
import os
import argparse

import numpy as np
from scipy.io import wavfile
from scipy.signal import stft

parser = argparse.ArgumentParser()
parser.add_argument('ref_wav', type=str, help='reference wav file')
parser.add_argument('test_wav', type=str, help='test wav file')

opus_eband5ms = [0,  1,  2,  3,  4,  5,  6,  7,  8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100]
opus_eband10ms = [2 * b for b in opus_eband5ms]

def create_filter_bank(band_limits, num_bins, normalize=False):
    filters = []
    for i in range(len(band_limits) - 1):
        i_start, i_stop = band_limits[i], band_limits[i + 1]
        coeffs = np.zeros(num_bins)
        coeffs[i_start:i_stop] = 1

        if normalize:
            coeffs = coeffs / sum(coeffs)

        filters.append(coeffs.reshape(1, -1))

    fb = np.concatenate(filters, axis=0)
    return fb

fb = create_filter_bank(opus_eband10ms, 241)

highband_idx = -4
func = lambda x : x ** 0.25

def band_wise_distortion(Xref, Xcut, p=None):
    Xref = np.abs(Xref)**2
    Xcut = np.abs(Xcut)**2

    Yref = fb @ Xref
    Ycut = fb @ Xcut

    nf = np.max(Yref, axis=1) * (10**(-30/10))
    
    Yref = func(np.maximum(Yref, nf.reshape(-1, 1)))
    Ycut = func(np.maximum(Ycut, nf.reshape(-1, 1)))
    delta = np.abs(Yref - Ycut)

    if p is not None:
        delta = np.linalg.norm(delta / delta.shape[-1], ord=p, axis=1)
    
    return 1000 * delta[highband_idx:]


def main(ref_path, cut_path):
    args = parser.parse_args()
    # Read WAV files
    _, x_ref = wavfile.read(ref_path)
    _, x_cut = wavfile.read(cut_path)

    m = np.max(np.abs(x_ref))
    x_ref = x_ref / m
    x_cut = x_cut / m
    
    # Compute STFTs
    _, _, X_ref = stft(x_ref, nperseg=480)
    _, _, X_cut = stft(x_cut, nperseg=480)
    
    # lowpass reference
    X_lp = X_ref.copy()
    X_lp[80:, :] = 0
    
    # flip comparison
    X_flip = X_ref.copy()
    X_flip[80:, :] = X_ref[80:, ::-1]
    
    # Calculate distortions
    dist_ref_cut = band_wise_distortion(X_ref, X_cut, 2)
    dist_ref_lp = band_wise_distortion(X_ref, X_lp, 2)
    dist_ref_flip = band_wise_distortion(X_ref, X_flip, 2)
    
    # Print results
    ref_cut_pass  = all(dist_ref_cut <= dist_ref_lp)
    ref_flip_pass = all(dist_ref_flip <= dist_ref_lp)
    print(f"ref-test distortion: {dist_ref_cut} ({'pass' if ref_cut_pass else 'fail'})")
    print(f"ref-lp   distortion: {dist_ref_lp}")
    print(f"ref-flip distortion: {dist_ref_flip} ({'pass' if ref_flip_pass else 'fail'})")
    
if __name__ == "__main__":
    # Example usage: update these paths and filename as needed
    args = parser.parse_args()
    main(args.ref_wav, args.test_wav)