shithub: opus

Download patch

ref: ec671ed90e7ec05efff65115c9d20dcdd519090b
parent: 15fb1b3c774d2c7b1fb5c8c00058f2103ee68698
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Wed Nov 28 09:57:22 EST 2018

Quick and dirty AVX2 implementation of gemm_accum

Brings us very close to real-time

--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -77,13 +77,81 @@
    return x < 0 ? 0 : x;
 }
 
-static void gemm_accum(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+#ifdef __AVX2__
+#include <immintrin.h>
+static void gemm_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
 {
    int i, j;
-   for (i=0;i<rows;i++)
+   for (i=0;i<rows;i+=16)
    {
+      float * restrict y;
+      __m256 vy0, vy8;
+      y = &out[i];
+      vy0 = _mm256_loadu_ps(&y[0]);
+      vy8 = _mm256_loadu_ps(&y[8]);
       for (j=0;j<cols;j++)
-         out[i] += weights[j*col_stride + i]*x[j];
+      {
+         __m256 vxj;
+         __m256 vw;
+         vxj = _mm256_broadcast_ss(&x[j]);
+
+         vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
+         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+
+         vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
+         vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
+      }
+      _mm256_storeu_ps (&y[0], vy0);
+      _mm256_storeu_ps (&y[8], vy8);
+   }
+}
+#else
+static void gemm_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+   int i, j;
+   for (i=0;i<rows;i+=16)
+   {
+      for (j=0;j<cols;j++)
+      {
+         const float * restrict w;
+         float * restrict y;
+         float xj;
+         w = &weights[j*col_stride + i];
+         xj = x[j];
+         y = &out[i];
+         y[0] += w[0]*xj;
+         y[1] += w[1]*xj;
+         y[2] += w[2]*xj;
+         y[3] += w[3]*xj;
+         y[4] += w[4]*xj;
+         y[5] += w[5]*xj;
+         y[6] += w[6]*xj;
+         y[7] += w[7]*xj;
+         y[8] += w[8]*xj;
+         y[9] += w[9]*xj;
+         y[10] += w[10]*xj;
+         y[11] += w[11]*xj;
+         y[12] += w[12]*xj;
+         y[13] += w[13]*xj;
+         y[14] += w[14]*xj;
+         y[15] += w[15]*xj;
+      }
+   }
+}
+#endif
+
+static void gemm_accum(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+   int i, j;
+   if (rows % 16 == 0 && cols % 16 == 0)
+   {
+      gemm_accum16(out, weights, rows, cols, col_stride, x);
+   } else {
+      for (i=0;i<rows;i++)
+      {
+         for (j=0;j<cols;j++)
+            out[i] += weights[j*col_stride + i]*x[j];
+      }
    }
 }
 
--