shithub: opus

Download patch

ref: b0e1a2eb95c684114431501cc1bb57fc3fb49842
parent: 3eac8c12e476f68eaa7608cb078379e980d3a7ab
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Aug 2 22:06:15 EDT 2023

Applying continuation after first subframe

Continuation matches Python code now

--- a/dnn/fwgan.c
+++ b/dnn/fwgan.c
@@ -108,28 +108,8 @@
   compute_generic_dense(&model->cont_net_6, tmp2, tmp1, ACTIVATION_TANH);
   compute_generic_dense(&model->cont_net_8, tmp1, tmp2, ACTIVATION_TANH);
   celt_assert(CONT_NET_10_OUT_SIZE == model->cont_net_10.nb_outputs);
-  compute_generic_dense(&model->cont_net_10, cont_inputs, tmp1, ACTIVATION_TANH);
+  compute_generic_dense(&model->cont_net_10, st->cont, tmp1, ACTIVATION_TANH);
 
-  celt_assert(RNN_GRU_STATE_SIZE == model->rnn_cont_fc_0.nb_outputs);
-  compute_generic_dense(&model->rnn_cont_fc_0, st->rnn_state, cont_inputs, ACTIVATION_TANH);
-
-  celt_assert(FWC1_STATE_SIZE == model->fwc1_cont_fc_0.nb_outputs);
-  compute_generic_dense(&model->fwc1_cont_fc_0, st->fwc1_state, cont_inputs, ACTIVATION_TANH);
-  celt_assert(FWC2_STATE_SIZE == model->fwc2_cont_fc_0.nb_outputs);
-  compute_generic_dense(&model->fwc2_cont_fc_0, st->fwc2_state, cont_inputs, ACTIVATION_TANH);
-  celt_assert(FWC3_STATE_SIZE == model->fwc3_cont_fc_0.nb_outputs);
-  compute_generic_dense(&model->fwc3_cont_fc_0, st->fwc3_state, cont_inputs, ACTIVATION_TANH);
-  celt_assert(FWC4_STATE_SIZE == model->fwc4_cont_fc_0.nb_outputs);
-  compute_generic_dense(&model->fwc4_cont_fc_0, st->fwc4_state, cont_inputs, ACTIVATION_TANH);
-  celt_assert(FWC5_STATE_SIZE == model->fwc5_cont_fc_0.nb_outputs);
-  compute_generic_dense(&model->fwc5_cont_fc_0, st->fwc5_state, cont_inputs, ACTIVATION_TANH);
-  celt_assert(FWC6_STATE_SIZE == model->fwc6_cont_fc_0.nb_outputs);
-  compute_generic_dense(&model->fwc6_cont_fc_0, st->fwc6_state, cont_inputs, ACTIVATION_TANH);
-  celt_assert(FWC7_STATE_SIZE == model->fwc7_cont_fc_0.nb_outputs);
-  compute_generic_dense(&model->fwc7_cont_fc_0, st->fwc7_state, cont_inputs, ACTIVATION_TANH);
-
-  /* FIXME: Do we need to handle initial features? How? */
-
   st->cont_initialized = 1;
 }
 
@@ -212,6 +192,29 @@
 
   compute_generic_conv1d(&model->fwc7_fc_0, tmp1, st->fwc7_state, tmp2, FWC6_FC_0_OUT_SIZE, ACTIVATION_LINEAR);
   compute_gated_activation(&model->fwc7_fc_1_gate, pcm, tmp1, ACTIVATION_TANH);
+
+  if (st->cont_initialized == 1) {
+    celt_assert(RNN_GRU_STATE_SIZE == model->rnn_cont_fc_0.nb_outputs);
+    compute_generic_dense(&model->rnn_cont_fc_0, st->rnn_state, st->cont, ACTIVATION_TANH);
+
+    celt_assert(FWC1_STATE_SIZE == model->fwc1_cont_fc_0.nb_outputs);
+    compute_generic_dense(&model->fwc1_cont_fc_0, st->fwc1_state, st->cont, ACTIVATION_TANH);
+    celt_assert(FWC2_STATE_SIZE == model->fwc2_cont_fc_0.nb_outputs);
+    compute_generic_dense(&model->fwc2_cont_fc_0, st->fwc2_state, st->cont, ACTIVATION_TANH);
+    celt_assert(FWC3_STATE_SIZE == model->fwc3_cont_fc_0.nb_outputs);
+    compute_generic_dense(&model->fwc3_cont_fc_0, st->fwc3_state, st->cont, ACTIVATION_TANH);
+    celt_assert(FWC4_STATE_SIZE == model->fwc4_cont_fc_0.nb_outputs);
+    compute_generic_dense(&model->fwc4_cont_fc_0, st->fwc4_state, st->cont, ACTIVATION_TANH);
+    celt_assert(FWC5_STATE_SIZE == model->fwc5_cont_fc_0.nb_outputs);
+    compute_generic_dense(&model->fwc5_cont_fc_0, st->fwc5_state, st->cont, ACTIVATION_TANH);
+    celt_assert(FWC6_STATE_SIZE == model->fwc6_cont_fc_0.nb_outputs);
+    compute_generic_dense(&model->fwc6_cont_fc_0, st->fwc6_state, st->cont, ACTIVATION_TANH);
+    celt_assert(FWC7_STATE_SIZE == model->fwc7_cont_fc_0.nb_outputs);
+    compute_generic_dense(&model->fwc7_cont_fc_0, st->fwc7_state, st->cont, ACTIVATION_TANH);
+    /* FIXME: Do we need to handle initial features? How? */
+    st->cont_initialized = 2;
+  }
+
 }
 
 
--- a/dnn/fwgan.h
+++ b/dnn/fwgan.h
@@ -51,8 +51,8 @@
   float syn_mem[LPC_ORDER];
   float preemph_mem;
   float deemph_mem;
+  float cont[CONT_NET_10_OUT_SIZE];
   float cont_conv1_mem[FEAT_IN_CONV1_CONV_STATE_SIZE];
-  float cont[FEAT_IN_NL1_GATE_OUT_SIZE];
   float rnn_state[RNN_GRU_STATE_SIZE];
   float fwc1_state[FWC1_STATE_SIZE];
   float fwc2_state[FWC2_STATE_SIZE];
--