ref: feb32828877ea5e8723ea2a446eb20d7b3fba426
parent: 62b546436fc07035802eb998f61702ee2716db60
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Mon Oct 30 10:08:07 EDT 2023
Don't try to use models that aren't loaded
--- a/celt/celt_decoder.c
+++ b/celt/celt_decoder.c
@@ -721,7 +721,7 @@
if (loss_duration == 0)
{
#ifdef ENABLE_DEEP_PLC
- update_plc_state(lpcnet, decode_mem, &st->plc_preemphasis_mem, C);
+ if (lpcnet->loaded) update_plc_state(lpcnet, decode_mem, &st->plc_preemphasis_mem, C);
#endif
st->last_pitch_index = pitch_index = celt_plc_pitch_search(decode_mem, C, st->arch);
} else {
@@ -914,7 +914,7 @@
} while (++c<C);
#ifdef ENABLE_DEEP_PLC
- if (st->complexity >= 5 || lpcnet->fec_fill_pos > 0) {
+ if (lpcnet->loaded && (st->complexity >= 5 || lpcnet->fec_fill_pos > 0)) {
float overlap_mem;
int samples_needed16k;
celt_sig *buf;
--- a/dnn/lpcnet_plc.c
+++ b/dnn/lpcnet_plc.c
@@ -57,8 +57,10 @@
fargan_init(&st->fargan);
lpcnet_encoder_init(&st->enc);
st->analysis_pos = PLC_BUF_SIZE;
+ st->loaded = 0;
#ifndef USE_WEIGHTS_FILE
ret = init_plc_model(&st->model, lpcnet_plc_arrays);
+ if (ret == 0) st->loaded = 1;
#else
ret = 0;
#endif
@@ -75,11 +77,12 @@
free(list);
if (ret == 0) {
ret = lpcnet_encoder_load_model(&st->enc, data, len);
- } else return -1;
+ }
if (ret == 0) {
- return fargan_load_model(&st->fargan, data, len);
+ ret = fargan_load_model(&st->fargan, data, len);
}
- else return -1;
+ if (ret == 0) st->loaded = 1;
+ return ret;
}
void lpcnet_plc_fec_add(LPCNetPLCState *st, const float *features) {
@@ -105,6 +108,7 @@
float zeros[3*PLC_MAX_RNN_NEURONS] = {0};
float dense_out[PLC_DENSE1_OUT_SIZE];
PLCNetState *net = &st->plc_net;
+ celt_assert(st->loaded);
_lpcnet_compute_dense(&st->model.plc_dense1, dense_out, in);
compute_gruB(&st->model.plc_gru1, zeros, net->plc_gru1_state, dense_out);
compute_gruB(&st->model.plc_gru2, zeros, net->plc_gru2_state, net->plc_gru1_state);
@@ -152,6 +156,7 @@
static const float att_table[10] = {0, 0, -.2, -.2, -.4, -.4, -.8, -.8, -1.6, -1.6};
int lpcnet_plc_conceal(LPCNetPLCState *st, opus_int16 *pcm) {
int i;
+ celt_assert(st->loaded);
if (st->blend == 0) {
int count = 0;
while (st->analysis_pos + FRAME_SIZE <= PLC_BUF_SIZE) {
--- a/dnn/lpcnet_private.h
+++ b/dnn/lpcnet_private.h
@@ -47,6 +47,7 @@
PLCModel model;
FARGANState fargan;
LPCNetEncState enc;
+ int loaded;
int arch;
#define LPCNET_PLC_RESET_START fec
--- a/silk/PLC.c
+++ b/silk/PLC.c
@@ -397,7 +397,7 @@
frame[ i ] = (opus_int16)silk_SAT16( silk_SAT16( silk_RSHIFT_ROUND( silk_SMULWW( sLPC_Q14_ptr[ MAX_LPC_ORDER + i ], prevGain_Q10[ 1 ] ), 8 ) ) );
}
#ifdef ENABLE_DEEP_PLC
- if ( lpcnet != NULL && psDec->sPLC.fs_kHz == 16 ) {
+ if ( lpcnet != NULL && lpcnet->loaded && psDec->sPLC.fs_kHz == 16 ) {
int run_deep_plc = psDec->sPLC.enable_deep_plc || lpcnet->fec_fill_pos != 0;
if( run_deep_plc ) {
for( k = 0; k < psDec->nb_subfr; k += 2 ) {
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -57,6 +57,7 @@
if (ret == 0) {
ret = lpcnet_encoder_load_model(&enc->lpcnet_enc_state, data, len);
}
+ if (ret == 0) enc->loaded = 1;
return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
}
@@ -74,8 +75,9 @@
{
enc->Fs = Fs;
enc->channels = channels;
+ enc->loaded = 0;
#ifndef USE_WEIGHTS_FILE
- init_rdovaeenc(&enc->model, rdovaeenc_arrays);
+ if (init_rdovaeenc(&enc->model, rdovaeenc_arrays) == 0) enc->loaded = 1;
#endif
dred_encoder_reset(enc);
}
@@ -85,6 +87,7 @@
float feature_buffer[2 * 36];
float input_buffer[2*DRED_NUM_FEATURES] = {0};
+ celt_assert(enc->loaded);
/* shift latents buffer */
OPUS_MOVE(enc->latents_buffer + DRED_LATENT_DIM, enc->latents_buffer, (DRED_MAX_FRAMES - 1) * DRED_LATENT_DIM);
@@ -184,6 +187,7 @@
{
int curr_offset16k;
int frame_size16k = frame_size * 16000 / enc->Fs;
+ celt_assert(enc->loaded);
curr_offset16k = 40 + extra_delay*16000/enc->Fs - enc->input_buffer_fill;
enc->dred_offset = (int)floor((curr_offset16k+20.f)/40.f);
enc->latent_offset = 0;
--- a/silk/dred_encoder.h
+++ b/silk/dred_encoder.h
@@ -40,6 +40,9 @@
typedef struct {
RDOVAEEnc model;
+ LPCNetEncState lpcnet_enc_state;
+ RDOVAEEncState rdovae_enc;
+ int loaded;
opus_int32 Fs;
int channels;
@@ -53,8 +56,6 @@
float state_buffer[DRED_STATE_DIM];
float initial_state[DRED_STATE_DIM];
float resample_mem[RESAMPLING_ORDER + 1];
- LPCNetEncState lpcnet_enc_state;
- RDOVAEEncState rdovae_enc;
} DREDEnc;
int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len);
--- a/src/opus_decoder.c
+++ b/src/opus_decoder.c
@@ -1042,7 +1042,7 @@
{
goto bad_arg;
}
- return lpcnet_plc_load_model(&st->lpcnet, data, len);
+ ret = lpcnet_plc_load_model(&st->lpcnet, data, len);
}
break;
#endif
@@ -1156,6 +1156,7 @@
#ifdef ENABLE_DRED
RDOVAEDec model;
#endif
+ int loaded;
int arch;
opus_uint32 magic;
};
@@ -1188,6 +1189,7 @@
parse_weights(&list, data, len);
ret = init_rdovaedec(&dec->model, list);
free(list);
+ if (ret == 0) dec->loaded = 1;
return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
}
#endif
@@ -1194,13 +1196,16 @@
int opus_dred_decoder_init(OpusDREDDecoder *dec)
{
+ int ret = 0;
+ dec->loaded = 0;
#if defined(ENABLE_DRED) && !defined(USE_WEIGHTS_FILE)
- init_rdovaedec(&dec->model, rdovaedec_arrays);
+ ret = init_rdovaedec(&dec->model, rdovaedec_arrays);
+ if (ret == 0) dec->loaded = 1;
#endif
dec->arch = opus_select_arch();
/* To make sure nobody forgets to init, use a magic number. */
dec->magic = 0xD8EDDEC0;
- return OPUS_OK;
+ return (ret == 0) ? OPUS_OK : OPUS_UNIMPLEMENTED;
}
OpusDREDDecoder *opus_dred_decoder_create(int *error)
@@ -1378,6 +1383,7 @@
const unsigned char *payload;
opus_int32 payload_len;
VALIDATE_DRED_DECODER(dred_dec);
+ if (!dred_dec->loaded) return OPUS_UNIMPLEMENTED;
dred->process_stage = -1;
payload_len = dred_find_payload(data, len, &payload);
if (payload_len < 0)
@@ -1412,6 +1418,7 @@
if (dred_dec == NULL || src == NULL || dst == NULL || (src->process_stage != 1 && src->process_stage != 2))
return OPUS_BAD_ARG;
VALIDATE_DRED_DECODER(dred_dec);
+ if (!dred_dec->loaded) return OPUS_UNIMPLEMENTED;
if (src != dst)
OPUS_COPY(dst, src, 1);
if (dst->process_stage == 2)
--- a/src/opus_encoder.c
+++ b/src/opus_encoder.c
@@ -1713,7 +1713,7 @@
#endif
#ifdef ENABLE_DRED
- if ( st->dred_duration > 0 ) {
+ if ( st->dred_duration > 0 && st->dred_encoder.loaded ) {
/* DRED Encoder */
dred_compute_latents( &st->dred_encoder, &pcm_buf[total_buffer*st->channels], frame_size, total_buffer );
} else {
@@ -2255,7 +2255,7 @@
ret += 1+redundancy_bytes;
apply_padding = !st->use_vbr;
#ifdef ENABLE_DRED
- if (st->dred_duration > 0) {
+ if (st->dred_duration > 0 && st->dred_encoder.loaded) {
opus_extension_data extension;
unsigned char buf[DRED_MAX_DATA_SIZE];
int dred_chunks;
@@ -2893,17 +2893,17 @@
}
break;
#ifdef USE_WEIGHTS_FILE
- case OPUS_SET_DNN_BLOB_REQUEST:
- {
- const unsigned char *data = va_arg(ap, const unsigned char *);
- opus_int32 len = va_arg(ap, opus_int32);
- if(len<0 || data == NULL)
- {
- goto bad_arg;
- }
- return dred_encoder_load_model(&st->dred_encoder, data, len);
- }
- break;
+ case OPUS_SET_DNN_BLOB_REQUEST:
+ {
+ const unsigned char *data = va_arg(ap, const unsigned char *);
+ opus_int32 len = va_arg(ap, opus_int32);
+ if(len<0 || data == NULL)
+ {
+ goto bad_arg;
+ }
+ ret = dred_encoder_load_model(&st->dred_encoder, data, len);
+ }
+ break;
#endif
case CELT_GET_MODE_REQUEST:
{
--
⑨