shithub: opus

Download patch

ref: e695355ba5ee48eaab0c5c0f65879ac412210add
parent: 06489b42ddc25b2da0b52c78a43ec12ccd9c6bb1
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Sat Dec 26 21:00:36 EST 2020

some cleanup

--- a/dnn/training_tf2/dump_lpcnet.py
+++ b/dnn/training_tf2/dump_lpcnet.py
@@ -84,7 +84,7 @@
     W = np.minimum(127, np.maximum(-128, np.round(W*128)))
     printVector(f, W.astype('int'), name, dtype='qweight')
     f.write('#else /*DOT_PROD*/\n')
-    printVector(f, W0, name, dtype='nnet_weight')
+    printVector(f, W0, name, dtype='qweight')
     f.write('#endif /*DOT_PROD*/\n')
     #idx = np.tile(np.concatenate([np.array([N]), np.arange(N)]), 3*N//16)
     printVector(f, idx, name + '_idx', dtype='int')
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -34,11 +34,6 @@
 #include <math.h>
 #include "arch.h"
 
-#ifdef DOT_PROD
-typedef signed char qweight;
-#else
-typedef float qweight;
-#endif
 
 #ifdef __AVX__
 #include "vec_avx.h"
@@ -46,9 +41,18 @@
 #include "vec_neon.h"
 #else
 
-//#define USE_SU_BIAS
 
 #define NO_OPTIMIZATIONS
+
+//#define DOT_PROD
+//#define USE_SU_BIAS
+
+#ifdef DOT_PROD
+typedef signed char qweight;
+#else
+typedef float qweight;
+#endif
+
 
 /* No AVX2/FMA support */
 #ifndef LPCNET_TEST
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -34,6 +34,9 @@
 
 #include <immintrin.h>
 
+#define DOT_PROD
+#define USE_SU_BIAS
+
 #ifdef __AVX2__
 static inline __m256 exp8_approx(__m256 X)
 {
@@ -222,9 +225,15 @@
 }
 
 #ifdef DOT_PROD
-
 #define USE_SU_BIAS
 
+#ifdef DOT_PROD
+typedef signed char qweight;
+#else
+typedef float qweight;
+#endif
+
+
 #define MAX_INPUTS (2048)
 #define MAX_OUTPUTS (8192)
 
@@ -232,7 +241,6 @@
 #define SCALE (128.f*127.f)
 #define SCALE_1 (1.f/128.f/127.f)
 
-#if 1
 
 static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
 {
@@ -285,43 +293,7 @@
    }
    for (i=0;i<rows;i++) _out[i] = SCALE_1*out[i];
 }
-#else
-static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
-{
-   int i, j;
-   unsigned char x[MAX_INPUTS];
-   int out[MAX_OUTPUTS];
-   for (i=0;i<rows;i++) out[i] = SCALE*_out[i];
-   for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
-   for (i=0;i<rows;i+=8)
-   {
-      int * restrict y;
-      int colblocks;
-      colblocks = *idx++;
-      y = &out[i];
-      for (j=0;j<colblocks;j++)
-      {
-         int pos;
-         int xj0, xj1, xj2, xj3;
-         pos = 4 * (*idx++);
-         xj0 = x[pos+0];
-         xj1 = x[pos+1];
-         xj2 = x[pos+2];
-         xj3 = x[pos+3];
-         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
-         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
-         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
-         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
-         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
-         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
-         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
-         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
-         w += 32;
-      }
-   }
-   for (i=0;i<rows;i++) _out[i] = SCALE_1*out[i];
-}
-#endif
+
 
 #else /*DOT_PROD*/
 static inline void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, int ignore, const int *idx, const float *x)
--