shithub: opus

Download patch

ref: c40add59af065f4fdf80048f2dad91d6b4480114
parent: 627aa7f5b3688ba787c69e55e199ba82e2013be0
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Thu Dec 21 11:57:35 EST 2023

lossgen: can now dump weights

--- /dev/null
+++ b/dnn/torch/lossgen/export_lossgen.py
@@ -1,0 +1,101 @@
+"""
+/* Copyright (c) 2022 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+import sys
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('checkpoint', type=str, help='model checkpoint')
+parser.add_argument('output_dir', type=str, help='output folder')
+
+args = parser.parse_args()
+
+import torch
+import numpy as np
+
+import lossgen
+from wexchange.torch import dump_torch_weights
+from wexchange.c_export import CWriter, print_vector
+
+def c_export(args, model):
+
+    message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
+
+    writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen')
+    writer.header.write(
+f"""
+#include "opus_types.h"
+"""
+        )
+
+    dense_layers = [
+        ('dense_in', "lossgen_dense_in"),
+        ('dense_out', "lossgen_dense_out")
+    ]
+
+
+    for name, export_name in dense_layers:
+        layer = model.get_submodule(name)
+        dump_torch_weights(writer, layer, name=export_name, verbose=True, quantize=False, scale=None)
+
+
+    gru_layers = [
+        ("gru1", "lossgen_gru1"),
+        ("gru2", "lossgen_gru2"),
+    ]
+
+    max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=True, scale=None, recurrent_scale=None)
+                             for name, export_name in gru_layers])
+
+    writer.header.write(
+f"""
+
+#define LOSSGEN_MAX_RNN_UNITS {max_rnn_units}
+
+"""
+        )
+
+    writer.close()
+
+
+if __name__ == "__main__":
+
+    os.makedirs(args.output_dir, exist_ok=True)
+    checkpoint = torch.load(args.checkpoint, map_location='cpu')
+    model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+    model.load_state_dict(checkpoint['state_dict'], strict=False)
+    #model = LossGen()
+    #checkpoint = torch.load(args.checkpoint, map_location='cpu')
+    #model.load_state_dict(checkpoint['state_dict'])
+    c_export(args, model)
--- a/dnn/torch/lossgen/lossgen.py
+++ b/dnn/torch/lossgen/lossgen.py
@@ -8,7 +8,8 @@
 
         self.gru1_size = gru1_size
         self.gru2_size = gru2_size
-        self.gru1 = nn.GRU(2, self.gru1_size, batch_first=True)
+        self.dense_in = nn.Linear(2, 8)
+        self.gru1 = nn.GRU(8, self.gru1_size, batch_first=True)
         self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True)
         self.dense_out = nn.Linear(self.gru2_size, 1)
 
@@ -22,7 +23,7 @@
         else:
             gru1_state = states[0]
             gru2_state = states[1]
-        x = torch.cat([loss, perc], dim=-1)
+        x = torch.tanh(self.dense_in(torch.cat([loss, perc], dim=-1)))
         gru1_out, gru1_state = self.gru1(x, gru1_state)
         gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
         return self.dense_out(gru2_out), [gru1_state, gru2_state]
--- a/dnn/torch/lossgen/test_lossgen.py
+++ b/dnn/torch/lossgen/test_lossgen.py
@@ -18,10 +18,7 @@
 
 
 checkpoint = torch.load(args.model, map_location='cpu')
-
 model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
-
-
 model.load_state_dict(checkpoint['state_dict'], strict=False)
 
 states=None
--- a/dnn/torch/lossgen/train_lossgen.py
+++ b/dnn/torch/lossgen/train_lossgen.py
@@ -32,13 +32,13 @@
         return [self.loss[index, :, :], self.perc[index, :, :]+r0+r1]
 
 
-adam_betas = [0.8, 0.99]
+adam_betas = [0.8, 0.98]
 adam_eps = 1e-8
-batch_size=512
-lr_decay = 0.0001
-lr = 0.001
+batch_size=256
+lr_decay = 0.001
+lr = 0.003
 epsilon = 1e-5
-epochs = 20
+epochs = 2000
 checkpoint_dir='checkpoint'
 os.makedirs(checkpoint_dir, exist_ok=True)
 checkpoint = dict()
--