ref: 116bcb38fb7bddb75a52b2da52ef536aadd4f3e1
parent: 3e223e6015d2740fcfd4c50d4fdc2d20349fa565
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Sat Jul 10 10:08:01 EDT 2021
Adding SSE 4.1 for older platforms AVX without AVX2 should now work again too.
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -35,7 +35,7 @@
#include "arch.h"
-#ifdef __AVX__
+#if defined(__AVX__) || defined(__SSE4_1__)
#include "vec_avx.h"
#elif defined(__ARM_NEON__) || defined(__ARM_NEON)
#include "vec_neon.h"
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -39,7 +39,234 @@
#define USE_SU_BIAS
#endif
-#ifndef __FMA__
+/* If we don't have AVX available, emulate what we need with SSE up to 4.1. */
+#ifndef __AVX__
+
+typedef struct {
+ __m128 lo;
+ __m128 hi;
+} mm256_emu;
+#define __m256 mm256_emu
+
+static inline mm256_emu mm256_loadu_ps(const float *src) {
+ mm256_emu ret;
+ ret.lo = _mm_loadu_ps(&src[0]);
+ ret.hi = _mm_loadu_ps(&src[4]);
+ return ret;
+}
+#define _mm256_loadu_ps(src) mm256_loadu_ps(src)
+
+
+static inline void mm256_storeu_ps(float *dst, mm256_emu src) {
+ _mm_storeu_ps(dst, src.lo);
+ _mm_storeu_ps(&dst[4], src.hi);
+}
+#define _mm256_storeu_ps(dst, src) mm256_storeu_ps(dst, src)
+
+
+static inline mm256_emu mm256_setzero_ps() {
+ mm256_emu ret;
+ ret.lo = _mm_setzero_ps();
+ ret.hi = ret.lo;
+ return ret;
+}
+#define _mm256_setzero_ps mm256_setzero_ps
+
+static inline mm256_emu mm256_broadcast_ss(const float *x) {
+ mm256_emu ret;
+ ret.lo = _mm_set1_ps(*x);
+ ret.hi = ret.lo;
+ return ret;
+}
+#define _mm256_broadcast_ss(x) mm256_broadcast_ss(x)
+
+static inline mm256_emu mm256_set1_ps(float x) {
+ mm256_emu ret;
+ ret.lo = _mm_set1_ps(x);
+ ret.hi = ret.lo;
+ return ret;
+}
+#define _mm256_set1_ps(x) mm256_set1_ps(x)
+
+
+
+static inline mm256_emu mm256_mul_ps(mm256_emu a, mm256_emu b) {
+ mm256_emu ret;
+ ret.lo = _mm_mul_ps(a.lo, b.lo);
+ ret.hi = _mm_mul_ps(a.hi, b.hi);
+ return ret;
+}
+#define _mm256_mul_ps(a,b) mm256_mul_ps(a,b)
+
+static inline mm256_emu mm256_add_ps(mm256_emu a, mm256_emu b) {
+ mm256_emu ret;
+ ret.lo = _mm_add_ps(a.lo, b.lo);
+ ret.hi = _mm_add_ps(a.hi, b.hi);
+ return ret;
+}
+#define _mm256_add_ps(a,b) mm256_add_ps(a,b)
+
+
+static inline mm256_emu mm256_max_ps(mm256_emu a, mm256_emu b) {
+ mm256_emu ret;
+ ret.lo = _mm_max_ps(a.lo, b.lo);
+ ret.hi = _mm_max_ps(a.hi, b.hi);
+ return ret;
+}
+#define _mm256_max_ps(a,b) mm256_max_ps(a,b)
+
+static inline mm256_emu mm256_min_ps(mm256_emu a, mm256_emu b) {
+ mm256_emu ret;
+ ret.lo = _mm_min_ps(a.lo, b.lo);
+ ret.hi = _mm_min_ps(a.hi, b.hi);
+ return ret;
+}
+#define _mm256_min_ps(a,b) mm256_min_ps(a,b)
+
+static inline mm256_emu mm256_rcp_ps(mm256_emu a) {
+ mm256_emu ret;
+ ret.lo = _mm_rcp_ps(a.lo);
+ ret.hi = _mm_rcp_ps(a.hi);
+ return ret;
+}
+#define _mm256_rcp_ps(a) mm256_rcp_ps(a)
+
+
+static inline __m128 mm256_extractf128_ps(mm256_emu x, int i) {
+ return (i==0) ? x.lo : x.hi;
+}
+#define _mm256_extractf128_ps(x,i) mm256_extractf128_ps(x,i)
+
+static inline mm256_emu mm256_insertf128_ps(mm256_emu dst, __m128 src, int i) {
+ if (i==0) dst.lo = src;
+ else dst.hi = src;
+ return dst;
+}
+#define _mm256_insertf128_ps(dst,src,i) mm256_insertf128_ps(dst,src,i)
+
+#endif /* __AVX__ */
+
+
+
+/* If we don't have AVX2 available, emulate what we need with SSE up to 4.1. */
+#ifndef __AVX2__
+
+typedef struct {
+ __m128i lo;
+ __m128i hi;
+} mm256i_emu;
+typedef __m256i real_m256i;
+#define __m256i mm256i_emu
+
+
+static inline mm256i_emu mm256_loadu_si256(const mm256i_emu *src) {
+ mm256i_emu ret;
+ ret.lo = _mm_loadu_si128((const __m128i*)src);
+ ret.hi = _mm_loadu_si128((const __m128i*)(&((const char *)src)[16]));
+ return ret;
+}
+#define _mm256_loadu_si256(src) mm256_loadu_si256(src)
+
+
+static inline void mm256_storeu_si256(mm256i_emu *dst, mm256i_emu src) {
+ _mm_storeu_si128((__m128i*)dst, src.lo);
+ _mm_storeu_si128((__m128i*)(&((char *)dst)[16]), src.hi);
+}
+#define _mm256_storeu_si256(dst, src) mm256_storeu_si256(dst, src)
+
+
+static inline mm256i_emu mm256_set1_epi32(int x) {
+ mm256i_emu ret;
+ ret.lo = _mm_set1_epi32(x);
+ ret.hi = ret.lo;
+ return ret;
+}
+#define _mm256_set1_epi32(x) mm256_set1_epi32(x)
+
+static inline mm256i_emu mm256_set1_epi16(int x) {
+ mm256i_emu ret;
+ ret.lo = _mm_set1_epi16(x);
+ ret.hi = ret.lo;
+ return ret;
+}
+#define _mm256_set1_epi16(x) mm256_set1_epi16(x)
+
+
+static inline mm256i_emu mm256_add_epi32(mm256i_emu a, mm256i_emu b) {
+ mm256i_emu ret;
+ ret.lo = _mm_add_epi32(a.lo, b.lo);
+ ret.hi = _mm_add_epi32(a.hi, b.hi);
+ return ret;
+}
+#define _mm256_add_epi32(a,b) mm256_add_epi32(a,b)
+
+static inline mm256i_emu mm256_madd_epi16(mm256i_emu a, mm256i_emu b) {
+ mm256i_emu ret;
+ ret.lo = _mm_madd_epi16(a.lo, b.lo);
+ ret.hi = _mm_madd_epi16(a.hi, b.hi);
+ return ret;
+}
+#define _mm256_madd_epi16(a,b) mm256_madd_epi16(a,b)
+
+static inline mm256i_emu mm256_maddubs_epi16(mm256i_emu a, mm256i_emu b) {
+ mm256i_emu ret;
+ ret.lo = _mm_maddubs_epi16(a.lo, b.lo);
+ ret.hi = _mm_maddubs_epi16(a.hi, b.hi);
+ return ret;
+}
+#define _mm256_maddubs_epi16(a,b) mm256_maddubs_epi16(a,b)
+
+
+
+/* Emulating the conversion functions is tricky because they use __m256i but are defined in AVX.
+ So we need to make a special when only AVX is available. */
+#ifdef __AVX__
+
+typedef union {
+ mm256i_emu fake;
+ real_m256i real;
+} mm256_union;
+
+static inline __m256 mm256_cvtepi32_ps(mm256i_emu a) {
+ mm256_union src;
+ src.fake = a;
+ return _mm256_cvtepi32_ps(src.real);
+}
+#define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
+
+static inline mm256i_emu mm256_cvtps_epi32(__m256 a) {
+ mm256_union ret;
+ ret.real = _mm256_cvtps_epi32(a);
+ return ret.fake;
+}
+#define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
+
+
+#else
+
+static inline mm256_emu mm256_cvtepi32_ps(mm256i_emu a) {
+ mm256_emu ret;
+ ret.lo = _mm_cvtepi32_ps(a.lo);
+ ret.hi = _mm_cvtepi32_ps(a.hi);
+ return ret;
+}
+#define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
+
+static inline mm256i_emu mm256_cvtps_epi32(mm256_emu a) {
+ mm256i_emu ret;
+ ret.lo = _mm_cvtps_epi32(a.lo);
+ ret.hi = _mm_cvtps_epi32(a.hi);
+ return ret;
+}
+#define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
+
+#endif /* __AVX__ */
+
+
+#endif /* __AVX2__ */
+
+/* In case we don't have FMA, make it a mul and an add. */
+#if !(defined(__FMA__) && defined(__AVX__))
#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
#endif
@@ -67,6 +294,72 @@
return Y;
}
+static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
+ int i;
+ __m256 const127 = _mm256_set1_ps(127.f);
+ for (i=0;i<len;i+=8) {
+ __m256 xf;
+ __m256i xi;
+ xf = _mm256_loadu_ps(&_x[i]);
+ //xf = _mm256_mul_ps(xf, const127);
+ //xf = _mm256_add_ps(xf, const127);
+ xf = _mm256_fmadd_ps(xf, const127, const127);
+ xi = _mm256_cvtps_epi32(xf);
+ xi = _mm256_packus_epi32(xi, _mm256_setzero_si256());
+ xi = _mm256_permute4x64_epi64(xi, 0xD8);
+ xi = _mm256_packus_epi16(xi, _mm256_setzero_si256());
+ xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0));
+ //xi = _mm256_permute4x64_epi64(xi, 0x);
+ _mm256_storeu_si256 ((__m256i *)&x[i], xi);
+ }
+}
+
+#else
+static inline __m128 exp4_approx(__m128 X)
+{
+ const __m128 K0 = _mm_set1_ps(0.99992522f);
+ const __m128 K1 = _mm_set1_ps(0.69583354f);
+ const __m128 K2 = _mm_set1_ps(0.22606716f);
+ const __m128 K3 = _mm_set1_ps(0.078024523f);
+ const __m128 log2_E = _mm_set1_ps(1.44269504);
+ const __m128 max_in = _mm_set1_ps(50.f);
+ const __m128 min_in = _mm_set1_ps(-50.f);
+ const __m128i mask = _mm_set1_epi32(0x7fffffff);
+ __m128 XF, Y;
+ __m128i I;
+ X = _mm_mul_ps(X, log2_E);
+ X = _mm_max_ps(min_in, _mm_min_ps(max_in, X));
+ XF = _mm_floor_ps(X);
+ I = _mm_cvtps_epi32(XF);
+ X = _mm_sub_ps(X, XF);
+ Y = _mm_fmadd_ps(_mm_fmadd_ps(_mm_fmadd_ps(K3, X, K2), X, K1), X, K0);
+ I = _mm_slli_epi32(I, 23);
+ Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
+ return Y;
+}
+static inline __m256 exp8_approx(__m256 X)
+{
+ __m256 Y;
+ __m128 Xhi, Xlo, Yhi, Ylo;
+ Xhi = _mm256_extractf128_ps(X, 1);
+ Xlo = _mm256_extractf128_ps(X, 0);
+ Yhi = exp4_approx(Xhi);
+ Ylo = exp4_approx(Xlo);
+ Y = _mm256_insertf128_ps(_mm256_setzero_ps(), Yhi, 1);
+ Y = _mm256_insertf128_ps(Y, Ylo, 0);
+ return Y;
+}
+
+static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
+ int i;
+ for (i=0;i<len;i++) x[i] = 127+floor(.5+127*_x[i]);
+}
+
+#endif
+
+
+#ifdef __AVX__
+
/* Approximating tanh() using a Padé-like rational function:
tanh(x) ~= x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
subject to the +/- 1 bounds.
@@ -125,42 +418,28 @@
return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
}
-#else
-static inline __m128 exp4_approx(__m128 X)
+static inline float tanh_approx(float x)
{
- const __m128 K0 = _mm_set1_ps(0.99992522f);
- const __m128 K1 = _mm_set1_ps(0.69583354f);
- const __m128 K2 = _mm_set1_ps(0.22606716f);
- const __m128 K3 = _mm_set1_ps(0.078024523f);
- const __m128 log2_E = _mm_set1_ps(1.44269504);
- const __m128 max_in = _mm_set1_ps(50.f);
- const __m128 min_in = _mm_set1_ps(-50.f);
- const __m128i mask = _mm_set1_epi32(0x7fffffff);
- __m128 XF, Y;
- __m128i I;
- X = _mm_mul_ps(X, log2_E);
- X = _mm_max_ps(min_in, _mm_min_ps(max_in, X));
- XF = _mm_floor_ps(X);
- I = _mm_cvtps_epi32(XF);
- X = _mm_sub_ps(X, XF);
- Y = _mm_fmadd_ps(_mm_fmadd_ps(_mm_fmadd_ps(K3, X, K2), X, K1), X, K0);
- I = _mm_slli_epi32(I, 23);
- Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
- return Y;
+ float out[8];
+ __m256 X, Y;
+ X = _mm256_set1_ps(x);
+ Y = tanh8_approx(X);
+ _mm256_storeu_ps(out, Y);
+ return out[0];
}
-static inline __m256 exp8_approx(__m256 X)
+
+static inline float sigmoid_approx(float x)
{
- __m256 Y;
- __m128 Xhi, Xlo, Yhi, Ylo;
- Xhi = _mm256_extractf128_ps(X, 1);
- Xlo = _mm256_extractf128_ps(X, 0);
- Yhi = exp4_approx(Xhi);
- Ylo = exp4_approx(Xlo);
- Y = _mm256_insertf128_ps(_mm256_setzero_ps(), Yhi, 1);
- Y = _mm256_insertf128_ps(Y, Ylo, 0);
- return Y;
+ float out[8];
+ __m256 X, Y;
+ X = _mm256_set1_ps(x);
+ Y = sigmoid8_approx(X);
+ _mm256_storeu_ps(out, Y);
+ return out[0];
}
+#else
+
static inline __m128 tanh4_approx(__m128 X)
{
const __m128 N0 = _mm_set1_ps(952.52801514f);
@@ -202,34 +481,34 @@
return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
}
-#endif
-
-static inline float celt_exp(float x)
+static inline float tanh_approx(float x)
{
- float out[8];
- __m256 X, Y;
- X = _mm256_set1_ps(x);
- Y = exp8_approx(X);
- _mm256_storeu_ps(out, Y);
+ float out[4];
+ __m128 X, Y;
+ X = _mm_set1_ps(x);
+ Y = tanh4_approx(X);
+ _mm_storeu_ps(out, Y);
return out[0];
}
-static inline float tanh_approx(float x)
+static inline float sigmoid_approx(float x)
{
- float out[8];
- __m256 X, Y;
- X = _mm256_set1_ps(x);
- Y = tanh8_approx(X);
- _mm256_storeu_ps(out, Y);
+ float out[4];
+ __m128 X, Y;
+ X = _mm_set1_ps(x);
+ Y = sigmoid4_approx(X);
+ _mm_storeu_ps(out, Y);
return out[0];
}
-static inline float sigmoid_approx(float x)
+#endif
+
+static inline float celt_exp(float x)
{
float out[8];
__m256 X, Y;
X = _mm256_set1_ps(x);
- Y = sigmoid8_approx(X);
+ Y = exp8_approx(X);
_mm256_storeu_ps(out, Y);
return out[0];
}
@@ -248,7 +527,7 @@
y[i] = celt_exp(x[i]);
}
-#ifdef __AVX2__
+#ifdef __AVX__
static inline void vec_tanh(float *y, const float *x, int N)
{
int i;
@@ -395,22 +674,7 @@
(void)col_stride;
ones = _mm256_set1_epi16(1);
//for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
- __m256 const127 = _mm256_set1_ps(127.f);
- for (i=0;i<cols;i+=8) {
- __m256 xf;
- __m256i xi;
- xf = _mm256_loadu_ps(&_x[i]);
- //xf = _mm256_mul_ps(xf, const127);
- //xf = _mm256_add_ps(xf, const127);
- xf = _mm256_fmadd_ps(xf, const127, const127);
- xi = _mm256_cvtps_epi32(xf);
- xi = _mm256_packus_epi32(xi, _mm256_setzero_si256());
- xi = _mm256_permute4x64_epi64(xi, 0xD8);
- xi = _mm256_packus_epi16(xi, _mm256_setzero_si256());
- xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0));
- //xi = _mm256_permute4x64_epi64(xi, 0x);
- _mm256_storeu_si256 ((__m256i *)&x[i], xi);
- }
+ vector_ps_to_epi8(x, _x, cols);
for (i=0;i<rows;i+=8)
{
__m256i vy0;
@@ -509,22 +773,7 @@
unsigned char x[MAX_INPUTS];
ones = _mm256_set1_epi16(1);
//for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
- __m256 const127 = _mm256_set1_ps(127.f);
- for (i=0;i<cols;i+=8) {
- __m256 xf;
- __m256i xi;
- xf = _mm256_loadu_ps(&_x[i]);
- //xf = _mm256_mul_ps(xf, const127);
- //xf = _mm256_add_ps(xf, const127);
- xf = _mm256_fmadd_ps(xf, const127, const127);
- xi = _mm256_cvtps_epi32(xf);
- xi = _mm256_packus_epi32(xi, _mm256_setzero_si256());
- xi = _mm256_permute4x64_epi64(xi, 0xD8);
- xi = _mm256_packus_epi16(xi, _mm256_setzero_si256());
- xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0));
- //xi = _mm256_permute4x64_epi64(xi, 0x);
- _mm256_storeu_si256 ((__m256i *)&x[i], xi);
- }
+ vector_ps_to_epi8(x, _x, cols);
for (i=0;i<rows;i+=8)
{
int colblocks;
--
⑨