ref: c7b978b923f4d243d8f67c9d865ba23c18c89ae9
parent: 3c694db22607e177326e42673284937e0e4d31aa
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Tue Nov 27 09:37:10 EST 2018
Fix reset_after GRU
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -173,6 +173,11 @@
/* Compute update gate. */
for (i=0;i<N;i++)
z[i] = gru->bias[i];
+ if (gru->reset_after)
+ {
+ for (i=0;i<N;i++)
+ z[i] += gru->bias[3*N + i];
+ }
gemm_accum(z, gru->input_weights, N, M, stride, input);
gemm_accum(z, gru->recurrent_weights, N, N, stride, state);
compute_activation(z, z, N, ACTIVATION_SIGMOID);
@@ -180,6 +185,11 @@
/* Compute reset gate. */
for (i=0;i<N;i++)
r[i] = gru->bias[N + i];
+ if (gru->reset_after)
+ {
+ for (i=0;i<N;i++)
+ r[i] += gru->bias[4*N + i];
+ }
gemm_accum(r, &gru->input_weights[N], N, M, stride, input);
gemm_accum(r, &gru->recurrent_weights[N], N, N, stride, state);
compute_activation(r, r, N, ACTIVATION_SIGMOID);
@@ -189,8 +199,8 @@
h[i] = gru->bias[2*N + i];
if (gru->reset_after)
{
- /* WARNING: The reset_after version was never tested. */
- RNN_CLEAR(tmp, N);
+ for (i=0;i<N;i++)
+ tmp[i] = gru->bias[5*N + i];
gemm_accum(tmp, &gru->recurrent_weights[2*N], N, N, stride, state);
for (i=0;i<N;i++)
h[i] += tmp[i] * r[i];
--
⑨