ref: 006556036a4838a0243690354a0f375882d4c49e
parent: 44fe0556826c23fc4273c55f159351577fa98a57
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Tue Jul 6 14:17:17 EDT 2021
Cleaning up the sparse GRU It no longer overwrites its input vector
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -68,7 +68,7 @@
}
}
-void compute_activation(float *output, float *input, int N, int activation)
+void compute_activation(float *output, const float *input, int N, int activation)
{
int i;
if (activation == ACTIVATION_SIGMOID) {
@@ -322,8 +322,8 @@
state[i] = h[i];
}
-/* WARNING: for efficiency reasons, this function overwrites the input vector. */
-void compute_sparse_gru(const SparseGRULayer *gru, float *state, float *input)
+/* The input of this GRU is after the input matrix multiply. */
+void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *input)
{
int i, k;
int N;
@@ -333,9 +333,9 @@
float *h;
const float *bias;
N = gru->nb_neurons;
- z = input;
- r = &input[N];
- h = &input[2*N];
+ z = recur;
+ r = &recur[N];
+ h = &recur[2*N];
celt_assert(gru->nb_neurons <= MAX_RNN_NEURONS);
celt_assert(input != state);
celt_assert(gru->reset_after);
@@ -344,17 +344,20 @@
#else
bias = &gru->bias[3*N];
#endif
- for (k=0;k<3;k++)
+ for (k=0;k<2;k++)
{
for (i=0;i<N;i++)
+ recur[k*N + i] = bias[k*N + i] + gru->diag_weights[k*N + i]*state[i] + input[k*N + i];
+ }
+ for (;k<3;k++)
+ {
+ for (i=0;i<N;i++)
recur[k*N + i] = bias[k*N + i] + gru->diag_weights[k*N + i]*state[i];
}
sparse_sgemv_accum8x4(recur, gru->recurrent_weights, 3*N, N, gru->idx, state);
- for (i=0;i<2*N;i++)
- input[i] += recur[i];
- compute_activation(input, input, 2*N, ACTIVATION_SIGMOID);
+ compute_activation(recur, recur, 2*N, ACTIVATION_SIGMOID);
for (i=0;i<N;i++)
- h[i] += recur[2*N+i]*r[i];
+ h[i] = h[i]*r[i] + input[2*N+i];
compute_activation(h, h, N, gru->activation);
for (i=0;i<N;i++)
state[i] = z[i]*state[i] + (1-z[i])*h[i];
--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -91,7 +91,7 @@
int dim;
} EmbeddingLayer;
-void compute_activation(float *output, float *input, int N, int activation);
+void compute_activation(float *output, const float *input, int N, int activation);
void compute_dense(const DenseLayer *layer, float *output, const float *input);
@@ -105,7 +105,7 @@
void compute_gru3(const GRULayer *gru, float *state, const float *input);
-void compute_sparse_gru(const SparseGRULayer *gru, float *state, float *input);
+void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *input);
void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input);
--
⑨