shithub: opus

ref: 4f3761b0199df7024b6e6b2004fc5eb7a6dbb28b
dir: /dnn/torch/testsuite/run_test.py/

View raw version
import os
import multiprocess as multiprocessing
import random
import subprocess
import argparse
import shutil

import yaml

from utils.files import get_wave_file_list
from utils.warpq import compute_WAPRQ
from utils.pesq import compute_PESQ
from utils.pitch import compute_pitch_error


parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml specifying end to end processing with model under test')
parser.add_argument('input_folder', type=str, help='input folder path')
parser.add_argument('output_folder', type=str, help='output folder path')
parser.add_argument('--num-testitems', type=int, help="number of testitems to be processed (default 100)", default=100)
parser.add_argument('--seed', type=int, help='seed for random item selection', default=None)
parser.add_argument('--fs', type=int, help="sampling rate at which input is presented as wave file (defaults to 16000)", default=16000)
parser.add_argument('--num-workers', type=int, help="number of subprocesses to be used (default=4)", default=4)
parser.add_argument('--plc-suffix', type=str, default="_is_lost.txt", help="suffix of plc error pattern file: only relevant if command chain uses PLCFILE (default=_is_lost.txt)")
parser.add_argument('--metrics', type=str, default='warpq', help='comma separated string of metrics, supported: {{"warpq", "pesq"}}, default="warpq"')
parser.add_argument('--verbose', action='store_true', help='enables printouts of all commands run in the pipeline')

def check_for_sox_in_path():
    r = subprocess.run("sox -h", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    return r.returncode == 0


def run_save_sh(command, verbose=False):

    if verbose:
        print(f"[run_save_sh] running command {command}...")

    r = subprocess.run(command, shell=True)
    if r.returncode != 0:
        raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}")


def run_processing_chain(input_path, output_path, model_commands, fs, metrics={'warpq'}, plc_suffix="_is_lost.txt", verbose=False):

    # prepare model input
    model_input = output_path + ".resamp.wav"
    run_save_sh(f"sox {input_path} -r {fs} {model_input}", verbose=verbose)

    plcfile = os.path.splitext(input_path)[0] + plc_suffix
    if os.path.isfile(plcfile):
        run_save_sh(f"cp {plcfile} {os.path.dirname(output_path)}")

    # generate model output
    for command in model_commands:
        run_save_sh(command.format(INPUT=model_input, OUTPUT=output_path, PLCFILE=plcfile), verbose=verbose)

    scores = dict()
    cache = dict()
    for metric in metrics:
        if metric == 'warpq':
            # run warpq
            score = compute_WAPRQ(input_path, output_path, sr=fs)
        elif metric == 'pesq':
            # run pesq
            score = compute_PESQ(input_path, output_path, fs=fs)
        elif metric == 'pitch_error':
            if metric in cache:
                score = cache[metric]
            else:
                rval = compute_pitch_error(input_path, output_path, fs=fs)
                score = rval[metric]
                cache['voicing_error'] = rval['voicing_error']
        elif metric == 'voicing_error':
            if metric in cache:
                score = cache[metric]
            else:
                rval = compute_pitch_error(input_path, output_path, fs=fs)
                score = rval[metric]
                cache['pitch_error'] = rval['pitch_error']
        else:
            ValueError(f'error: unknown metric {metric}')

        scores[metric] = score

    return (output_path, scores)


def get_output_path(root_folder, input, output_folder):

    input_relpath = os.path.relpath(input, root_folder)

    os.makedirs(os.path.join(output_folder, 'processing', os.path.dirname(input_relpath)), exist_ok=True)

    output_path = os.path.join(output_folder, 'processing', input_relpath + '.output.wav')

    return output_path


def add_audio_table(f, html_folder, results, title, metric):

    item_folder = os.path.join(html_folder, 'items')
    os.makedirs(item_folder, exist_ok=True)

    # table with results
    f.write(f"""
            <div>
            <h2> {title} </h2>
            <table>
            <tr>
                <th> Rank   </th>
                <th> Name   </th>
                <th> {metric.upper()} </th>
                <th> Audio (out)  </th>
                <th> Audio (orig)  </th>
            </tr>
            """)

    for i, r in enumerate(results):
        item, score = r
        item_name = os.path.basename(item)
        new_item_path = os.path.join(item_folder, item_name)
        shutil.copyfile(item, new_item_path)
        shutil.copyfile(item + '.resamp.wav', os.path.join(item_folder, item_name + '.orig.wav'))

        f.write(f"""
                <tr>
                    <td> {i + 1} </td>
                    <td> {item_name.split('.')[0]} </td>
                    <td> {score:.3f} </td>
                    <td>
                        <audio controls>
                            <source src="items/{item_name}">
                        </audio>
                    </td>
                    <td>
                        <audio controls>
                            <source src="items/{item_name + '.orig.wav'}">
                        </audio>
                    </td>
                </tr>
                """)

    # footer
    f.write("""
            </table>
            </div>
            """)


def create_html(output_folder, results, title, metric):

    html_folder = output_folder
    items_folder = os.path.join(html_folder, 'items')
    os.makedirs(html_folder, exist_ok=True)
    os.makedirs(items_folder, exist_ok=True)

    with open(os.path.join(html_folder, 'index.html'), 'w') as f:
        # header and title
        f.write(f"""
                <!DOCTYPE html>
                <html lang="en">
                <head>
                    <meta charset="utf-8">
                    <title>{title}</title>
                    <style>
                        article {{
                            align-items: flex-start;
                            display: flex;
                            flex-wrap: wrap;
                            gap: 4em;
                        }}
                        html {{
                            box-sizing: border-box;
                            font-family: "Amazon Ember", "Source Sans", "Verdana", "Calibri", sans-serif;
                            padding: 2em;
                        }}
                        td {{
                            padding: 3px 7px;
                            text-align: center;
                        }}
                        td:first-child {{
                            text-align: end;
                        }}
                        th {{
                            background: #ff9900;
                            color: #000;
                            font-size: 1.2em;
                            padding: 7px 7px;
                        }}
                    </style>
                </head>
                </body>
                <h1>{title}</h1>
                <article>
                """)

        # top 20
        add_audio_table(f, html_folder, results[:-21: -1], "Top 20", metric)

        # 20 around median
        N = len(results) // 2
        add_audio_table(f, html_folder, results[N + 10 : N - 10: -1], "Median 20", metric)

        # flop 20
        add_audio_table(f, html_folder, results[:20], "Flop 20", metric)

        # footer
        f.write("""
                </article>
                </body>
                </html>
                """)

metric_sorting_signs = {
    'warpq'         : -1,
    'pesq'          : 1,
    'pitch_error'   : -1,
    'voicing_error' : -1
}

def is_valid_result(data, metrics):
    if not isinstance(data, dict):
        return False

    for metric in metrics:
        if not metric in data:
            return False

    return True


def evaluate_results(output_folder, results, metric):

    results = sorted(results, key=lambda x : metric_sorting_signs[metric] * x[1])
    with open(os.path.join(args.output_folder, f'scores_{metric}.txt'), 'w') as f:
        for result in results:
            f.write(f"{os.path.relpath(result[0], args.output_folder)} {result[1]}\n")


    # some statistics
    mean = sum([r[1] for r in results]) / len(results)
    top_mean = sum([r[1] for r in results[-20:]]) / 20
    bottom_mean = sum([r[1] for r in results[:20]]) / 20

    with open(os.path.join(args.output_folder, f'stats_{metric}.txt'), 'w') as f:
        f.write(f"mean score: {mean}\n")
        f.write(f"bottom mean score: {bottom_mean}\n")
        f.write(f"top mean score: {top_mean}\n")

    print(f"\nmean score: {mean}")
    print(f"bottom mean score: {bottom_mean}")
    print(f"top mean score: {top_mean}\n")

    # create output html
    create_html(os.path.join(output_folder, 'html', metric), results, setup['test'], metric)

if __name__ == "__main__":
    args = parser.parse_args()

    # check for sox
    if not check_for_sox_in_path():
        raise RuntimeError("script requires sox")


    # prepare output folder
    if os.path.exists(args.output_folder):
        print("warning: output folder exists")

        reply = input('continue? (y/n): ')
        while reply not in {'y', 'n'}:
            reply = input('continue? (y/n): ')

        if reply == 'n':
            os._exit()
        else:
            # start with a clean sleight
            shutil.rmtree(args.output_folder)

    os.makedirs(args.output_folder, exist_ok=True)

    # extract metrics
    metrics = args.metrics.split(",")
    for metric in metrics:
        if not metric in metric_sorting_signs:
            print(f"unknown metric {metric}")
            args.usage()

    # read setup
    print(f"loading {args.setup}...")
    with open(args.setup, "r") as f:
        setup = yaml.load(f.read(), yaml.FullLoader)

    model_commands = setup['processing']

    print("\nfound the following model commands:")
    for command in model_commands:
        print(command.format(INPUT='input.wav', OUTPUT='output.wav', PLCFILE='input_is_lost.txt'))

    # store setup to output folder
    setup['input']  = os.path.abspath(args.input_folder)
    setup['output'] = os.path.abspath(args.output_folder)
    setup['seed']   = args.seed
    with open(os.path.join(args.output_folder, 'setup.yml'), 'w') as f:
        yaml.dump(setup, f)

    # get input
    print(f"\nCollecting audio files from {args.input_folder}...")
    file_list = get_wave_file_list(args.input_folder, check_for_features=False)
    print(f"...{len(file_list)} files found\n")

    # sample from file list
    file_list = sorted(file_list)
    random.seed(args.seed)
    random.shuffle(file_list)
    num_testitems = min(args.num_testitems, len(file_list))
    file_list = file_list[:num_testitems]


    print(f"\nlaunching test on {num_testitems} items...")
    # helper function for parallel processing
    def func(input_path):
        output_path = get_output_path(args.input_folder, input_path, args.output_folder)

        try:
            rval = run_processing_chain(input_path, output_path, model_commands, args.fs, metrics=metrics, plc_suffix=args.plc_suffix, verbose=args.verbose)
        except:
            rval = (input_path, -1)

        return rval

    with multiprocessing.Pool(args.num_workers) as p:
        results = p.map(func, file_list)

    results_dict = dict()
    for name, values in results:
        if is_valid_result(values, metrics):
            results_dict[name] = values

    print(results_dict)

    # evaluating results
    num_failures = num_testitems - len(results_dict)
    print(f"\nprocessing of {num_failures} items failed\n")

    for metric in metrics:
        print(metric)
        evaluate_results(
            args.output_folder,
            [(name, value[metric]) for name, value in results_dict.items()],
            metric
        )