shithub: opus

ref: 35ee397e060283d30c098ae5e17836316bbec08b
dir: /dnn/torch/lpcnet/test_lpcnet.py/

View raw version
import argparse

import torch
import numpy as np


from models import model_dict
from utils.data import load_features
from utils.wav import wavwrite16

debug = False
if debug:
    args = type('dummy', (object,),
    {
        'features'      : 'features.f32',
        'checkpoint'    : 'checkpoint.pth',
        'output'        : 'out.wav',
        'version'       : 2
    })()
else:
    parser = argparse.ArgumentParser()

    parser.add_argument('features', type=str, help='feature file')
    parser.add_argument('checkpoint', type=str, help='checkpoint file')
    parser.add_argument('output', type=str, help='output file')
    parser.add_argument('--version', type=int, help='feature version', default=2)

    args = parser.parse_args()


torch.set_num_threads(2)

version = args.version
feature_file = args.features
checkpoint_file = args.checkpoint



output_file = args.output
if not output_file.endswith('.wav'):
    output_file += '.wav'

checkpoint = torch.load(checkpoint_file, map_location="cpu")

# check model
if not 'model' in checkpoint['setup']['lpcnet']:
    print(f'warning: did not find model entry in setup, using default lpcnet')
    model_name = 'lpcnet'
else:
    model_name = checkpoint['setup']['lpcnet']['model']

model = model_dict[model_name](checkpoint['setup']['lpcnet']['config'])

model.load_state_dict(checkpoint['state_dict'])

data = load_features(feature_file)

output = model.generate(data['features'], data['periods'], data['lpcs'])

wavwrite16(output_file, output.numpy(), 16000)