ref: 2b4652f9f6b59b3c95a31da3b4348f5d4eea068e
parent: bce779886de72547db991c5d4dec1ae1fc1a7a80
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Fri Dec 25 22:20:20 EST 2020
WIP: cleanup
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -35,7 +35,7 @@
#include <immintrin.h>
#ifdef __AVX2__
-static __m256 exp8_approx(__m256 X)
+static inline __m256 exp8_approx(__m256 X)
{
const __m256 K0 = _mm256_set1_ps(0.99992522f);
const __m256 K1 = _mm256_set1_ps(0.69583354f);
@@ -60,7 +60,7 @@
#else
#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
-static __m128 exp4_approx(__m128 X)
+static inline __m128 exp4_approx(__m128 X)
{
const __m128 K0 = _mm_set1_ps(0.99992522f);
const __m128 K1 = _mm_set1_ps(0.69583354f);
@@ -82,7 +82,7 @@
Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
return Y;
}
-static __m256 exp8_approx(__m256 X)
+static inline __m256 exp8_approx(__m256 X)
{
__m256 Y;
__m128 Xhi, Xlo, Yhi, Ylo;
@@ -96,7 +96,7 @@
}
#endif
-static float celt_exp(float x)
+static inline float celt_exp(float x)
{
float out[8];
__m256 X, Y;
@@ -106,7 +106,7 @@
return out[0];
}
-static void softmax(float *y, const float *x, int N)
+static inline void softmax(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-7;i+=8)
@@ -120,7 +120,7 @@
y[i] = celt_exp(x[i]);
}
-static void vec_tanh(float *y, const float *x, int N)
+static inline void vec_tanh(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-7;i+=8)
@@ -142,7 +142,7 @@
}
}
-static void vec_sigmoid(float *y, const float *x, int N)
+static inline void vec_sigmoid(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-7;i+=8)
@@ -163,7 +163,7 @@
}
}
-static void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+static inline void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
{
int i, j;
for (i=0;i<rows;i+=16)
@@ -189,7 +189,7 @@
_mm256_storeu_ps (&y[8], vy8);
}
}
-static void sparse_sgemv_accum16(float *out, const float *weights, int rows, const int *idx, const float *x)
+static inline void sparse_sgemv_accum16(float *out, const float *weights, int rows, const int *idx, const float *x)
{
int i, j;
for (i=0;i<rows;i+=16)
@@ -222,7 +222,17 @@
}
#ifdef DOT_PROD
-static void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, const int *idx, const float *x)
+
+#define USE_SU_BIAS
+
+#define MAX_INPUTS (2048)
+
+
+#define SCALE (128.f*127.f)
+#define SCALE_1 (1.f/128.f/127.f)
+
+#if 0
+static inline void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, const int *idx, const float *x)
{
int i, j;
for (i=0;i<rows;i+=8)
@@ -247,11 +257,47 @@
_mm256_storeu_ps (&y[0], vy0);
}
}
-
#else
-static void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, int ignore, const int *idx, const float *x)
+static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
{
int i, j;
+ unsigned x[MAX_INPUTS];
+ for (i=0;i<rows;i++) out[i] *= SCALE;
+ for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
+ for (i=0;i<rows;i+=8)
+ {
+ int colblocks;
+ colblocks = *idx++;
+ for (j=0;j<colblocks;j++)
+ {
+ int pos;
+ float * restrict y;
+ int xj0, xj1, xj2, xj3;
+ pos = 4 * (*idx++);
+ xj0 = x[pos+0];
+ xj1 = x[pos+1];
+ xj2 = x[pos+2];
+ xj3 = x[pos+3];
+ y = &out[i];
+ y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+ y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+ y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+ y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+ y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+ y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+ y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+ y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
+ w += 32;
+ }
+ }
+ for (i=0;i<rows;i++) out[i] *= SCALE_1;
+}
+#endif
+
+#else /*DOT_PROD*/
+static inline void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, int ignore, const int *idx, const float *x)
+{
+ int i, j;
(void)ignore;
for (i=0;i<rows;i+=8)
{
@@ -288,6 +334,6 @@
_mm256_storeu_ps (&y[0], vy0);
}
}
-#endif
+#endif /*DOT_PROD*/
#endif /*VEC_AVX_H*/
--
⑨