ref: 73a05f55c75505b3ef05f931dbb054002ba0e2ba
parent: cc285186993a66e9886957b2b5c106ec56c9949e
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Tue Dec 22 11:52:24 EST 2020
wip 8x4
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -301,7 +301,7 @@
for (i=0;i<N;i++)
recur[k*N + i] += gru->diag_weights[k*N + i]*state[i];
}
- sparse_sgemv_accum16(recur, gru->recurrent_weights, 3*N, gru->idx, state);
+ sparse_sgemv_accum8x4(recur, gru->recurrent_weights, 3*N, gru->idx, state);
for (i=0;i<2*N;i++)
zrh[i] += recur[i];
compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID);
--- a/dnn/training_tf2/dump_lpcnet.py
+++ b/dnn/training_tf2/dump_lpcnet.py
@@ -60,6 +60,7 @@
def printSparseVector(f, A, name):
N = A.shape[0]
W = np.zeros((0,))
+ W0 = np.zeros((0,))
diag = np.concatenate([np.diag(A[:,:N]), np.diag(A[:,N:2*N]), np.diag(A[:,2*N:])])
A[:,:N] = A[:,:N] - np.diag(np.diag(A[:,:N]))
A[:,N:2*N] = A[:,N:2*N] - np.diag(np.diag(A[:,N:2*N]))
@@ -76,9 +77,14 @@
nb_nonzero = nb_nonzero + 1
idx = np.append(idx, j)
vblock = block.transpose((1,0)).reshape((-1,))
+ W0 = np.concatenate([W0, block.reshape((-1,))])
W = np.concatenate([W, vblock])
idx[pos] = nb_nonzero
+ f.write('#ifdef DOT_PROD\n')
printVector(f, W, name)
+ f.write('#else /*DOT_PROD*/\n')
+ printVector(f, W0, name)
+ f.write('#endif /*DOT_PROD*/\n')
#idx = np.tile(np.concatenate([np.array([N]), np.arange(N)]), 3*N//16)
printVector(f, idx, name + '_idx', dtype='int')
return;
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -162,3 +162,123 @@
}
}
}
+
+#ifdef DOT_PROD
+
+static void sparse_sgemv_accum8x4(float *out, const float *w, int rows, const int *idx, const float *x)
+{
+ int i, j;
+ for (i=0;i<rows;i+=8)
+ {
+ int cols;
+ cols = *idx++;
+ for (j=0;j<cols;j++)
+ {
+ int pos;
+ float * restrict y;
+ float xj0, xj1, xj2, xj3;
+ pos = 4 * (*idx++);
+ xj0 = x[pos+0];
+ xj1 = x[pos+1];
+ xj2 = x[pos+2];
+ xj3 = x[pos+3];
+ y = &out[i];
+ y[0] += w[0]*xj0;
+ y[1] += w[4]*xj0;
+ y[2] += w[8]*xj0;
+ y[3] += w[12]*xj0;
+ y[4] += w[16]*xj0;
+ y[5] += w[20]*xj0;
+ y[6] += w[24]*xj0;
+ y[7] += w[28]*xj0;
+
+ y[0] += w[1]*xj1;
+ y[1] += w[5]*xj1;
+ y[2] += w[9]*xj1;
+ y[3] += w[13]*xj1;
+ y[4] += w[17]*xj1;
+ y[5] += w[21]*xj1;
+ y[6] += w[25]*xj1;
+ y[7] += w[29]*xj1;
+
+ y[0] += w[2]*xj2;
+ y[1] += w[6]*xj2;
+ y[2] += w[10]*xj2;
+ y[3] += w[14]*xj2;
+ y[4] += w[18]*xj2;
+ y[5] += w[22]*xj2;
+ y[6] += w[26]*xj2;
+ y[7] += w[30]*xj2;
+
+ y[0] += w[3]*xj3;
+ y[1] += w[7]*xj3;
+ y[2] += w[11]*xj3;
+ y[3] += w[15]*xj3;
+ y[4] += w[19]*xj3;
+ y[5] += w[23]*xj3;
+ y[6] += w[27]*xj3;
+ y[7] += w[31]*xj3;
+ w += 32;
+ }
+ }
+}
+
+#else
+static void sparse_sgemv_accum8x4(float *out, const float *w, int rows, const int *idx, const float *x)
+{
+ int i, j;
+ for (i=0;i<rows;i+=8)
+ {
+ int cols;
+ cols = *idx++;
+ for (j=0;j<cols;j++)
+ {
+ int pos;
+ float * restrict y;
+ float xj0, xj1, xj2, xj3;
+ pos = 4 * (*idx++);
+ xj0 = x[pos+0];
+ xj1 = x[pos+1];
+ xj2 = x[pos+2];
+ xj3 = x[pos+3];
+ y = &out[i];
+ y[0] += w[0]*xj0;
+ y[1] += w[1]*xj0;
+ y[2] += w[2]*xj0;
+ y[3] += w[3]*xj0;
+ y[4] += w[4]*xj0;
+ y[5] += w[5]*xj0;
+ y[6] += w[6]*xj0;
+ y[7] += w[7]*xj0;
+
+ y[0] += w[8]*xj1;
+ y[1] += w[9]*xj1;
+ y[2] += w[10]*xj1;
+ y[3] += w[11]*xj1;
+ y[4] += w[12]*xj1;
+ y[5] += w[13]*xj1;
+ y[6] += w[14]*xj1;
+ y[7] += w[15]*xj1;
+
+ y[0] += w[16]*xj2;
+ y[1] += w[17]*xj2;
+ y[2] += w[18]*xj2;
+ y[3] += w[19]*xj2;
+ y[4] += w[20]*xj2;
+ y[5] += w[21]*xj2;
+ y[6] += w[22]*xj2;
+ y[7] += w[23]*xj2;
+
+ y[0] += w[24]*xj3;
+ y[1] += w[25]*xj3;
+ y[2] += w[26]*xj3;
+ y[3] += w[27]*xj3;
+ y[4] += w[28]*xj3;
+ y[5] += w[29]*xj3;
+ y[6] += w[30]*xj3;
+ y[7] += w[31]*xj3;
+ w += 32;
+ }
+ }
+}
+#endif
--
⑨