shithub: opus

Download patch

ref: 006556036a4838a0243690354a0f375882d4c49e
parent: 44fe0556826c23fc4273c55f159351577fa98a57
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Tue Jul 6 14:17:17 EDT 2021

Cleaning up the sparse GRU

It no longer overwrites its input vector

--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -68,7 +68,7 @@
    }
 }
 
-void compute_activation(float *output, float *input, int N, int activation)
+void compute_activation(float *output, const float *input, int N, int activation)
 {
    int i;
    if (activation == ACTIVATION_SIGMOID) {
@@ -322,8 +322,8 @@
       state[i] = h[i];
 }
 
-/* WARNING: for efficiency reasons, this function overwrites the input vector. */
-void compute_sparse_gru(const SparseGRULayer *gru, float *state, float *input)
+/* The input of this GRU is after the input matrix multiply. */
+void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *input)
 {
    int i, k;
    int N;
@@ -333,9 +333,9 @@
    float *h;
    const float *bias;
    N = gru->nb_neurons;
-   z = input;
-   r = &input[N];
-   h = &input[2*N];
+   z = recur;
+   r = &recur[N];
+   h = &recur[2*N];
    celt_assert(gru->nb_neurons <= MAX_RNN_NEURONS);
    celt_assert(input != state);
    celt_assert(gru->reset_after);
@@ -344,17 +344,20 @@
 #else
    bias = &gru->bias[3*N];   
 #endif
-   for (k=0;k<3;k++)
+   for (k=0;k<2;k++)
    {
       for (i=0;i<N;i++)
+         recur[k*N + i] = bias[k*N + i] + gru->diag_weights[k*N + i]*state[i] + input[k*N + i];
+   }
+   for (;k<3;k++)
+   {
+      for (i=0;i<N;i++)
          recur[k*N + i] = bias[k*N + i] + gru->diag_weights[k*N + i]*state[i];
    }
    sparse_sgemv_accum8x4(recur, gru->recurrent_weights, 3*N, N, gru->idx, state);
-   for (i=0;i<2*N;i++)
-      input[i] += recur[i];
-   compute_activation(input, input, 2*N, ACTIVATION_SIGMOID);
+   compute_activation(recur, recur, 2*N, ACTIVATION_SIGMOID);
    for (i=0;i<N;i++)
-      h[i] += recur[2*N+i]*r[i];
+      h[i] = h[i]*r[i] + input[2*N+i];
    compute_activation(h, h, N, gru->activation);
    for (i=0;i<N;i++)
       state[i] = z[i]*state[i] + (1-z[i])*h[i];
--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -91,7 +91,7 @@
   int dim;
 } EmbeddingLayer;
 
-void compute_activation(float *output, float *input, int N, int activation);
+void compute_activation(float *output, const float *input, int N, int activation);
 
 void compute_dense(const DenseLayer *layer, float *output, const float *input);
 
@@ -105,7 +105,7 @@
 
 void compute_gru3(const GRULayer *gru, float *state, const float *input);
 
-void compute_sparse_gru(const SparseGRULayer *gru, float *state, float *input);
+void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *input);
 
 void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input);
 
--