shithub: opus

Download patch

ref: 62cd1c963baf1fdccb990260bb5dc53b780e45eb
parent: f5a68a41b01d7edcd629ec9b284302df1518e922
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Jul 19 20:39:24 EDT 2023

Transition to LinearLayer and remove unused code

--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -69,22 +69,6 @@
    return x < 0 ? 0 : x;
 }
 
-
-static void sgemv_accum(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
-{
-   int i, j;
-   if (rows % 16 == 0)
-   {
-      sgemv_accum16(out, weights, rows, cols, col_stride, x);
-   } else {
-      for (i=0;i<rows;i++)
-      {
-         for (j=0;j<cols;j++)
-            out[i] += weights[j*col_stride + i]*x[j];
-      }
-   }
-}
-
 void compute_linear(const LinearLayer *linear, float *out, const float *in)
 {
    int i, M, N;
@@ -186,24 +170,8 @@
    }
 }
 
-#if 1
 void _lpcnet_compute_dense(const DenseLayer *layer, float *output, const float *input)
 {
-   int i;
-   int N, M;
-   int stride;
-   M = layer->nb_inputs;
-   N = layer->nb_neurons;
-   stride = N;
-   celt_assert(input != output);
-   for (i=0;i<N;i++)
-      output[i] = layer->bias[i];
-   sgemv_accum(output, layer->input_weights, N, M, stride, input);
-   compute_activation(output, output, N, layer->activation);
-}
-#else
-void _lpcnet_compute_dense(const DenseLayer *layer, float *output, const float *input)
-{
    LinearLayer matrix;
    celt_assert(input != output);
    matrix.bias = layer->bias;
@@ -218,7 +186,6 @@
    compute_linear(&matrix, output, input);
    compute_activation(output, output, layer->nb_neurons, layer->activation);
 }
-#endif
 
 int sample_mdense(const MDenseLayer *layer, const float *input, const float *sampling_logit_table, kiss99_ctx *rng)
 {
@@ -280,61 +247,8 @@
 #endif
 #define MAX_IDX_SIZE 8192
 
-
-#if 1
 void compute_gruB(const GRULayer *gru, const float* gru_b_condition, float *state, const float *input)
 {
-   int i;
-   int N, M;
-   int stride;
-   float zrh[3*MAX_RNN_NEURONS_ALL];
-   float recur[3*MAX_RNN_NEURONS_ALL];
-   float *z;
-   float *r;
-   float *h;
-   M = gru->nb_inputs;
-   N = gru->nb_neurons;
-   z = zrh;
-   r = &zrh[N];
-   h = &zrh[2*N];
-   celt_assert(gru->nb_neurons <= MAX_RNN_NEURONS_ALL);
-   celt_assert(input != state);
-   celt_assert(gru->reset_after);
-   stride = 3*N;
-   /* Compute update gate. */
-#ifdef USE_SU_BIAS
-   for (i=0;i<3*N;i++)
-      zrh[i] = gru->subias[i] + gru_b_condition[i];
-#else
-   for (i=0;i<3*N;i++)
-      zrh[i] = gru->bias[i] + gru_b_condition[i];
-#endif
-   sparse_sgemv_accum8x4(zrh, gru->input_weights, 3*N, M, gru->input_weights_idx, input);
-#ifdef USE_SU_BIAS
-   for (i=0;i<3*N;i++)
-      recur[i] = gru->subias[3*N + i];
-#else
-   for (i=0;i<3*N;i++)
-      recur[i] = gru->bias[3*N + i];
-#endif
-   sgemv_accum8x4(recur, gru->recurrent_weights, 3*N, N, stride, state);
-   for (i=0;i<2*N;i++)
-      zrh[i] += recur[i];
-   compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID);
-   for (i=0;i<N;i++)
-      h[i] += recur[2*N+i]*r[i];
-   compute_activation(h, h, N, gru->activation);
-   for (i=0;i<N;i++)
-      h[i] = z[i]*state[i] + (1-z[i])*h[i];
-   for (i=0;i<N;i++)
-      state[i] = h[i];
-}
-
-#else
-
-
-void compute_gruB(const GRULayer *gru, const float* gru_b_condition, float *state, const float *input)
-{
   LinearLayer in_matrix, rec_matrix;
   int i, M, N;
   float bias[3*MAX_RNN_NEURONS_ALL];
@@ -379,54 +293,10 @@
   rec_matrix.weights_idx = NULL;
   compute_generic_gru(&in_matrix, &rec_matrix, state, input);
 }
-#endif
 
-
-#if 1
 /* The input of this GRU is after the input matrix multiply. */
 void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *input)
 {
-   int i, k;
-   int N;
-   float recur[3*MAX_RNN_NEURONS_ALL];
-   float *z;
-   float *r;
-   float *h;
-   const float *bias;
-   N = gru->nb_neurons;
-   z = recur;
-   r = &recur[N];
-   h = &recur[2*N];
-   celt_assert(gru->nb_neurons <= MAX_RNN_NEURONS_ALL);
-   celt_assert(input != state);
-   celt_assert(gru->reset_after);
-#ifdef USE_SU_BIAS
-   bias = &gru->subias[3*N];
-#else
-   bias = &gru->bias[3*N];
-#endif
-   for (k=0;k<2;k++)
-   {
-      for (i=0;i<N;i++)
-         recur[k*N + i] = bias[k*N + i] + gru->diag_weights[k*N + i]*state[i] + input[k*N + i];
-   }
-   for (;k<3;k++)
-   {
-      for (i=0;i<N;i++)
-         recur[k*N + i] = bias[k*N + i] + gru->diag_weights[k*N + i]*state[i];
-   }
-   sparse_sgemv_accum8x4(recur, gru->recurrent_weights, 3*N, N, gru->idx, state);
-   compute_activation(recur, recur, 2*N, ACTIVATION_SIGMOID);
-   for (i=0;i<N;i++)
-      h[i] = h[i]*r[i] + input[2*N+i];
-   compute_activation(h, h, N, gru->activation);
-   for (i=0;i<N;i++)
-      state[i] = z[i]*state[i] + (1-z[i])*h[i];
-}
-#else
-/* The input of this GRU is after the input matrix multiply. */
-void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *input)
-{
   LinearLayer in_matrix, rec_matrix;
   int i, N;
   float scale[3*MAX_RNN_NEURONS_ALL];
@@ -460,34 +330,11 @@
   rec_matrix.weights_idx = gru->idx;
   compute_generic_gru(&in_matrix, &rec_matrix, state, input);
 }
-#endif
 
-
 #define MAX_CONV_INPUTS_ALL IMAX(MAX_CONV_INPUTS, DRED_MAX_CONV_INPUTS)
 
-#if 1
 void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input)
 {
-   int i;
-   int N, M;
-   int stride;
-   float tmp[MAX_CONV_INPUTS_ALL];
-   celt_assert(input != output);
-   celt_assert(layer->nb_inputs*layer->kernel_size <= MAX_CONV_INPUTS_ALL);
-   OPUS_COPY(tmp, mem, layer->nb_inputs*(layer->kernel_size-1));
-   OPUS_COPY(&tmp[layer->nb_inputs*(layer->kernel_size-1)], input, layer->nb_inputs);
-   M = layer->nb_inputs*layer->kernel_size;
-   N = layer->nb_neurons;
-   stride = N;
-   for (i=0;i<N;i++)
-      output[i] = layer->bias[i];
-   sgemv_accum(output, layer->input_weights, N, M, stride, tmp);
-   compute_activation(output, output, N, layer->activation);
-   OPUS_COPY(mem, &tmp[layer->nb_inputs], layer->nb_inputs*(layer->kernel_size-1));
-}
-#else
-void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input)
-{
    LinearLayer matrix;
    int N, M;
    float tmp[MAX_CONV_INPUTS_ALL];
@@ -510,7 +357,6 @@
    compute_activation(output, output, N, layer->activation);
    OPUS_COPY(mem, &tmp[layer->nb_inputs], layer->nb_inputs*(layer->kernel_size-1));
 }
-#endif
 
 void compute_embedding(const EmbeddingLayer *layer, float *output, int input)
 {
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -35,12 +35,14 @@
 #include "arch.h"
 
 
-#if defined(__AVX__) || defined(__SSE2__)
+#if defined(__AVX__) || defined(__SSSE3__)
 #include "vec_avx.h"
 #elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) && !defined(DISABLE_NEON)
 #include "vec_neon.h"
 #else
 
+#include "os_support.h"
+
 #define MAX_INPUTS (2048)
 
 #define NO_OPTIMIZATIONS
@@ -352,282 +354,9 @@
     }
 }
 #endif
-static inline void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
-{
-   int i, j;
-   for (i=0;i<rows;i+=16)
-   {
-      for (j=0;j<cols;j++)
-      {
-         const float * restrict w;
-         float * restrict y;
-         float xj;
-         w = &weights[j*col_stride + i];
-         xj = x[j];
-         y = &out[i];
-         y[0] += w[0]*xj;
-         y[1] += w[1]*xj;
-         y[2] += w[2]*xj;
-         y[3] += w[3]*xj;
-         y[4] += w[4]*xj;
-         y[5] += w[5]*xj;
-         y[6] += w[6]*xj;
-         y[7] += w[7]*xj;
-         y[8] += w[8]*xj;
-         y[9] += w[9]*xj;
-         y[10] += w[10]*xj;
-         y[11] += w[11]*xj;
-         y[12] += w[12]*xj;
-         y[13] += w[13]*xj;
-         y[14] += w[14]*xj;
-         y[15] += w[15]*xj;
-      }
-   }
-}
 
-static inline void sparse_sgemv_accum16(float *out, const float *w, int rows, const int *idx, const float *x)
-{
-   int i, j;
-   for (i=0;i<rows;i+=16)
-   {
-      int cols;
-      cols = *idx++;
-      for (j=0;j<cols;j++)
-      {
-         float * restrict y;
-         float xj;
-         xj = x[*idx++];
-         y = &out[i];
-         y[0] += w[0]*xj;
-         y[1] += w[1]*xj;
-         y[2] += w[2]*xj;
-         y[3] += w[3]*xj;
-         y[4] += w[4]*xj;
-         y[5] += w[5]*xj;
-         y[6] += w[6]*xj;
-         y[7] += w[7]*xj;
-         y[8] += w[8]*xj;
-         y[9] += w[9]*xj;
-         y[10] += w[10]*xj;
-         y[11] += w[11]*xj;
-         y[12] += w[12]*xj;
-         y[13] += w[13]*xj;
-         y[14] += w[14]*xj;
-         y[15] += w[15]*xj;
-         w += 16;
-      }
-   }
-}
-
-#ifdef DOT_PROD
-
 #define SCALE (128.f*127.f)
 #define SCALE_1 (1.f/128.f/127.f)
-
-
-#ifdef USE_SU_BIAS
-
-static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x)
-{
-   int i, j;
-   unsigned char x[MAX_INPUTS];
-   (void)col_stride;
-   for (i=0;i<rows;i++) out[i] *= SCALE;
-   for (i=0;i<cols;i++) x[i] = 127+(int)floor(.5+127*_x[i]);
-   for (i=0;i<rows;i+=8)
-   {
-      for (j=0;j<cols;j+=4)
-      {
-         float * restrict y;
-         float xj0, xj1, xj2, xj3;
-         xj0 = x[j+0];
-         xj1 = x[j+1];
-         xj2 = x[j+2];
-         xj3 = x[j+3];
-         y = &out[i];
-         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
-         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
-         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
-         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
-         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
-         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
-         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
-         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
-         w += 32;
-      }
-   }
-   for (i=0;i<rows;i++) out[i] *= SCALE_1;
-}
-
-static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
-{
-   int i, j;
-   unsigned char x[MAX_INPUTS];
-   for (i=0;i<rows;i++) out[i] *= SCALE;
-   for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
-   for (i=0;i<rows;i+=8)
-   {
-      int colblocks;
-      colblocks = *idx++;
-      for (j=0;j<colblocks;j++)
-      {
-         int pos;
-         float * restrict y;
-         int xj0, xj1, xj2, xj3;
-         pos = (*idx++);
-         xj0 = x[pos+0];
-         xj1 = x[pos+1];
-         xj2 = x[pos+2];
-         xj3 = x[pos+3];
-         y = &out[i];
-         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
-         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
-         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
-         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
-         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
-         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
-         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
-         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
-         w += 32;
-      }
-   }
-   for (i=0;i<rows;i++) out[i] *= SCALE_1;
-}
-#else /*USE_SU_BIAS*/
-
-static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x)
-{
-   int i, j;
-   signed char x[MAX_INPUTS];
-   (void)col_stride;
-   for (i=0;i<rows;i++) out[i] *= SCALE;
-   for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]);
-   for (i=0;i<rows;i+=8)
-   {
-      for (j=0;j<cols;j+=4)
-      {
-         float * restrict y;
-         float xj0, xj1, xj2, xj3;
-         xj0 = x[j+0];
-         xj1 = x[j+1];
-         xj2 = x[j+2];
-         xj3 = x[j+3];
-         y = &out[i];
-         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
-         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
-         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
-         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
-         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
-         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
-         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
-         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
-         w += 32;
-      }
-   }
-   for (i=0;i<rows;i++) out[i] *= SCALE_1;
-}
-
-static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
-{
-   int i, j;
-   signed char x[MAX_INPUTS];
-   for (i=0;i<rows;i++) out[i] *= SCALE;
-   for (i=0;i<cols;i++) x[i] = floor(.5+127*_x[i]);
-   for (i=0;i<rows;i+=8)
-   {
-      int colblocks;
-      colblocks = *idx++;
-      for (j=0;j<colblocks;j++)
-      {
-         int pos;
-         float * restrict y;
-         int xj0, xj1, xj2, xj3;
-         pos = (*idx++);
-         xj0 = x[pos+0];
-         xj1 = x[pos+1];
-         xj2 = x[pos+2];
-         xj3 = x[pos+3];
-         y = &out[i];
-         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
-         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
-         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
-         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
-         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
-         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
-         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
-         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
-         w += 32;
-      }
-   }
-   for (i=0;i<rows;i++) out[i] *= SCALE_1;
-}
-#endif /*USE_SU_BIAS*/
-
-#else /*DOT_PROD*/
-
-#define sgemv_accum8x4 sgemv_accum
-
-
-static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int ignore, const int *idx, const float *x)
-{
-   int i, j;
-   (void)ignore;
-   for (i=0;i<rows;i+=8)
-   {
-      int cols;
-      cols = *idx++;
-      for (j=0;j<cols;j++)
-      {
-         int pos;
-         float * restrict y;
-         float xj0, xj1, xj2, xj3;
-         pos = (*idx++);
-         xj0 = x[pos+0];
-         xj1 = x[pos+1];
-         xj2 = x[pos+2];
-         xj3 = x[pos+3];
-         y = &out[i];
-         y[0] += w[0]*xj0;
-         y[1] += w[1]*xj0;
-         y[2] += w[2]*xj0;
-         y[3] += w[3]*xj0;
-         y[4] += w[4]*xj0;
-         y[5] += w[5]*xj0;
-         y[6] += w[6]*xj0;
-         y[7] += w[7]*xj0;
-
-         y[0] += w[8]*xj1;
-         y[1] += w[9]*xj1;
-         y[2] += w[10]*xj1;
-         y[3] += w[11]*xj1;
-         y[4] += w[12]*xj1;
-         y[5] += w[13]*xj1;
-         y[6] += w[14]*xj1;
-         y[7] += w[15]*xj1;
-
-         y[0] += w[16]*xj2;
-         y[1] += w[17]*xj2;
-         y[2] += w[18]*xj2;
-         y[3] += w[19]*xj2;
-         y[4] += w[20]*xj2;
-         y[5] += w[21]*xj2;
-         y[6] += w[22]*xj2;
-         y[7] += w[23]*xj2;
-
-         y[0] += w[24]*xj3;
-         y[1] += w[25]*xj3;
-         y[2] += w[26]*xj3;
-         y[3] += w[27]*xj3;
-         y[4] += w[28]*xj3;
-         y[5] += w[29]*xj3;
-         y[6] += w[30]*xj3;
-         y[7] += w[31]*xj3;
-         w += 32;
-      }
-   }
-}
-#endif /*DOT_PROD*/
-
 
 #endif /*no optimizations*/
 #endif /*VEC_H*/
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -186,7 +186,15 @@
 typedef __m256i real_m256i;
 #define __m256i mm256i_emu
 
+static inline mm256i_emu mm256_setzero_si256(void) {
+  mm256i_emu ret;
+  ret.lo = _mm_setzero_si128();
+  ret.hi = ret.lo;
+  return ret;
+}
+#define _mm256_setzero_si256 mm256_setzero_si256
 
+
 static inline mm256i_emu mm256_loadu_si256(const mm256i_emu *src) {
   mm256i_emu ret;
   ret.lo = _mm_loadu_si128((const __m128i*)src);
@@ -619,64 +627,7 @@
 
 #endif
 
-static inline void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
-{
-   int i, j;
-   for (i=0;i<rows;i+=16)
-   {
-      float *y;
-      __m256 vy0, vy8;
-      y = &out[i];
-      vy0 = _mm256_loadu_ps(&y[0]);
-      vy8 = _mm256_loadu_ps(&y[8]);
-      for (j=0;j<cols;j++)
-      {
-         __m256 vxj;
-         __m256 vw;
-         vxj = _mm256_broadcast_ss(&x[j]);
 
-         vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
-         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
-
-         vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
-         vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
-      }
-      _mm256_storeu_ps (&y[0], vy0);
-      _mm256_storeu_ps (&y[8], vy8);
-   }
-}
-static inline void sparse_sgemv_accum16(float *out, const float *weights, int rows, const int *idx, const float *x)
-{
-   int i, j;
-   for (i=0;i<rows;i+=16)
-   {
-      float *y;
-      int cols;
-      __m256 vy0, vy8;
-      y = &out[i];
-      vy0 = _mm256_loadu_ps(&y[0]);
-      vy8 = _mm256_loadu_ps(&y[8]);
-      cols = *idx++;
-      for (j=0;j<cols;j++)
-      {
-         int id;
-         __m256 vxj;
-         __m256 vw;
-         id = *idx++;
-         vxj = _mm256_broadcast_ss(&x[id]);
-
-         vw = _mm256_loadu_ps(&weights[0]);
-         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
-
-         vw = _mm256_loadu_ps(&weights[8]);
-         vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
-         weights += 16;
-      }
-      _mm256_storeu_ps (&y[0], vy0);
-      _mm256_storeu_ps (&y[8], vy8);
-   }
-}
-
 static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
 {
    int i, j;
@@ -874,233 +825,15 @@
    }
 }
 
-
-#ifdef DOT_PROD
-#define USE_SU_BIAS
-
-typedef signed char qweight;
-
-#define MAX_OUTPUTS (8192)
-
-
 #define SCALE (128.f*127.f)
 #define SCALE_1 (1.f/128.f/127.f)
+#define USE_SU_BIAS
 
-#if 1
-static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, int col_stride, const float *_x)
-{
-   __m256i ones;
-   int i, j;
-   unsigned char x[MAX_INPUTS];
-   (void)col_stride;
-   ones = _mm256_set1_epi16(1);
-   /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/
-   vector_ps_to_epi8(x, _x, cols);
-   for (i=0;i<rows;i+=8)
-   {
-      __m256i vy0;
-      __m256 vout;
-      vout = _mm256_loadu_ps(&_out[i]);
-      vout = _mm256_mul_ps(vout, _mm256_set1_ps(SCALE));
-      vy0 = _mm256_cvtps_epi32(vout);
-      j=0;
-#if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */
-      for (;j<cols-12;j+=16)
-      {
-         __m256i tmp;
-         __m256i vxj;
-         __m256i vw;
-         vxj = _mm256_set1_epi32(*(int*)&x[j]);
-         vw = _mm256_loadu_si256((const __m256i *)w);
-         tmp = _mm256_maddubs_epi16(vxj, vw);
-         tmp = _mm256_madd_epi16(tmp, ones);
-         vy0 = _mm256_add_epi32(vy0, tmp);
-         w += 32;
-         vxj = _mm256_set1_epi32(*(int*)&x[j+4]);
-         vw = _mm256_loadu_si256((const __m256i *)w);
-         tmp = _mm256_maddubs_epi16(vxj, vw);
-         tmp = _mm256_madd_epi16(tmp, ones);
-         vy0 = _mm256_add_epi32(vy0, tmp);
-         w += 32;
-         vxj = _mm256_set1_epi32(*(int*)&x[j+8]);
-         vw = _mm256_loadu_si256((const __m256i *)w);
-         tmp = _mm256_maddubs_epi16(vxj, vw);
-         tmp = _mm256_madd_epi16(tmp, ones);
-         vy0 = _mm256_add_epi32(vy0, tmp);
-         w += 32;
-         vxj = _mm256_set1_epi32(*(int*)&x[j+12]);
-         vw = _mm256_loadu_si256((const __m256i *)w);
-         tmp = _mm256_maddubs_epi16(vxj, vw);
-         tmp = _mm256_madd_epi16(tmp, ones);
-         vy0 = _mm256_add_epi32(vy0, tmp);
-         w += 32;
-      }
-#endif
-      for (;j<cols;j+=4)
-      {
-         __m256i tmp;
-         __m256i vxj;
-         __m256i vw;
-         vxj = _mm256_set1_epi32(*(int*)&x[j]);
-         vw = _mm256_loadu_si256((const __m256i *)w);
-         tmp = _mm256_maddubs_epi16(vxj, vw);
-         tmp = _mm256_madd_epi16(tmp, ones);
-         vy0 = _mm256_add_epi32(vy0, tmp);
-         w += 32;
-      }
-      vout = _mm256_cvtepi32_ps(vy0);
-      vout = _mm256_mul_ps(vout, _mm256_set1_ps(SCALE_1));
-      _mm256_storeu_ps(&_out[i], vout);
-   }
-}
+#ifdef DOT_PROD
+typedef signed char qweight;
 #else
-static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x)
-{
-   int i, j;
-   unsigned char x[MAX_INPUTS];
-   (void)col_stride;
-   for (i=0;i<rows;i++) out[i] *= SCALE;
-   for (i=0;i<cols;i++) x[i] = 127+(int)floor(.5+127*_x[i]);
-   for (i=0;i<rows;i+=8)
-   {
-      for (j=0;j<cols;j+=4)
-      {
-         float *y;
-         float xj0, xj1, xj2, xj3;
-         xj0 = x[j+0];
-         xj1 = x[j+1];
-         xj2 = x[j+2];
-         xj3 = x[j+3];
-         y = &out[i];
-         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
-         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
-         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
-         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
-         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
-         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
-         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
-         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
-         w += 32;
-      }
-   }
-   for (i=0;i<rows;i++) out[i] *= SCALE_1;
-}
+typedef float qweight;
 #endif
 
-static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
-{
-   __m256i ones;
-   int i, j;
-   unsigned char x[MAX_INPUTS];
-   ones = _mm256_set1_epi16(1);
-   /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/
-   vector_ps_to_epi8(x, _x, cols);
-   for (i=0;i<rows;i+=8)
-   {
-      int colblocks;
-      __m256i vy0;
-      __m256 vout;
-      colblocks = *idx++;
-      vout = _mm256_loadu_ps(&_out[i]);
-      vout = _mm256_mul_ps(vout, _mm256_set1_ps(SCALE));
-      vy0 = _mm256_cvtps_epi32(vout);
-      j=0;
-#if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */
-      for (;j<colblocks-3;j+=4)
-      {
-         __m256i tmp;
-         __m256i vxj;
-         __m256i vw;
-         vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
-         vw = _mm256_loadu_si256((const __m256i *)w);
-         tmp = _mm256_maddubs_epi16(vxj, vw);
-         tmp = _mm256_madd_epi16(tmp, ones);
-         vy0 = _mm256_add_epi32(vy0, tmp);
-         w += 32;
-         vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
-         vw = _mm256_loadu_si256((const __m256i *)w);
-         tmp = _mm256_maddubs_epi16(vxj, vw);
-         tmp = _mm256_madd_epi16(tmp, ones);
-         vy0 = _mm256_add_epi32(vy0, tmp);
-         w += 32;
-         vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
-         vw = _mm256_loadu_si256((const __m256i *)w);
-         tmp = _mm256_maddubs_epi16(vxj, vw);
-         tmp = _mm256_madd_epi16(tmp, ones);
-         vy0 = _mm256_add_epi32(vy0, tmp);
-         w += 32;
-         vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
-         vw = _mm256_loadu_si256((const __m256i *)w);
-         tmp = _mm256_maddubs_epi16(vxj, vw);
-         tmp = _mm256_madd_epi16(tmp, ones);
-         vy0 = _mm256_add_epi32(vy0, tmp);
-         w += 32;
-      }
-#endif
-      for (;j<colblocks;j++)
-      {
-         __m256i tmp;
-         __m256i vxj;
-         __m256i vw;
-         int pos;
-         pos = (*idx++);
-         vxj = _mm256_set1_epi32(*(int*)&x[pos]);
-         vw = _mm256_loadu_si256((const __m256i *)w);
-         tmp = _mm256_maddubs_epi16(vxj, vw);
-         tmp = _mm256_madd_epi16(tmp, ones);
-         vy0 = _mm256_add_epi32(vy0, tmp);
-         w += 32;
-      }
-      vout = _mm256_cvtepi32_ps(vy0);
-      vout = _mm256_mul_ps(vout, _mm256_set1_ps(SCALE_1));
-      _mm256_storeu_ps(&_out[i], vout);
-   }
-}
-
-
-#else /*DOT_PROD*/
-typedef float qweight;
-#define sgemv_accum8x4 sgemv_accum
-
-static inline void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, int ignore, const int *idx, const float *x)
-{
-   int i, j;
-   (void)ignore;
-   for (i=0;i<rows;i+=8)
-   {
-      float *y;
-      int cols;
-      __m256 vy0;
-      y = &out[i];
-      vy0 = _mm256_loadu_ps(&y[0]);
-      cols = *idx++;
-      for (j=0;j<cols;j++)
-      {
-         int id;
-         __m256 vxj;
-         __m256 vw;
-         id = *idx++;
-         vxj = _mm256_broadcast_ss(&x[id]);
-         vw = _mm256_loadu_ps(&weights[0]);
-         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
-
-         vxj = _mm256_broadcast_ss(&x[id+1]);
-         vw = _mm256_loadu_ps(&weights[8]);
-         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
-
-         vxj = _mm256_broadcast_ss(&x[id+2]);
-         vw = _mm256_loadu_ps(&weights[16]);
-         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
-
-         vxj = _mm256_broadcast_ss(&x[id+3]);
-         vw = _mm256_loadu_ps(&weights[24]);
-         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
-
-         weights += 32;
-      }
-      _mm256_storeu_ps (&y[0], vy0);
-   }
-}
-#endif /*DOT_PROD*/
 
 #endif /*VEC_AVX_H*/
--- a/dnn/vec_neon.h
+++ b/dnn/vec_neon.h
@@ -32,6 +32,7 @@
 #define VEC_NEON_H
 
 #include <arm_neon.h>
+#include "os_support.h"
 
 #ifndef DISABLE_DOT_PROD
 #define DOT_PROD
@@ -43,9 +44,6 @@
 typedef float qweight;
 #endif
 
-/* Just so it compiles when those functions aren't needed. */
-static inline void sgemv16x1(float *, const float *, int , int , int , const float *) {}
-static inline void sparse_sgemv8x4(float *, const float *, const int *, int , const float *) {}
 
 #ifndef LPCNET_TEST
 static inline float32x4_t exp4_approx(float32x4_t x) {
@@ -197,7 +195,7 @@
 }
 #endif
 
-static inline void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
 {
     int i, j;
     for (i=0;i<rows;i+=16)
@@ -206,10 +204,10 @@
 
 	/* keep y[0..15] in registers for duration of inner loop */
 
-	float32x4_t y0_3 = vld1q_f32(&y[0]);
-	float32x4_t y4_7 = vld1q_f32(&y[4]);
-	float32x4_t y8_11 = vld1q_f32(&y[8]);
-	float32x4_t y12_15 = vld1q_f32(&y[12]);
+	float32x4_t y0_3 = vdupq_n_f32(0);
+	float32x4_t y4_7 = vdupq_n_f32(0);
+	float32x4_t y8_11 = vdupq_n_f32(0);
+	float32x4_t y12_15 = vdupq_n_f32(0);
 
 	for (j=0;j<cols;j++)
 	{
@@ -241,46 +239,67 @@
     }
 }
 
-static inline void sparse_sgemv_accum16(float *out, const float *w, int rows, const int *idx, const float *x)
+/* Temporarily use unoptimized version */
+static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
 {
-    int i, j;
-    for (i=0;i<rows;i+=16)
-    {
-	int cols;
-	cols = *idx++;
-	float * restrict y;
-	y = &out[i];
+   int i, j;
+   OPUS_CLEAR(out, rows);
+   for (i=0;i<rows;i+=8)
+   {
+      int cols;
+      cols = *idx++;
+      for (j=0;j<cols;j++)
+      {
+         int pos;
+         float * restrict y;
+         float xj0, xj1, xj2, xj3;
+         pos = (*idx++);
+         xj0 = x[pos+0];
+         xj1 = x[pos+1];
+         xj2 = x[pos+2];
+         xj3 = x[pos+3];
+         y = &out[i];
+         y[0] += w[0]*xj0;
+         y[1] += w[1]*xj0;
+         y[2] += w[2]*xj0;
+         y[3] += w[3]*xj0;
+         y[4] += w[4]*xj0;
+         y[5] += w[5]*xj0;
+         y[6] += w[6]*xj0;
+         y[7] += w[7]*xj0;
 
-	/* keep y[0..15] in registers for duration of inner loop */
+         y[0] += w[8]*xj1;
+         y[1] += w[9]*xj1;
+         y[2] += w[10]*xj1;
+         y[3] += w[11]*xj1;
+         y[4] += w[12]*xj1;
+         y[5] += w[13]*xj1;
+         y[6] += w[14]*xj1;
+         y[7] += w[15]*xj1;
 
-	float32x4_t y0_3 = vld1q_f32(&y[0]);
-	float32x4_t y4_7 = vld1q_f32(&y[4]);
-	float32x4_t y8_11 = vld1q_f32(&y[8]);
-	float32x4_t y12_15 = vld1q_f32(&y[12]);
+         y[0] += w[16]*xj2;
+         y[1] += w[17]*xj2;
+         y[2] += w[18]*xj2;
+         y[3] += w[19]*xj2;
+         y[4] += w[20]*xj2;
+         y[5] += w[21]*xj2;
+         y[6] += w[22]*xj2;
+         y[7] += w[23]*xj2;
 
-	for (j=0;j<cols;j++)
-	{
-	    float32x4_t xj= vld1q_dup_f32(&x[*idx++]);
-	    float32x4_t wvec;
-
-	    wvec = vld1q_f32(&w[0]); y0_3 = vmlaq_f32(y0_3, wvec, xj);
-	    wvec = vld1q_f32(&w[4]); y4_7 = vmlaq_f32(y4_7, wvec, xj);
-	    wvec = vld1q_f32(&w[8]); y8_11 = vmlaq_f32(y8_11, wvec, xj);
-	    wvec = vld1q_f32(&w[12]); y12_15 = vmlaq_f32(y12_15, wvec, xj);
-
-	    w += 16;
-	}
-
-	/* save y[0..15] back to memory */
-
-	vst1q_f32(&y[0], y0_3);
-	vst1q_f32(&y[4], y4_7);
-	vst1q_f32(&y[8], y8_11);
-	vst1q_f32(&y[12], y12_15);
-
-    }
+         y[0] += w[24]*xj3;
+         y[1] += w[25]*xj3;
+         y[2] += w[26]*xj3;
+         y[3] += w[27]*xj3;
+         y[4] += w[28]*xj3;
+         y[5] += w[29]*xj3;
+         y[6] += w[30]*xj3;
+         y[7] += w[31]*xj3;
+         w += 32;
+      }
+   }
 }
 
+
 #define SCALE (128.f*127.f)
 #define SCALE_1 (1.f/128.f/127.f)
 
@@ -368,79 +387,5 @@
    }
 }
 
-static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, int col_stride, const float *_x)
-{
-   int i, j;
-   signed char x[MAX_INPUTS];
-   const float32x4_t scale = vdupq_n_f32(SCALE);
-   const float32x4_t scale_1 = vdupq_n_f32(SCALE_1);
-   const float32x4_t const127 = vdupq_n_f32(127.);
-   (void)col_stride;
-   for (i=0;i<cols;i+=8) {
-      int32x4_t xi0, xi4;
-      int16x8_t x_short;
-      xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
-      xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
-      x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
-      vst1_s8(&x[i], vmovn_s16(x_short));
-   }
-   for (i=0;i<rows;i+=8)
-   {
-      int32x4_t acc0, acc1;
-      acc0 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i])));
-      acc1 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i+4])));
-      for (j=0;j<cols;j+=4)
-      {
-         int8x16_t vw0, vw1, vx;
-         vx = (int8x16_t)vld1q_dup_s32((int*)&x[j]);
-         vw0 = vld1q_s8(w);
-         vw1 = vld1q_s8(&w[16]);
-         acc0 = vdotprod(acc0, vw0, vx);
-         acc1 = vdotprod(acc1, vw1, vx);
-         w += 32;
-      }
-      vst1q_f32(&_out[i], vmulq_f32(scale_1, vcvtq_f32_s32(acc0)));
-      vst1q_f32(&_out[i+4], vmulq_f32(scale_1, vcvtq_f32_s32(acc1)));
-   }
-}
-
-static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
-{
-   int i, j;
-   signed char x[MAX_INPUTS];
-   const float32x4_t scale = vdupq_n_f32(SCALE);
-   const float32x4_t scale_1 = vdupq_n_f32(SCALE_1);
-   const float32x4_t const127 = vdupq_n_f32(127.);
-   for (i=0;i<cols;i+=8) {
-      int32x4_t xi0, xi4;
-      int16x8_t x_short;
-      xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
-      xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
-      x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
-      vst1_s8(&x[i], vmovn_s16(x_short));
-   }
-   for (i=0;i<rows;i+=8)
-   {
-      int colblocks;
-      int32x4_t acc0, acc1;
-      acc0 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i])));
-      acc1 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i+4])));
-      colblocks = *idx++;
-      for (j=0;j<colblocks;j++)
-      {
-         int pos;
-         pos = (*idx++);
-         int8x16_t vw0, vw1, vx;
-         vx = (int8x16_t)vld1q_dup_s32((int*)&x[pos]);
-         vw0 = vld1q_s8(w);
-         vw1 = vld1q_s8(&w[16]);
-         acc0 = vdotprod(acc0, vw0, vx);
-         acc1 = vdotprod(acc1, vw1, vx);
-         w += 32;
-      }
-      vst1q_f32(&_out[i], vmulq_f32(scale_1, vcvtq_f32_s32(acc0)));
-      vst1q_f32(&_out[i+4], vmulq_f32(scale_1, vcvtq_f32_s32(acc1)));
-   }
-}
 
 #endif
--