ref: b214e684c1d234c7952656e174a512a65d1e6f09
parent: 8c3fe6f31d8c48936f061b5134ecec18d3c9f71e
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Fri Jan 1 20:09:00 EST 2021
Neon WIP: Compiles but very slow
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -37,7 +37,7 @@
#ifdef __AVX__
#include "vec_avx.h"
-#elif __ARM_NEON__
+#elif defined(__ARM_NEON__) || defined(__ARM_NEON)
#include "vec_neon.h"
#else
--- a/dnn/vec_neon.h
+++ b/dnn/vec_neon.h
@@ -29,8 +29,13 @@
/* NEON support for ARM machines */
#include <arm_neon.h>
+
+#define DOT_PROD
+typedef signed char qweight;
+
+
#ifndef LPCNET_TEST
-static OPUS_INLINE float32x4_t exp4_approx(float32x4_t x) {
+static inline OPUS_INLINE float32x4_t exp4_approx(float32x4_t x) {
int32x4_t i;
float32x4_t xf;
@@ -57,7 +62,7 @@
return Y;
}
-static OPUS_INLINE float celt_exp(float x)
+static inline float celt_exp(float x)
{
float out[4];
float32x4_t X, Y;
@@ -67,7 +72,7 @@
return out[0];
}
-static void softmax(float *y, const float *x, int N)
+static inline void softmax(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-3;i+=4)
@@ -81,7 +86,7 @@
y[i] = celt_exp(x[i]);
}
-static void vec_tanh(float *y, const float *x, int N)
+static inline void vec_tanh(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-3;i+=4)
@@ -103,7 +108,7 @@
}
}
-static void vec_sigmoid(float *y, const float *x, int N)
+static inline void vec_sigmoid(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-3;i+=4)
@@ -124,7 +129,7 @@
}
#endif
-static void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+static inline void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
{
int i, j;
for (i=0;i<rows;i+=16)
@@ -168,7 +173,7 @@
}
}
-static void sparse_sgemv_accum16(float *out, const float *w, int rows, const int *idx, const float *x)
+static inline void sparse_sgemv_accum16(float *out, const float *w, int rows, const int *idx, const float *x)
{
int i, j;
for (i=0;i<rows;i+=16)
@@ -206,4 +211,76 @@
vst1q_f32(&y[12], y12_15);
}
+}
+
+#define SCALE (128.f*127.f)
+#define SCALE_1 (1.f/128.f/127.f)
+
+#define MAX_INPUTS 2048
+
+static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x)
+{
+ int i, j;
+ signed char x[MAX_INPUTS];
+ (void)col_stride;
+ for (i=0;i<rows;i++) out[i] *= SCALE;
+ for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]);
+ for (i=0;i<rows;i+=8)
+ {
+ for (j=0;j<cols;j+=4)
+ {
+ float * restrict y;
+ float xj0, xj1, xj2, xj3;
+ xj0 = x[j+0];
+ xj1 = x[j+1];
+ xj2 = x[j+2];
+ xj3 = x[j+3];
+ y = &out[i];
+ y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+ y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+ y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+ y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+ y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+ y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+ y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+ y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
+ w += 32;
+ }
+ }
+ for (i=0;i<rows;i++) out[i] *= SCALE_1;
+}
+
+static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
+{
+ int i, j;
+ signed char x[MAX_INPUTS];
+ for (i=0;i<rows;i++) out[i] *= SCALE;
+ for (i=0;i<cols;i++) x[i] = floor(.5+127*_x[i]);
+ for (i=0;i<rows;i+=8)
+ {
+ int colblocks;
+ colblocks = *idx++;
+ for (j=0;j<colblocks;j++)
+ {
+ int pos;
+ float * restrict y;
+ int xj0, xj1, xj2, xj3;
+ pos = 4 * (*idx++);
+ xj0 = x[pos+0];
+ xj1 = x[pos+1];
+ xj2 = x[pos+2];
+ xj3 = x[pos+3];
+ y = &out[i];
+ y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+ y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+ y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+ y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+ y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+ y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+ y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+ y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
+ w += 32;
+ }
+ }
+ for (i=0;i<rows;i++) out[i] *= SCALE_1;
}
--
⑨