shithub: opus

Download patch

ref: 0098fe70ac5a94952956146ed4795341a3639c79
parent: 87f9fbc50cb14fe950cffa97a7bd87c01c09461f
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Fri May 19 14:12:18 EDT 2023

Defer calls to run_frame_network() to save CPU

Calls are deferred to the actual loss and we only process the minimum
required.

--- a/dnn/lpcnet.c
+++ b/dnn/lpcnet.c
@@ -119,6 +119,30 @@
     if (lpcnet->frame_count < 1000) lpcnet->frame_count++;
 }
 
+void run_frame_network_deferred(LPCNetState *lpcnet, const float *features)
+{
+    int max_buffer_size = lpcnet->model.feature_conv1.kernel_size + lpcnet->model.feature_conv2.kernel_size - 2;
+    celt_assert(max_buffer_size <= MAX_FEATURE_BUFFER_SIZE);
+    if (lpcnet->feature_buffer_fill == max_buffer_size) {
+        RNN_MOVE(lpcnet->feature_buffer, &lpcnet->feature_buffer[NB_FEATURES],  (max_buffer_size-1)*NB_FEATURES);
+    } else {
+      lpcnet->feature_buffer_fill++;
+    }
+    RNN_COPY(&lpcnet->feature_buffer[(lpcnet->feature_buffer_fill-1)*NB_FEATURES], features, NB_FEATURES);
+}
+
+void run_frame_network_flush(LPCNetState *lpcnet)
+{
+    int i;
+    for (i=0;i<lpcnet->feature_buffer_fill;i++) {
+        float lpc[LPC_ORDER];
+        float gru_a_condition[3*GRU_A_STATE_SIZE];
+        float gru_b_condition[3*GRU_B_STATE_SIZE];
+        run_frame_network(lpcnet, gru_a_condition, gru_b_condition, lpc, &lpcnet->feature_buffer[i*NB_FEATURES]);
+    }
+    lpcnet->feature_buffer_fill = 0;
+}
+
 int run_sample_network(LPCNetState *lpcnet, const float *gru_a_condition, const float *gru_b_condition, int last_exc, int last_sig, int pred, const float *sampling_logit_table, kiss99_ctx *rng)
 {
     NNetState *net;
--- a/dnn/lpcnet_plc.c
+++ b/dnn/lpcnet_plc.c
@@ -189,11 +189,8 @@
         st->plc_net = st->plc_copy[FEATURES_DELAY];
         compute_plc_pred(&st->plc_net, st->features, zeros);
         for (i=0;i<FEATURES_DELAY;i++) {
-          float lpc[LPC_ORDER];
-          float gru_a_condition[3*GRU_A_STATE_SIZE];
-          float gru_b_condition[3*GRU_B_STATE_SIZE];
           /* FIXME: backtrack state, replace features. */
-          run_frame_network(&st->lpcnet, gru_a_condition, gru_b_condition, lpc, st->features);
+          run_frame_network_deferred(&st->lpcnet, st->features);
         }
         copy = st->lpcnet;
         lpcnet_synthesize_impl(&st->lpcnet, &st->features[0], tmp, FRAME_SIZE-TRAINING_OFFSET, 0);
@@ -238,11 +235,8 @@
   }
   if (st->skip_analysis) {
     if (st->enable_blending) {
-      float lpc[LPC_ORDER];
-      float gru_a_condition[3*GRU_A_STATE_SIZE];
-      float gru_b_condition[3*GRU_B_STATE_SIZE];
       /* FIXME: backtrack state, replace features. */
-      run_frame_network(&st->lpcnet, gru_a_condition, gru_b_condition, lpc, st->enc.features[0]);
+      run_frame_network_deferred(&st->lpcnet, st->enc.features[0]);
     }
     st->skip_analysis--;
   } else {
@@ -250,10 +244,7 @@
     RNN_COPY(output, &st->pcm[0], FRAME_SIZE);
 #ifdef PLC_SKIP_UPDATES
     {
-      float lpc[LPC_ORDER];
-      float gru_a_condition[3*GRU_A_STATE_SIZE];
-      float gru_b_condition[3*GRU_B_STATE_SIZE];
-      run_frame_network(&st->lpcnet, gru_a_condition, gru_b_condition, lpc, st->enc.features[0]);
+      run_frame_network_deferred(&st->lpcnet, st->enc.features[0]);
     }
 #else
     lpcnet_synthesize_impl(&st->lpcnet, st->enc.features[0], output, FRAME_SIZE, FRAME_SIZE);
@@ -274,6 +265,7 @@
 static int lpcnet_plc_conceal_causal(LPCNetPLCState *st, short *pcm) {
   int i;
   short output[FRAME_SIZE];
+  run_frame_network_flush(&st->lpcnet);
   st->enc.pcount = 0;
   /* If we concealed the previous frame, finish synthesizing the rest of the samples. */
   /* FIXME: Copy/predict features. */
--- a/dnn/lpcnet_private.h
+++ b/dnn/lpcnet_private.h
@@ -23,11 +23,14 @@
 #define FORBIDDEN_INTERP 7
 
 #define PLC_MAX_FEC 100
+#define MAX_FEATURE_BUFFER_SIZE 4
 
 struct LPCNetState {
     NNetState nnet;
     int last_exc;
     float last_sig[LPC_ORDER];
+    float feature_buffer[NB_FEATURES*MAX_FEATURE_BUFFER_SIZE];
+    int feature_buffer_fill;
     float last_features[NB_FEATURES];
 #if FEATURES_DELAY>0
     float old_lpc[FEATURES_DELAY][LPC_ORDER];
@@ -114,6 +117,10 @@
 
 void lpcnet_reset_signal(LPCNetState *lpcnet);
 void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features);
+void run_frame_network_deferred(LPCNetState *lpcnet, const float *features);
+void run_frame_network_flush(LPCNetState *lpcnet);
+
+
 void lpcnet_synthesize_tail_impl(LPCNetState *lpcnet, short *output, int N, int preload);
 void lpcnet_synthesize_impl(LPCNetState *lpcnet, const float *features, short *output, int N, int preload);
 void lpcnet_synthesize_blend_impl(LPCNetState *lpcnet, const short *pcm_in, short *output, int N);
--