shithub: opus

Download patch

ref: 116bcb38fb7bddb75a52b2da52ef536aadd4f3e1
parent: 3e223e6015d2740fcfd4c50d4fdc2d20349fa565
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Sat Jul 10 10:08:01 EDT 2021

Adding SSE 4.1 for older platforms

AVX without AVX2 should now work again too.

--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -35,7 +35,7 @@
 #include "arch.h"
 
 
-#ifdef __AVX__
+#if defined(__AVX__) || defined(__SSE4_1__)
 #include "vec_avx.h"
 #elif defined(__ARM_NEON__) || defined(__ARM_NEON)
 #include "vec_neon.h"
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -39,7 +39,234 @@
 #define USE_SU_BIAS
 #endif
 
-#ifndef __FMA__
+/* If we don't have AVX available, emulate what we need with SSE up to 4.1. */
+#ifndef __AVX__
+
+typedef struct {
+  __m128 lo;
+  __m128 hi;
+} mm256_emu;
+#define __m256 mm256_emu
+
+static inline mm256_emu mm256_loadu_ps(const float *src) {
+  mm256_emu ret;
+  ret.lo = _mm_loadu_ps(&src[0]);
+  ret.hi = _mm_loadu_ps(&src[4]);
+  return ret;
+}
+#define _mm256_loadu_ps(src) mm256_loadu_ps(src)
+
+
+static inline void mm256_storeu_ps(float *dst, mm256_emu src) {
+  _mm_storeu_ps(dst, src.lo);
+  _mm_storeu_ps(&dst[4], src.hi);
+}
+#define _mm256_storeu_ps(dst, src) mm256_storeu_ps(dst, src)
+
+
+static inline mm256_emu mm256_setzero_ps() {
+  mm256_emu ret;
+  ret.lo = _mm_setzero_ps();
+  ret.hi = ret.lo;
+  return ret;
+}
+#define _mm256_setzero_ps mm256_setzero_ps
+
+static inline mm256_emu mm256_broadcast_ss(const float *x) {
+  mm256_emu ret;
+  ret.lo = _mm_set1_ps(*x);
+  ret.hi = ret.lo;
+  return ret;
+}
+#define _mm256_broadcast_ss(x) mm256_broadcast_ss(x)
+
+static inline mm256_emu mm256_set1_ps(float x) {
+  mm256_emu ret;
+  ret.lo = _mm_set1_ps(x);
+  ret.hi = ret.lo;
+  return ret;
+}
+#define _mm256_set1_ps(x) mm256_set1_ps(x)
+
+
+
+static inline mm256_emu mm256_mul_ps(mm256_emu a, mm256_emu b) {
+  mm256_emu ret;
+  ret.lo = _mm_mul_ps(a.lo, b.lo);
+  ret.hi = _mm_mul_ps(a.hi, b.hi);
+  return ret;
+}
+#define _mm256_mul_ps(a,b) mm256_mul_ps(a,b)
+
+static inline mm256_emu mm256_add_ps(mm256_emu a, mm256_emu b) {
+  mm256_emu ret;
+  ret.lo = _mm_add_ps(a.lo, b.lo);
+  ret.hi = _mm_add_ps(a.hi, b.hi);
+  return ret;
+}
+#define _mm256_add_ps(a,b) mm256_add_ps(a,b)
+
+
+static inline mm256_emu mm256_max_ps(mm256_emu a, mm256_emu b) {
+  mm256_emu ret;
+  ret.lo = _mm_max_ps(a.lo, b.lo);
+  ret.hi = _mm_max_ps(a.hi, b.hi);
+  return ret;
+}
+#define _mm256_max_ps(a,b) mm256_max_ps(a,b)
+
+static inline mm256_emu mm256_min_ps(mm256_emu a, mm256_emu b) {
+  mm256_emu ret;
+  ret.lo = _mm_min_ps(a.lo, b.lo);
+  ret.hi = _mm_min_ps(a.hi, b.hi);
+  return ret;
+}
+#define _mm256_min_ps(a,b) mm256_min_ps(a,b)
+
+static inline mm256_emu mm256_rcp_ps(mm256_emu a) {
+  mm256_emu ret;
+  ret.lo = _mm_rcp_ps(a.lo);
+  ret.hi = _mm_rcp_ps(a.hi);
+  return ret;
+}
+#define _mm256_rcp_ps(a) mm256_rcp_ps(a)
+
+
+static inline __m128 mm256_extractf128_ps(mm256_emu x, int i) {
+    return (i==0) ? x.lo : x.hi;
+}
+#define _mm256_extractf128_ps(x,i) mm256_extractf128_ps(x,i)
+
+static inline mm256_emu mm256_insertf128_ps(mm256_emu dst, __m128 src, int i) {
+    if (i==0) dst.lo = src;
+    else dst.hi = src;
+    return dst;
+}
+#define _mm256_insertf128_ps(dst,src,i) mm256_insertf128_ps(dst,src,i)
+
+#endif /* __AVX__ */
+
+
+
+/* If we don't have AVX2 available, emulate what we need with SSE up to 4.1. */
+#ifndef __AVX2__
+
+typedef struct {
+  __m128i lo;
+  __m128i hi;
+} mm256i_emu;
+typedef __m256i real_m256i;
+#define __m256i mm256i_emu
+
+
+static inline mm256i_emu mm256_loadu_si256(const mm256i_emu *src) {
+  mm256i_emu ret;
+  ret.lo = _mm_loadu_si128((const __m128i*)src);
+  ret.hi = _mm_loadu_si128((const __m128i*)(&((const char *)src)[16]));
+  return ret;
+}
+#define _mm256_loadu_si256(src) mm256_loadu_si256(src)
+
+
+static inline void mm256_storeu_si256(mm256i_emu *dst, mm256i_emu src) {
+  _mm_storeu_si128((__m128i*)dst, src.lo);
+  _mm_storeu_si128((__m128i*)(&((char *)dst)[16]), src.hi);
+}
+#define _mm256_storeu_si256(dst, src) mm256_storeu_si256(dst, src)
+
+
+static inline mm256i_emu mm256_set1_epi32(int x) {
+  mm256i_emu ret;
+  ret.lo = _mm_set1_epi32(x);
+  ret.hi = ret.lo;
+  return ret;
+}
+#define _mm256_set1_epi32(x) mm256_set1_epi32(x)
+
+static inline mm256i_emu mm256_set1_epi16(int x) {
+  mm256i_emu ret;
+  ret.lo = _mm_set1_epi16(x);
+  ret.hi = ret.lo;
+  return ret;
+}
+#define _mm256_set1_epi16(x) mm256_set1_epi16(x)
+
+
+static inline mm256i_emu mm256_add_epi32(mm256i_emu a, mm256i_emu b) {
+  mm256i_emu ret;
+  ret.lo = _mm_add_epi32(a.lo, b.lo);
+  ret.hi = _mm_add_epi32(a.hi, b.hi);
+  return ret;
+}
+#define _mm256_add_epi32(a,b) mm256_add_epi32(a,b)
+
+static inline mm256i_emu mm256_madd_epi16(mm256i_emu a, mm256i_emu b) {
+  mm256i_emu ret;
+  ret.lo = _mm_madd_epi16(a.lo, b.lo);
+  ret.hi = _mm_madd_epi16(a.hi, b.hi);
+  return ret;
+}
+#define _mm256_madd_epi16(a,b) mm256_madd_epi16(a,b)
+
+static inline mm256i_emu mm256_maddubs_epi16(mm256i_emu a, mm256i_emu b) {
+  mm256i_emu ret;
+  ret.lo = _mm_maddubs_epi16(a.lo, b.lo);
+  ret.hi = _mm_maddubs_epi16(a.hi, b.hi);
+  return ret;
+}
+#define _mm256_maddubs_epi16(a,b) mm256_maddubs_epi16(a,b)
+
+
+
+/* Emulating the conversion functions is tricky because they use __m256i but are defined in AVX.
+   So we need to make a special when only AVX is available. */
+#ifdef __AVX__
+
+typedef union {
+  mm256i_emu fake;
+  real_m256i real;
+} mm256_union;
+
+static inline __m256 mm256_cvtepi32_ps(mm256i_emu a) {
+  mm256_union src;
+  src.fake = a;
+  return _mm256_cvtepi32_ps(src.real);
+}
+#define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
+
+static inline mm256i_emu mm256_cvtps_epi32(__m256 a) {
+  mm256_union ret;
+  ret.real =   _mm256_cvtps_epi32(a);
+  return ret.fake;
+}
+#define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
+
+
+#else
+
+static inline mm256_emu mm256_cvtepi32_ps(mm256i_emu a) {
+  mm256_emu ret;
+  ret.lo = _mm_cvtepi32_ps(a.lo);
+  ret.hi = _mm_cvtepi32_ps(a.hi);
+  return ret;
+}
+#define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
+
+static inline mm256i_emu mm256_cvtps_epi32(mm256_emu a) {
+  mm256i_emu ret;
+  ret.lo = _mm_cvtps_epi32(a.lo);
+  ret.hi = _mm_cvtps_epi32(a.hi);
+  return ret;
+}
+#define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
+
+#endif /* __AVX__ */
+
+
+#endif /* __AVX2__ */
+
+/* In case we don't have FMA, make it a mul and an add. */
+#if !(defined(__FMA__) && defined(__AVX__))
 #define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
 #define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
 #endif
@@ -67,6 +294,72 @@
    return Y;
 }
 
+static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
+    int i;
+   __m256 const127 = _mm256_set1_ps(127.f);
+    for (i=0;i<len;i+=8) {
+       __m256 xf;
+       __m256i xi;
+       xf = _mm256_loadu_ps(&_x[i]);
+       //xf = _mm256_mul_ps(xf, const127);
+       //xf = _mm256_add_ps(xf, const127);
+       xf = _mm256_fmadd_ps(xf, const127, const127);
+       xi = _mm256_cvtps_epi32(xf);
+       xi = _mm256_packus_epi32(xi,  _mm256_setzero_si256());
+       xi = _mm256_permute4x64_epi64(xi, 0xD8);
+       xi = _mm256_packus_epi16(xi, _mm256_setzero_si256());
+       xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0));
+       //xi = _mm256_permute4x64_epi64(xi, 0x);
+       _mm256_storeu_si256 ((__m256i *)&x[i], xi);
+   }
+}
+
+#else
+static inline __m128 exp4_approx(__m128 X)
+{
+   const __m128 K0 = _mm_set1_ps(0.99992522f);
+   const __m128 K1 = _mm_set1_ps(0.69583354f);
+   const __m128 K2 = _mm_set1_ps(0.22606716f);
+   const __m128 K3 = _mm_set1_ps(0.078024523f);
+   const __m128 log2_E = _mm_set1_ps(1.44269504);
+   const __m128 max_in = _mm_set1_ps(50.f);
+   const __m128 min_in = _mm_set1_ps(-50.f);
+   const __m128i mask = _mm_set1_epi32(0x7fffffff);
+   __m128 XF, Y;
+   __m128i I;
+   X = _mm_mul_ps(X, log2_E);
+   X = _mm_max_ps(min_in, _mm_min_ps(max_in, X));
+   XF = _mm_floor_ps(X);
+   I = _mm_cvtps_epi32(XF);
+   X = _mm_sub_ps(X, XF);
+   Y = _mm_fmadd_ps(_mm_fmadd_ps(_mm_fmadd_ps(K3, X, K2), X, K1), X, K0);
+   I = _mm_slli_epi32(I, 23);
+   Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
+   return Y;
+}
+static inline __m256 exp8_approx(__m256 X)
+{
+   __m256 Y;
+   __m128 Xhi, Xlo, Yhi, Ylo;
+   Xhi = _mm256_extractf128_ps(X, 1);
+   Xlo = _mm256_extractf128_ps(X, 0);
+   Yhi = exp4_approx(Xhi);
+   Ylo = exp4_approx(Xlo);
+   Y = _mm256_insertf128_ps(_mm256_setzero_ps(), Yhi, 1);
+   Y = _mm256_insertf128_ps(Y, Ylo, 0);
+   return Y;
+}
+
+static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
+    int i;
+    for (i=0;i<len;i++) x[i] = 127+floor(.5+127*_x[i]);
+}
+
+#endif
+
+
+#ifdef __AVX__
+
 /* Approximating tanh() using a Padé-like rational function:
    tanh(x) ~= x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
    subject to the +/- 1 bounds.
@@ -125,42 +418,28 @@
    return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
 }
 
-#else
-static inline __m128 exp4_approx(__m128 X)
+static inline float tanh_approx(float x)
 {
-   const __m128 K0 = _mm_set1_ps(0.99992522f);
-   const __m128 K1 = _mm_set1_ps(0.69583354f);
-   const __m128 K2 = _mm_set1_ps(0.22606716f);
-   const __m128 K3 = _mm_set1_ps(0.078024523f);
-   const __m128 log2_E = _mm_set1_ps(1.44269504);
-   const __m128 max_in = _mm_set1_ps(50.f);
-   const __m128 min_in = _mm_set1_ps(-50.f);
-   const __m128i mask = _mm_set1_epi32(0x7fffffff);
-   __m128 XF, Y;
-   __m128i I;
-   X = _mm_mul_ps(X, log2_E);
-   X = _mm_max_ps(min_in, _mm_min_ps(max_in, X));
-   XF = _mm_floor_ps(X);
-   I = _mm_cvtps_epi32(XF);
-   X = _mm_sub_ps(X, XF);
-   Y = _mm_fmadd_ps(_mm_fmadd_ps(_mm_fmadd_ps(K3, X, K2), X, K1), X, K0);
-   I = _mm_slli_epi32(I, 23);
-   Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
-   return Y;
+   float out[8];
+   __m256 X, Y;
+   X = _mm256_set1_ps(x);
+   Y = tanh8_approx(X);
+   _mm256_storeu_ps(out, Y);
+   return out[0];
 }
-static inline __m256 exp8_approx(__m256 X)
+
+static inline float sigmoid_approx(float x)
 {
-   __m256 Y;
-   __m128 Xhi, Xlo, Yhi, Ylo;
-   Xhi = _mm256_extractf128_ps(X, 1);
-   Xlo = _mm256_extractf128_ps(X, 0);
-   Yhi = exp4_approx(Xhi);
-   Ylo = exp4_approx(Xlo);
-   Y = _mm256_insertf128_ps(_mm256_setzero_ps(), Yhi, 1);
-   Y = _mm256_insertf128_ps(Y, Ylo, 0);
-   return Y;
+   float out[8];
+   __m256 X, Y;
+   X = _mm256_set1_ps(x);
+   Y = sigmoid8_approx(X);
+   _mm256_storeu_ps(out, Y);
+   return out[0];
 }
 
+#else
+
 static inline __m128 tanh4_approx(__m128 X)
 {
    const __m128 N0 = _mm_set1_ps(952.52801514f);
@@ -202,34 +481,34 @@
    return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
 }
 
-#endif
-
-static inline float celt_exp(float x)
+static inline float tanh_approx(float x)
 {
-   float out[8];
-   __m256 X, Y;
-   X = _mm256_set1_ps(x);
-   Y = exp8_approx(X);
-   _mm256_storeu_ps(out, Y);
+   float out[4];
+   __m128 X, Y;
+   X = _mm_set1_ps(x);
+   Y = tanh4_approx(X);
+   _mm_storeu_ps(out, Y);
    return out[0];
 }
 
-static inline float tanh_approx(float x)
+static inline float sigmoid_approx(float x)
 {
-   float out[8];
-   __m256 X, Y;
-   X = _mm256_set1_ps(x);
-   Y = tanh8_approx(X);
-   _mm256_storeu_ps(out, Y);
+   float out[4];
+   __m128 X, Y;
+   X = _mm_set1_ps(x);
+   Y = sigmoid4_approx(X);
+   _mm_storeu_ps(out, Y);
    return out[0];
 }
 
-static inline float sigmoid_approx(float x)
+#endif
+
+static inline float celt_exp(float x)
 {
    float out[8];
    __m256 X, Y;
    X = _mm256_set1_ps(x);
-   Y = sigmoid8_approx(X);
+   Y = exp8_approx(X);
    _mm256_storeu_ps(out, Y);
    return out[0];
 }
@@ -248,7 +527,7 @@
         y[i] = celt_exp(x[i]);
 }
 
-#ifdef __AVX2__
+#ifdef __AVX__
 static inline void vec_tanh(float *y, const float *x, int N)
 {
     int i;
@@ -395,22 +674,7 @@
    (void)col_stride;
    ones = _mm256_set1_epi16(1);
    //for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
-   __m256 const127 = _mm256_set1_ps(127.f);
-   for (i=0;i<cols;i+=8) {
-       __m256 xf;
-       __m256i xi;
-       xf = _mm256_loadu_ps(&_x[i]);
-       //xf = _mm256_mul_ps(xf, const127);
-       //xf = _mm256_add_ps(xf, const127);
-       xf = _mm256_fmadd_ps(xf, const127, const127);
-       xi = _mm256_cvtps_epi32(xf);
-       xi = _mm256_packus_epi32(xi,  _mm256_setzero_si256());
-       xi = _mm256_permute4x64_epi64(xi, 0xD8);
-       xi = _mm256_packus_epi16(xi, _mm256_setzero_si256());
-       xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0));
-       //xi = _mm256_permute4x64_epi64(xi, 0x);
-       _mm256_storeu_si256 ((__m256i *)&x[i], xi);
-   }
+   vector_ps_to_epi8(x, _x, cols);
    for (i=0;i<rows;i+=8)
    {
       __m256i vy0;
@@ -509,22 +773,7 @@
    unsigned char x[MAX_INPUTS];
    ones = _mm256_set1_epi16(1);
    //for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
-   __m256 const127 = _mm256_set1_ps(127.f);
-   for (i=0;i<cols;i+=8) {
-       __m256 xf;
-       __m256i xi;
-       xf = _mm256_loadu_ps(&_x[i]);
-       //xf = _mm256_mul_ps(xf, const127);
-       //xf = _mm256_add_ps(xf, const127);
-       xf = _mm256_fmadd_ps(xf, const127, const127);
-       xi = _mm256_cvtps_epi32(xf);
-       xi = _mm256_packus_epi32(xi,  _mm256_setzero_si256());
-       xi = _mm256_permute4x64_epi64(xi, 0xD8);
-       xi = _mm256_packus_epi16(xi, _mm256_setzero_si256());
-       xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0));
-       //xi = _mm256_permute4x64_epi64(xi, 0x);
-       _mm256_storeu_si256 ((__m256i *)&x[i], xi);
-   }
+   vector_ps_to_epi8(x, _x, cols);
    for (i=0;i<rows;i+=8)
    {
       int colblocks;
--