shithub: opus

Download patch

ref: be392e38572760dda8d484285e55fcd5e35ddbfa
parent: 2b4652f9f6b59b3c95a31da3b4348f5d4eea068e
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Fri Dec 25 22:58:48 EST 2020

WIP: Got some AVX2 code working

--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -200,7 +200,7 @@
 static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
 {
    int i, j;
-   unsigned x[MAX_INPUTS];
+   unsigned char x[MAX_INPUTS];
    for (i=0;i<rows;i++) out[i] *= SCALE;
    for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
    for (i=0;i<rows;i+=8)
@@ -235,7 +235,7 @@
 static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
 {
    int i, j;
-   signed x[MAX_INPUTS];
+   signed char x[MAX_INPUTS];
    for (i=0;i<rows;i++) out[i] *= SCALE;
    for (i=0;i<cols;i++) x[i] = floor(.5+127*_x[i]);
    for (i=0;i<rows;i+=8)
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -226,52 +226,81 @@
 #define USE_SU_BIAS
 
 #define MAX_INPUTS (2048)
+#define MAX_OUTPUTS (8192)
 
 
 #define SCALE (128.f*127.f)
 #define SCALE_1 (1.f/128.f/127.f)
 
-#if 0
-static inline void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, const int *idx, const float *x)
+#if 1
+
+static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
 {
+   __m256i ones;
    int i, j;
+   unsigned char x[MAX_INPUTS];
+   int out[MAX_OUTPUTS];
+   ones = _mm256_set1_epi16(1);
+   for (i=0;i<rows;i++) out[i] = SCALE*_out[i];
+   for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
    for (i=0;i<rows;i+=8)
    {
-      float * restrict y;
-      int cols;
-      __m256 vy0;
+      int * restrict y;
+      int colblocks;
+      __m256i vy0;
+      colblocks = *idx++;
       y = &out[i];
-      vy0 = _mm256_loadu_ps(&y[0]);
-      cols = *idx++;
-      for (j=0;j<cols;j++)
+      vy0 = _mm256_loadu_si256((const __m256i *)&y[0]);
+      for (j=0;j<colblocks;j++)
       {
-         int id;
-         __m256 vxj;
-         __m256 vw;
-         id = *idx++;
+         __m256i tmp;
+         __m256i vxj;
+         __m256i vw;
+         int pos;
+         int xj0, xj1, xj2, xj3;
+         pos = 4 * (*idx++);
+         vxj = _mm256_set1_epi32(*(int*)&x[pos]);
+         xj0 = x[pos+0];
+         xj1 = x[pos+1];
+         xj2 = x[pos+2];
+         xj3 = x[pos+3];
          
-         //kernel goes here
+         vw = _mm256_loadu_si256((const __m256i *)w); //_mm256_lddqu_si256?
+         tmp = _mm256_maddubs_epi16(vxj, vw); //swap?
+         tmp = _mm256_madd_epi16(tmp, ones);
+         vy0 = _mm256_add_epi32(vy0, tmp);
 
-         weights += 32;
+         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
+         w += 32;
       }
-      _mm256_storeu_ps (&y[0], vy0);
+      _mm256_storeu_si256 ((__m256i *)&y[0], vy0);
    }
+   for (i=0;i<rows;i++) _out[i] = SCALE_1*out[i];
 }
 #else
-static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
+static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
 {
    int i, j;
-   unsigned x[MAX_INPUTS];
-   for (i=0;i<rows;i++) out[i] *= SCALE;
+   unsigned char x[MAX_INPUTS];
+   int out[MAX_OUTPUTS];
+   for (i=0;i<rows;i++) out[i] = SCALE*_out[i];
    for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
    for (i=0;i<rows;i+=8)
    {
+      int * restrict y;
       int colblocks;
       colblocks = *idx++;
+      y = &out[i];
       for (j=0;j<colblocks;j++)
       {
          int pos;
-         float * restrict y;
          int xj0, xj1, xj2, xj3;
          pos = 4 * (*idx++);
          xj0 = x[pos+0];
@@ -278,7 +307,6 @@
          xj1 = x[pos+1];
          xj2 = x[pos+2];
          xj3 = x[pos+3];
-         y = &out[i];
          y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
          y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
          y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
@@ -290,7 +318,7 @@
          w += 32;
       }
    }
-   for (i=0;i<rows;i++) out[i] *= SCALE_1;
+   for (i=0;i<rows;i++) _out[i] = SCALE_1*out[i];
 }
 #endif
 
--