shithub: opus

Download patch

ref: e25a585de8d1ba83fbca9fa398ea8a52b6f2bb1b
parent: 4ccfbdff04cf8071084884174915d77996b4a3ad
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Mon Nov 26 19:23:26 EST 2018

Match Python boundary condition

--- a/dnn/lpcnet.c
+++ b/dnn/lpcnet.c
@@ -51,6 +51,7 @@
     int last_exc;
     short last_sig[LPC_ORDER];
     float old_input[FEATURES_DELAY][FEATURE_CONV2_OUT_SIZE];
+    int frame_count;
 };
 
 
@@ -77,6 +78,12 @@
     return (int)floor(.5 + u);
 }
 
+static void print_vector(float *x, int N)
+{
+    int i;
+    for (i=0;i<N;i++) printf("%f ", x[i]);
+    printf("\n");
+}
 void run_frame_network(LPCNetState *lpcnet, float *condition, float *lpc, const float *features, int pitch)
 {
     int i;
@@ -89,8 +96,10 @@
     RNN_COPY(in, features, NB_FEATURES);
     compute_embedding(&embed_pitch, &in[NB_FEATURES], pitch);
     compute_conv1d(&feature_conv1, conv1_out, net->feature_conv1_state, in);
+    if (lpcnet->frame_count < FEATURE_CONV1_DELAY) RNN_CLEAR(conv1_out, FEATURE_CONV1_OUT_SIZE);
     compute_conv1d(&feature_conv2, conv2_out, net->feature_conv2_state, conv1_out);
     celt_assert(FRAME_INPUT_SIZE == FEATURE_CONV2_OUT_SIZE);
+    if (lpcnet->frame_count < FEATURES_DELAY) RNN_CLEAR(conv2_out, FEATURE_CONV2_OUT_SIZE);
     for (i=0;i<FEATURE_CONV2_OUT_SIZE;i++) conv2_out[i] += lpcnet->old_input[FEATURES_DELAY-1][i];
     memmove(lpcnet->old_input[1], lpcnet->old_input[0], (FEATURES_DELAY-1)*FRAME_INPUT_SIZE*sizeof(in[0]));
     memcpy(lpcnet->old_input[0], in, FRAME_INPUT_SIZE*sizeof(in[0]));
@@ -98,6 +107,7 @@
     compute_dense(&feature_dense2, condition, dense1_out);
     /* FIXME: Actually compute the LPC on the middle frame. */
     RNN_CLEAR(lpc, LPC_ORDER);
+    if (lpcnet->frame_count < 1000) lpcnet->frame_count++;
 }
 
 void run_sample_network(NNetState *net, float *pdf, const float *condition, int last_exc, int last_sig, int pred)
@@ -140,6 +150,11 @@
     /* FIXME: get the pitch gain from 2 frames in the past. */
     pitch_gain = features[PITCH_GAIN_FEATURE];
     run_frame_network(lpcnet, condition, lpc, features, pitch);
+    if (lpcnet->frame_count <= FEATURES_DELAY)
+    {
+        RNN_CLEAR(output, N);
+        return;
+    }
     for (i=0;i<N;i++)
     {
         int j;
--