shithub: opus

ref: 35ee397e060283d30c098ae5e17836316bbec08b
dir: /dnn/torch/lpcnet/train_lpcnet.py/

View raw version
import os
import argparse
import sys

try:
    import git
    has_git = True
except:
    has_git = False

import yaml


import torch
from torch.optim.lr_scheduler import LambdaLR

from data import LPCNetDataset
from models import model_dict
from engine.lpcnet_engine import train_one_epoch, evaluate
from utils.data import load_features
from utils.wav import wavwrite16


debug = False
if debug:
    args = type('dummy', (object,),
    {
        'setup' : 'setup.yml',
        'output' : 'testout',
        'device' : None,
        'test_features' : None,
        'finalize': False,
        'initial_checkpoint': None,
        'no-redirect': False
    })()
else:
    parser = argparse.ArgumentParser("train_lpcnet.py")
    parser.add_argument('setup', type=str, help='setup yaml file')
    parser.add_argument('output', type=str, help='output path')
    parser.add_argument('--device', type=str, help='compute device', default=None)
    parser.add_argument('--test-features', type=str, help='test feature file in v2 format', default=None)
    parser.add_argument('--finalize', action='store_true', help='run single training round with lr=1e-5')
    parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
    parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of output')

    args = parser.parse_args()


torch.set_num_threads(4)

with open(args.setup, 'r') as f:
    setup = yaml.load(f.read(), yaml.FullLoader)

if args.finalize:
    if args.initial_checkpoint is None:
        raise ValueError('finalization requires initial checkpoint')

    if 'sparsification' in setup['lpcnet']['config']:
        for sp_job in setup['lpcnet']['config']['sparsification'].values():
            sp_job['start'], sp_job['stop'] = 0, 0

    setup['training']['lr'] = 1.0e-5
    setup['training']['lr_decay_factor'] = 0.0
    setup['training']['epochs'] = 1

    checkpoint_prefix = 'checkpoint_finalize'
    output_prefix = 'output_finalize'
    setup_name = 'setup_finalize.yml'
    output_file='out_finalize.txt'
else:
    checkpoint_prefix = 'checkpoint'
    output_prefix = 'output'
    setup_name = 'setup.yml'
    output_file='out.txt'


# check model
if not 'model' in setup['lpcnet']:
    print(f'warning: did not find model entry in setup, using default lpcnet')
    model_name = 'lpcnet'
else:
    model_name = setup['lpcnet']['model']

# prepare output folder
if os.path.exists(args.output) and not debug and not args.finalize:
    print("warning: output folder exists")

    reply = input('continue? (y/n): ')
    while reply not in {'y', 'n'}:
        reply = input('continue? (y/n): ')

    if reply == 'n':
        os._exit()
else:
    os.makedirs(args.output, exist_ok=True)

checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)


# add repo info to setup
if has_git:
    working_dir = os.path.split(__file__)[0]
    try:
        repo = git.Repo(working_dir)
        setup['repo'] = dict()
        hash = repo.head.object.hexsha
        urls = list(repo.remote().urls)
        is_dirty = repo.is_dirty()

        if is_dirty:
            print("warning: repo is dirty")

        setup['repo']['hash'] = hash
        setup['repo']['urls'] = urls
        setup['repo']['dirty'] = is_dirty
    except:
        has_git = False

# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
    yaml.dump(setup, f)

# prepare inference test if wanted
run_inference_test = False
if type(args.test_features) != type(None):
    test_features = load_features(args.test_features)
    inference_test_dir = os.path.join(args.output, 'inference_test')
    os.makedirs(inference_test_dir, exist_ok=True)
    run_inference_test = True

# training parameters
batch_size      = setup['training']['batch_size']
epochs          = setup['training']['epochs']
lr              = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']

# load training dataset
lpcnet_config = setup['lpcnet']['config']
data = LPCNetDataset(   setup['dataset'],
                        features=lpcnet_config['features'],
                        input_signals=lpcnet_config['signals'],
                        target=lpcnet_config['target'],
                        frames_per_sample=setup['training']['frames_per_sample'],
                        feature_history=lpcnet_config['feature_history'],
                        feature_lookahead=lpcnet_config['feature_lookahead'],
                        lpc_gamma=lpcnet_config.get('lpc_gamma', 1))

# load validation dataset if given
if 'validation_dataset' in setup:
    validation_data = LPCNetDataset(   setup['validation_dataset'],
                        features=lpcnet_config['features'],
                        input_signals=lpcnet_config['signals'],
                        target=lpcnet_config['target'],
                        frames_per_sample=setup['training']['frames_per_sample'],
                        feature_history=lpcnet_config['feature_history'],
                        feature_lookahead=lpcnet_config['feature_lookahead'],
                        lpc_gamma=lpcnet_config.get('lpc_gamma', 1))

    validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)

    run_validation = True
else:
    run_validation = False

# create model
model = model_dict[model_name](setup['lpcnet']['config'])

if args.initial_checkpoint is not None:
    print(f"loading state dict from {args.initial_checkpoint}...")
    chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
    model.load_state_dict(chkpt['state_dict'])

# set compute device
if type(args.device) == type(None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device(args.device)

# push model to device
model.to(device)

# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)

# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr)

# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))

# loss
criterion = torch.nn.NLLLoss()

# model checkpoint
checkpoint = {
    'setup'         : setup,
    'state_dict'    : model.state_dict(),
    'loss'          : -1
}

if not args.no_redirect:
    print(f"re-directing output to {os.path.join(args.output, output_file)}")
    sys.stdout = open(os.path.join(args.output, output_file), "w")

best_loss = 1e9

for ep in range(1, epochs + 1):
    print(f"training epoch {ep}...")
    new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)


    # save checkpoint
    checkpoint['state_dict'] = model.state_dict()
    checkpoint['loss']       = new_loss

    if run_validation:
        print("running validation...")
        validation_loss = evaluate(model, criterion, validation_dataloader, device)
        checkpoint['validation_loss'] = validation_loss

        if validation_loss < best_loss:
            torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
            best_loss = validation_loss

    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))

    # run inference test
    if run_inference_test:
        model.to("cpu")
        print("running inference test...")

        output = model.generate(test_features['features'], test_features['periods'], test_features['lpcs'])

        testfilename = os.path.join(inference_test_dir, output_prefix + f'_epoch_{ep}.wav')

        wavwrite16(testfilename, output.numpy(), 16000)

        model.to(device)

    print()