ref: e25a585de8d1ba83fbca9fa398ea8a52b6f2bb1b
parent: 4ccfbdff04cf8071084884174915d77996b4a3ad
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Mon Nov 26 19:23:26 EST 2018
Match Python boundary condition
--- a/dnn/lpcnet.c
+++ b/dnn/lpcnet.c
@@ -51,6 +51,7 @@
int last_exc;
short last_sig[LPC_ORDER];
float old_input[FEATURES_DELAY][FEATURE_CONV2_OUT_SIZE];
+ int frame_count;
};
@@ -77,6 +78,12 @@
return (int)floor(.5 + u);
}
+static void print_vector(float *x, int N)
+{
+ int i;
+ for (i=0;i<N;i++) printf("%f ", x[i]);
+ printf("\n");
+}
void run_frame_network(LPCNetState *lpcnet, float *condition, float *lpc, const float *features, int pitch)
{
int i;
@@ -89,8 +96,10 @@
RNN_COPY(in, features, NB_FEATURES);
compute_embedding(&embed_pitch, &in[NB_FEATURES], pitch);
compute_conv1d(&feature_conv1, conv1_out, net->feature_conv1_state, in);
+ if (lpcnet->frame_count < FEATURE_CONV1_DELAY) RNN_CLEAR(conv1_out, FEATURE_CONV1_OUT_SIZE);
compute_conv1d(&feature_conv2, conv2_out, net->feature_conv2_state, conv1_out);
celt_assert(FRAME_INPUT_SIZE == FEATURE_CONV2_OUT_SIZE);
+ if (lpcnet->frame_count < FEATURES_DELAY) RNN_CLEAR(conv2_out, FEATURE_CONV2_OUT_SIZE);
for (i=0;i<FEATURE_CONV2_OUT_SIZE;i++) conv2_out[i] += lpcnet->old_input[FEATURES_DELAY-1][i];
memmove(lpcnet->old_input[1], lpcnet->old_input[0], (FEATURES_DELAY-1)*FRAME_INPUT_SIZE*sizeof(in[0]));
memcpy(lpcnet->old_input[0], in, FRAME_INPUT_SIZE*sizeof(in[0]));
@@ -98,6 +107,7 @@
compute_dense(&feature_dense2, condition, dense1_out);
/* FIXME: Actually compute the LPC on the middle frame. */
RNN_CLEAR(lpc, LPC_ORDER);
+ if (lpcnet->frame_count < 1000) lpcnet->frame_count++;
}
void run_sample_network(NNetState *net, float *pdf, const float *condition, int last_exc, int last_sig, int pred)
@@ -140,6 +150,11 @@
/* FIXME: get the pitch gain from 2 frames in the past. */
pitch_gain = features[PITCH_GAIN_FEATURE];
run_frame_network(lpcnet, condition, lpc, features, pitch);
+ if (lpcnet->frame_count <= FEATURES_DELAY)
+ {
+ RNN_CLEAR(output, N);
+ return;
+ }
for (i=0;i<N;i++)
{
int j;
--
⑨