ref: ba5dde539a21177ad91e84a00dfa0606c8f7149a
parent: ec6e42ba3d82ffaca24c2debd010111a60628408
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Thu Jun 1 00:09:57 EDT 2023
Add reset functions that don't clear the model
--- a/dnn/include/lpcnet.h
+++ b/dnn/include/lpcnet.h
@@ -75,6 +75,8 @@
*/
LPCNET_EXPORT int lpcnet_decoder_init(LPCNetDecState *st);
+LPCNET_EXPORT void lpcnet_reset(LPCNetState *lpcnet);
+
/** Allocates and initializes a decoder state.
* @returns The newly created state
*/
@@ -186,6 +188,7 @@
LPCNET_EXPORT int lpcnet_plc_get_size(void);
LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options);
+LPCNET_EXPORT void lpcnet_plc_reset(LPCNetPLCState *st);
LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(int options);
--- a/dnn/lpcnet.c
+++ b/dnn/lpcnet.c
@@ -171,23 +171,30 @@
return sizeof(LPCNetState);
}
+LPCNET_EXPORT void lpcnet_reset(LPCNetState *lpcnet)
+{
+ const char* rng_string="LPCNet";
+ RNN_CLEAR((char*)&lpcnet->LPCNET_RESET_START,
+ sizeof(LPCNetState)-
+ ((char*)&lpcnet->LPCNET_RESET_START - (char*)lpcnet));
+ lpcnet->last_exc = lin2ulaw(0.f);
+ kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string));
+}
+
LPCNET_EXPORT int lpcnet_init(LPCNetState *lpcnet)
{
int i;
int ret;
- const char* rng_string="LPCNet";
- memset(lpcnet, 0, lpcnet_get_size());
- lpcnet->last_exc = lin2ulaw(0.f);
for (i=0;i<256;i++) {
float prob = .025f+.95f*i/255.f;
lpcnet->sampling_logit_table[i] = -log((1-prob)/prob);
}
- kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string));
#ifndef USE_WEIGHTS_FILE
ret = init_lpcnet_model(&lpcnet->model, lpcnet_arrays);
#else
ret = 0;
#endif
+ lpcnet_reset(lpcnet);
celt_assert(ret == 0);
return ret;
}
--- a/dnn/lpcnet_plc.c
+++ b/dnn/lpcnet_plc.c
@@ -43,10 +43,11 @@
return sizeof(LPCNetPLCState);
}
-LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) {
- int ret;
- RNN_CLEAR(st, 1);
- lpcnet_init(&st->lpcnet);
+LPCNET_EXPORT void lpcnet_plc_reset(LPCNetPLCState *st) {
+ RNN_CLEAR((char*)&st->LPCNET_PLC_RESET_START,
+ sizeof(LPCNetPLCState)-
+ ((char*)&st->LPCNET_PLC_RESET_START - (char*)st));
+ lpcnet_reset(&st->lpcnet);
lpcnet_encoder_init(&st->enc);
RNN_CLEAR(st->pcm, PLC_BUF_SIZE);
st->pcm_fill = PLC_BUF_SIZE;
@@ -55,6 +56,12 @@
st->loss_count = 0;
st->dc_mem = 0;
st->queued_update = 0;
+}
+
+LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) {
+ int ret;
+ lpcnet_init(&st->lpcnet);
+ lpcnet_encoder_init(&st->enc);
if ((options&0x3) == LPCNET_PLC_CAUSAL) {
st->enable_blending = 1;
st->non_causal = 0;
@@ -74,6 +81,7 @@
ret = 0;
#endif
celt_assert(ret == 0);
+ lpcnet_plc_reset(st);
return ret;
}
--- a/dnn/lpcnet_private.h
+++ b/dnn/lpcnet_private.h
@@ -26,6 +26,11 @@
#define MAX_FEATURE_BUFFER_SIZE 4
struct LPCNetState {
+ LPCNetModel model;
+ float sampling_logit_table[256];
+ kiss99_ctx rng;
+
+#define LPCNET_RESET_START nnet
NNetState nnet;
int last_exc;
float last_sig[LPC_ORDER];
@@ -35,14 +40,11 @@
#if FEATURES_DELAY>0
float old_lpc[FEATURES_DELAY][LPC_ORDER];
#endif
- float sampling_logit_table[256];
float gru_a_condition[3*GRU_A_STATE_SIZE];
float gru_b_condition[3*GRU_B_STATE_SIZE];
int frame_count;
float deemph_mem;
float lpc[LPC_ORDER];
- kiss99_ctx rng;
- LPCNetModel model;
};
struct LPCNetDecState {
@@ -74,8 +76,14 @@
#define PLC_BUF_SIZE (FEATURES_DELAY*FRAME_SIZE + TRAINING_OFFSET)
struct LPCNetPLCState {
+ PLCModel model;
LPCNetState lpcnet;
LPCNetEncState enc;
+ int enable_blending;
+ int non_causal;
+ int remove_dc;
+
+#define LPCNET_PLC_RESET_START fec
float fec[PLC_MAX_FEC][NB_FEATURES];
int fec_keep_pos;
int fec_read_pos;
@@ -89,16 +97,12 @@
int loss_count;
PLCNetState plc_net;
PLCNetState plc_copy[FEATURES_DELAY+1];
- int enable_blending;
- int non_causal;
double dc_mem;
double syn_dc;
- int remove_dc;
short dc_buf[TRAINING_OFFSET];
int queued_update;
short queued_samples[FRAME_SIZE];
- PLCModel model;
};
extern float ceps_codebook1[];
--
⑨