ref: 83657d0e43d80e1e64273d8d8094b04ed8088172
parent: 1707b960dee936f6730544e01ad3cd0fb055dbdb
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Tue Dec 29 22:07:14 EST 2020
Dot product AVX2 code for non-sparse multiply
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -224,13 +224,14 @@
celt_assert(gru->reset_after);
stride = 3*N;
/* Compute update gate. */
+#ifdef USE_SU_BIAS
for (i=0;i<3*N;i++)
- zrh[i] = gru->bias[i];
-#if 1
- sgemv_accum8x4(zrh, gru->input_weights, 3*N, M, stride, input);
+ zrh[i] = gru->subias[i];
#else
- sgemv_accum(zrh, gru->input_weights, 3*N, M, stride, input);
+ for (i=0;i<3*N;i++)
+ zrh[i] = gru->bias[i];
#endif
+ sgemv_accum8x4(zrh, gru->input_weights, 3*N, M, stride, input);
for (i=0;i<3*N;i++)
recur[i] = gru->bias[3*N + i];
sgemv_accum(recur, gru->recurrent_weights, 3*N, N, stride, state);
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -198,13 +198,16 @@
#define SCALE (128.f*127.f)
#define SCALE_1 (1.f/128.f/127.f)
+
+#ifdef USE_SU_BIAS
+
static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x)
{
int i, j;
- signed char x[MAX_INPUTS];
+ unsigned char x[MAX_INPUTS];
(void)col_stride;
for (i=0;i<rows;i++) out[i] *= SCALE;
- for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]);
+ for (i=0;i<cols;i++) x[i] = 127+(int)floor(.5+127*_x[i]);
for (i=0;i<rows;i+=8)
{
for (j=0;j<cols;j+=4)
@@ -230,8 +233,6 @@
for (i=0;i<rows;i++) out[i] *= SCALE_1;
}
-
-#ifdef USE_SU_BIAS
static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
{
int i, j;
@@ -267,6 +268,39 @@
for (i=0;i<rows;i++) out[i] *= SCALE_1;
}
#else /*USE_SU_BIAS*/
+
+static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x)
+{
+ int i, j;
+ signed char x[MAX_INPUTS];
+ (void)col_stride;
+ for (i=0;i<rows;i++) out[i] *= SCALE;
+ for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]);
+ for (i=0;i<rows;i+=8)
+ {
+ for (j=0;j<cols;j+=4)
+ {
+ float * restrict y;
+ float xj0, xj1, xj2, xj3;
+ xj0 = x[j+0];
+ xj1 = x[j+1];
+ xj2 = x[j+2];
+ xj3 = x[j+3];
+ y = &out[i];
+ y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+ y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+ y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+ y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+ y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+ y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+ y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+ y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
+ w += 32;
+ }
+ }
+ for (i=0;i<rows;i++) out[i] *= SCALE_1;
+}
+
static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
{
int i, j;
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -241,6 +241,88 @@
#define SCALE (128.f*127.f)
#define SCALE_1 (1.f/128.f/127.f)
+#if 1
+static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, int col_stride, const float *_x)
+{
+ __m256i ones;
+ int i, j;
+ unsigned char x[MAX_INPUTS];
+ int out[MAX_OUTPUTS];
+ (void)col_stride;
+ ones = _mm256_set1_epi16(1);
+ for (i=0;i<rows;i++) out[i] = SCALE*_out[i];
+ //for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
+ __m256 const127 = _mm256_set1_ps(127.f);
+ for (i=0;i<cols;i+=8) {
+ __m256 xf;
+ __m256i xi;
+ xf = _mm256_loadu_ps(&_x[i]);
+ //xf = _mm256_mul_ps(xf, const127);
+ //xf = _mm256_add_ps(xf, const127);
+ xf = _mm256_fmadd_ps(xf, const127, const127);
+ xi = _mm256_cvtps_epi32(xf);
+ xi = _mm256_packus_epi32(xi, _mm256_setzero_si256());
+ xi = _mm256_permute4x64_epi64(xi, 0xD8);
+ xi = _mm256_packus_epi16(xi, _mm256_setzero_si256());
+ xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0));
+ //xi = _mm256_permute4x64_epi64(xi, 0x);
+ _mm256_storeu_si256 ((__m256i *)&x[i], xi);
+ }
+ for (i=0;i<rows;i+=8)
+ {
+ int * restrict y;
+ __m256i vy0;
+ y = &out[i];
+ vy0 = _mm256_loadu_si256((const __m256i *)&y[0]);
+ for (j=0;j<cols;j+=4)
+ {
+ __m256i tmp;
+ __m256i vxj;
+ __m256i vw;
+ vxj = _mm256_set1_epi32(*(int*)&x[j]);
+ vw = _mm256_loadu_si256((const __m256i *)w); //_mm256_lddqu_si256?
+ tmp = _mm256_maddubs_epi16(vxj, vw); //swap?
+ tmp = _mm256_madd_epi16(tmp, ones);
+ vy0 = _mm256_add_epi32(vy0, tmp);
+ w += 32;
+ }
+ _mm256_storeu_si256 ((__m256i *)&y[0], vy0);
+ }
+ for (i=0;i<rows;i++) _out[i] = SCALE_1*out[i];
+}
+#else
+static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x)
+{
+ int i, j;
+ unsigned char x[MAX_INPUTS];
+ (void)col_stride;
+ for (i=0;i<rows;i++) out[i] *= SCALE;
+ for (i=0;i<cols;i++) x[i] = 127+(int)floor(.5+127*_x[i]);
+ for (i=0;i<rows;i+=8)
+ {
+ for (j=0;j<cols;j+=4)
+ {
+ float * restrict y;
+ float xj0, xj1, xj2, xj3;
+ xj0 = x[j+0];
+ xj1 = x[j+1];
+ xj2 = x[j+2];
+ xj3 = x[j+3];
+ y = &out[i];
+ y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+ y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+ y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+ y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+ y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+ y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+ y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+ y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
+ w += 32;
+ }
+ }
+ for (i=0;i<rows;i++) out[i] *= SCALE_1;
+}
+#endif
static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
{
--
⑨