shithub: opus

Download patch

ref: ba5dde539a21177ad91e84a00dfa0606c8f7149a
parent: ec6e42ba3d82ffaca24c2debd010111a60628408
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Thu Jun 1 00:09:57 EDT 2023

Add reset functions that don't clear the model

--- a/dnn/include/lpcnet.h
+++ b/dnn/include/lpcnet.h
@@ -75,6 +75,8 @@
   */
 LPCNET_EXPORT int lpcnet_decoder_init(LPCNetDecState *st);
 
+LPCNET_EXPORT void lpcnet_reset(LPCNetState *lpcnet);
+
 /** Allocates and initializes a decoder state.
   *  @returns The newly created state
   */
@@ -186,6 +188,7 @@
 LPCNET_EXPORT int lpcnet_plc_get_size(void);
 
 LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options);
+LPCNET_EXPORT void lpcnet_plc_reset(LPCNetPLCState *st);
 
 LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(int options);
 
--- a/dnn/lpcnet.c
+++ b/dnn/lpcnet.c
@@ -171,23 +171,30 @@
     return sizeof(LPCNetState);
 }
 
+LPCNET_EXPORT void lpcnet_reset(LPCNetState *lpcnet)
+{
+    const char* rng_string="LPCNet";
+    RNN_CLEAR((char*)&lpcnet->LPCNET_RESET_START,
+            sizeof(LPCNetState)-
+            ((char*)&lpcnet->LPCNET_RESET_START - (char*)lpcnet));
+    lpcnet->last_exc = lin2ulaw(0.f);
+    kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string));
+}
+
 LPCNET_EXPORT int lpcnet_init(LPCNetState *lpcnet)
 {
     int i;
     int ret;
-    const char* rng_string="LPCNet";
-    memset(lpcnet, 0, lpcnet_get_size());
-    lpcnet->last_exc = lin2ulaw(0.f);
     for (i=0;i<256;i++) {
         float prob = .025f+.95f*i/255.f;
         lpcnet->sampling_logit_table[i] = -log((1-prob)/prob);
     }
-    kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string));
 #ifndef USE_WEIGHTS_FILE
     ret = init_lpcnet_model(&lpcnet->model, lpcnet_arrays);
 #else
     ret = 0;
 #endif
+    lpcnet_reset(lpcnet);
     celt_assert(ret == 0);
     return ret;
 }
--- a/dnn/lpcnet_plc.c
+++ b/dnn/lpcnet_plc.c
@@ -43,10 +43,11 @@
   return sizeof(LPCNetPLCState);
 }
 
-LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) {
-  int ret;
-  RNN_CLEAR(st, 1);
-  lpcnet_init(&st->lpcnet);
+LPCNET_EXPORT void lpcnet_plc_reset(LPCNetPLCState *st) {
+  RNN_CLEAR((char*)&st->LPCNET_PLC_RESET_START,
+          sizeof(LPCNetPLCState)-
+          ((char*)&st->LPCNET_PLC_RESET_START - (char*)st));
+  lpcnet_reset(&st->lpcnet);
   lpcnet_encoder_init(&st->enc);
   RNN_CLEAR(st->pcm, PLC_BUF_SIZE);
   st->pcm_fill = PLC_BUF_SIZE;
@@ -55,6 +56,12 @@
   st->loss_count = 0;
   st->dc_mem = 0;
   st->queued_update = 0;
+}
+
+LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) {
+  int ret;
+  lpcnet_init(&st->lpcnet);
+  lpcnet_encoder_init(&st->enc);
   if ((options&0x3) == LPCNET_PLC_CAUSAL) {
     st->enable_blending = 1;
     st->non_causal = 0;
@@ -74,6 +81,7 @@
   ret = 0;
 #endif
   celt_assert(ret == 0);
+  lpcnet_plc_reset(st);
   return ret;
 }
 
--- a/dnn/lpcnet_private.h
+++ b/dnn/lpcnet_private.h
@@ -26,6 +26,11 @@
 #define MAX_FEATURE_BUFFER_SIZE 4
 
 struct LPCNetState {
+    LPCNetModel model;
+    float sampling_logit_table[256];
+    kiss99_ctx rng;
+
+#define LPCNET_RESET_START nnet
     NNetState nnet;
     int last_exc;
     float last_sig[LPC_ORDER];
@@ -35,14 +40,11 @@
 #if FEATURES_DELAY>0
     float old_lpc[FEATURES_DELAY][LPC_ORDER];
 #endif
-    float sampling_logit_table[256];
     float gru_a_condition[3*GRU_A_STATE_SIZE];
     float gru_b_condition[3*GRU_B_STATE_SIZE];
     int frame_count;
     float deemph_mem;
     float lpc[LPC_ORDER];
-    kiss99_ctx rng;
-    LPCNetModel model;
 };
 
 struct LPCNetDecState {
@@ -74,8 +76,14 @@
 
 #define PLC_BUF_SIZE (FEATURES_DELAY*FRAME_SIZE + TRAINING_OFFSET)
 struct LPCNetPLCState {
+  PLCModel model;
   LPCNetState lpcnet;
   LPCNetEncState enc;
+  int enable_blending;
+  int non_causal;
+  int remove_dc;
+
+#define LPCNET_PLC_RESET_START fec
   float fec[PLC_MAX_FEC][NB_FEATURES];
   int fec_keep_pos;
   int fec_read_pos;
@@ -89,16 +97,12 @@
   int loss_count;
   PLCNetState plc_net;
   PLCNetState plc_copy[FEATURES_DELAY+1];
-  int enable_blending;
-  int non_causal;
   double dc_mem;
   double syn_dc;
-  int remove_dc;
 
   short dc_buf[TRAINING_OFFSET];
   int queued_update;
   short queued_samples[FRAME_SIZE];
-  PLCModel model;
 };
 
 extern float ceps_codebook1[];
--