shithub: opus

Download patch

ref: c7b978b923f4d243d8f67c9d865ba23c18c89ae9
parent: 3c694db22607e177326e42673284937e0e4d31aa
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Tue Nov 27 09:37:10 EST 2018

Fix reset_after GRU

--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -173,6 +173,11 @@
    /* Compute update gate. */
    for (i=0;i<N;i++)
       z[i] = gru->bias[i];
+   if (gru->reset_after)
+   {
+      for (i=0;i<N;i++)
+         z[i] += gru->bias[3*N + i];
+   }
    gemm_accum(z, gru->input_weights, N, M, stride, input);
    gemm_accum(z, gru->recurrent_weights, N, N, stride, state);
    compute_activation(z, z, N, ACTIVATION_SIGMOID);
@@ -180,6 +185,11 @@
    /* Compute reset gate. */
    for (i=0;i<N;i++)
       r[i] = gru->bias[N + i];
+   if (gru->reset_after)
+   {
+      for (i=0;i<N;i++)
+         r[i] += gru->bias[4*N + i];
+   }
    gemm_accum(r, &gru->input_weights[N], N, M, stride, input);
    gemm_accum(r, &gru->recurrent_weights[N], N, N, stride, state);
    compute_activation(r, r, N, ACTIVATION_SIGMOID);
@@ -189,8 +199,8 @@
       h[i] = gru->bias[2*N + i];
    if (gru->reset_after)
    {
-      /* WARNING: The reset_after version was never tested. */
-      RNN_CLEAR(tmp, N);
+      for (i=0;i<N;i++)
+         tmp[i] = gru->bias[5*N + i];
       gemm_accum(tmp, &gru->recurrent_weights[2*N], N, N, stride, state);
       for (i=0;i<N;i++)
          h[i] += tmp[i] * r[i];
--