shithub: opus

Download patch

ref: 040aa437c3791de80b58b5ef83c1412a2d3e608c
parent: 6c2f7e58fdd4d73b24747428230c20eda0890728
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Wed Nov 28 07:37:18 EST 2018

Simper GRU implementation just for reset_after.

--- a/dnn/lpcnet.c
+++ b/dnn/lpcnet.c
@@ -121,10 +121,10 @@
     compute_embedding(&embed_sig, &in_a[EMBED_SIG_OUT_SIZE], pred);
     compute_embedding(&embed_exc, &in_a[2*EMBED_SIG_OUT_SIZE], last_exc);
     RNN_COPY(&in_a[2*EMBED_SIG_OUT_SIZE + EMBED_EXC_OUT_SIZE], condition, FEATURE_DENSE2_OUT_SIZE);
-    compute_gru(&gru_a, net->gru_a_state, in_a);
+    compute_gru2(&gru_a, net->gru_a_state, in_a);
     RNN_COPY(in_b, net->gru_a_state, GRU_A_STATE_SIZE);
     RNN_COPY(&in_b[GRU_A_STATE_SIZE], condition, FEATURE_DENSE2_OUT_SIZE);
-    compute_gru(&gru_b, net->gru_b_state, in_b);
+    compute_gru2(&gru_b, net->gru_b_state, in_b);
     compute_mdense(&dual_fc, pdf, net->gru_b_state);
 }
 
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -218,6 +218,44 @@
       state[i] = h[i];
 }
 
+void compute_gru2(const GRULayer *gru, float *state, const float *input)
+{
+   int i;
+   int N, M;
+   int stride;
+   float zrh[3*MAX_RNN_NEURONS];
+   float recur[3*MAX_RNN_NEURONS];
+   float *z;
+   float *r;
+   float *h;
+   M = gru->nb_inputs;
+   N = gru->nb_neurons;
+   z = zrh;
+   r = &zrh[N];
+   h = &zrh[2*N];
+   celt_assert(gru->nb_neurons <= MAX_RNN_NEURONS);
+   celt_assert(input != state);
+   celt_assert(gru->reset_after);
+   stride = 3*N;
+   /* Compute update gate. */
+   for (i=0;i<3*N;i++)
+      zrh[i] = gru->bias[i];
+   gemm_accum(zrh, gru->input_weights, 3*N, M, stride, input);
+   for (i=0;i<3*N;i++)
+      recur[i] = gru->bias[3*N + i];
+   gemm_accum(recur, gru->recurrent_weights, 3*N, N, stride, state);
+   for (i=0;i<2*N;i++)
+      zrh[i] += recur[i];
+   compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID);
+   for (i=0;i<N;i++)
+      h[i] += recur[2*N+i]*r[i];
+   compute_activation(h, h, N, gru->activation);
+   for (i=0;i<N;i++)
+      h[i] = z[i]*state[i] + (1-z[i])*h[i];
+   for (i=0;i<N;i++)
+      state[i] = h[i];
+}
+
 void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input)
 {
    int i;
--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -85,6 +85,8 @@
 
 void compute_gru(const GRULayer *gru, float *state, const float *input);
 
+void compute_gru2(const GRULayer *gru, float *state, const float *input);
+
 void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input);
 
 void compute_embedding(const EmbeddingLayer *layer, float *output, int input);
--