shithub: opus

Download patch

ref: f5a68a41b01d7edcd629ec9b284302df1518e922
parent: 8423ef1de25d8006be837504c22080e0319a1d26
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sat Jul 15 22:21:49 EDT 2023

Add generic linear layer

Should be able to handle all previous GRU variants and more.

--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -85,6 +85,73 @@
    }
 }
 
+void compute_linear(const LinearLayer *linear, float *out, const float *in)
+{
+   int i, M, N;
+   const float *bias;
+   bias = linear->bias;
+   M = linear->nb_inputs;
+   N = linear->nb_outputs;
+   if (linear->float_weights != NULL) {
+     if (linear->weights_idx != NULL) sparse_sgemv8x4(out, linear->float_weights, linear->weights_idx, N, in);
+     else sgemv16x1(out, linear->float_weights, N, M, N, in);
+   } else if (linear->weights != NULL) {
+     if (linear->weights_idx != NULL) sparse_cgemv8x4(out, linear->weights, linear->weights_idx, linear->scale, N, M, in);
+     else cgemv8x4(out, linear->weights, linear->scale, N, M, in);
+     /* Only use SU biases on for integer matrices on SU archs. */
+#ifdef USE_SU_BIAS
+     bias = linear->subias;
+#endif
+   }
+   else OPUS_CLEAR(out, N);
+   if (bias != NULL) {
+      for (i=0;i<N;i++) out[i] += bias[i];
+   }
+   if (linear->diag) {
+      /* Diag is only used for GRU recurrent weights. */
+      celt_assert(3*M == N);
+      for (i=0;i<M;i++) {
+         out[i] += linear->diag[i]*in[i];
+         out[i+M] += linear->diag[i+M]*in[i];
+         out[i+2*M] += linear->diag[i+2*M]*in[i];
+      }
+   }
+}
+
+#define MAX_RNN_NEURONS_ALL IMAX(IMAX(MAX_RNN_NEURONS, PLC_MAX_RNN_NEURONS), DRED_MAX_RNN_NEURONS)
+
+
+void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in)
+{
+  int i;
+  int N;
+  float zrh[3*MAX_RNN_NEURONS_ALL];
+  float recur[3*MAX_RNN_NEURONS_ALL];
+  float *z;
+  float *r;
+  float *h;
+  celt_assert(3*recurrent_weights->nb_inputs == recurrent_weights->nb_outputs);
+  celt_assert(input_weights->nb_outputs == recurrent_weights->nb_outputs);
+  N = recurrent_weights->nb_inputs;
+  z = zrh;
+  r = &zrh[N];
+  h = &zrh[2*N];
+  celt_assert(recurrent_weights->nb_outputs <= 3*MAX_RNN_NEURONS_ALL);
+  celt_assert(in != state);
+  compute_linear(input_weights, zrh, in);
+  compute_linear(recurrent_weights, recur, state);
+  for (i=0;i<2*N;i++)
+     zrh[i] += recur[i];
+  compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID);
+  for (i=0;i<N;i++)
+     h[i] += recur[2*N+i]*r[i];
+  compute_activation(h, h, N, ACTIVATION_TANH);
+  for (i=0;i<N;i++)
+     h[i] = z[i]*state[i] + (1-z[i])*h[i];
+  for (i=0;i<N;i++)
+     state[i] = h[i];
+}
+
 void compute_activation(float *output, const float *input, int N, int activation)
 {
    int i;
@@ -119,6 +186,7 @@
    }
 }
 
+#if 1
 void _lpcnet_compute_dense(const DenseLayer *layer, float *output, const float *input)
 {
    int i;
@@ -133,8 +201,25 @@
    sgemv_accum(output, layer->input_weights, N, M, stride, input);
    compute_activation(output, output, N, layer->activation);
 }
+#else
+void _lpcnet_compute_dense(const DenseLayer *layer, float *output, const float *input)
+{
+   LinearLayer matrix;
+   celt_assert(input != output);
+   matrix.bias = layer->bias;
+   matrix.subias = NULL;
+   matrix.float_weights = layer->input_weights;
+   matrix.weights = NULL;
+   matrix.weights_idx = NULL;
+   matrix.diag = NULL;
+   matrix.nb_inputs = layer->nb_inputs;
+   matrix.nb_outputs = layer->nb_neurons;
+   matrix.scale = NULL;
+   compute_linear(&matrix, output, input);
+   compute_activation(output, output, layer->nb_neurons, layer->activation);
+}
+#endif
 
-
 int sample_mdense(const MDenseLayer *layer, const float *input, const float *sampling_logit_table, kiss99_ctx *rng)
 {
    int b, j, N, M, C, stride;
@@ -188,9 +273,15 @@
 
 }
 
+#ifdef USE_SU_BIAS
+#define bias_type subias
+#else
+#define bias_type bias
+#endif
+#define MAX_IDX_SIZE 8192
 
-#define MAX_RNN_NEURONS_ALL IMAX(IMAX(MAX_RNN_NEURONS, PLC_MAX_RNN_NEURONS), DRED_MAX_RNN_NEURONS)
 
+#if 1
 void compute_gruB(const GRULayer *gru, const float* gru_b_condition, float *state, const float *input)
 {
    int i;
@@ -239,7 +330,59 @@
       state[i] = h[i];
 }
 
+#else
 
+
+void compute_gruB(const GRULayer *gru, const float* gru_b_condition, float *state, const float *input)
+{
+  LinearLayer in_matrix, rec_matrix;
+  int i, M, N;
+  float bias[3*MAX_RNN_NEURONS_ALL];
+  float scale[3*MAX_RNN_NEURONS_ALL];
+  M = gru->nb_inputs;
+  N = gru->nb_neurons;
+
+  in_matrix.bias = bias;
+  in_matrix.diag = NULL;
+  in_matrix.nb_inputs = M;
+  in_matrix.nb_outputs = 3*N;
+  in_matrix.subias = bias;
+#ifdef DISABLE_DOT_PROD
+  for (i=0;i<3*N;i++) bias[i] = gru->bias[i] + gru_b_condition[i];
+  in_matrix.scale = NULL;
+  in_matrix.float_weights = gru->input_weights;
+  in_matrix.weights = NULL;
+#else
+  for (i=0;i<3*N;i++) bias[i] = gru->bias_type[i] + gru_b_condition[i];
+  for (i=0;i<3*N;i++) scale[i] = SCALE_1;
+  in_matrix.scale = scale;
+  in_matrix.weights = gru->input_weights;
+  in_matrix.float_weights = NULL;
+#endif
+  in_matrix.weights_idx = gru->input_weights_idx;
+
+  rec_matrix.bias = &gru->bias[3*N];
+  rec_matrix.diag = NULL;
+  rec_matrix.nb_inputs = N;
+  rec_matrix.nb_outputs = 3*N;
+  rec_matrix.scale = scale;
+  rec_matrix.subias = &gru->subias[3*N];
+#ifdef DISABLE_DOT_PROD
+  rec_matrix.scale = NULL;
+  rec_matrix.float_weights = gru->recurrent_weights;
+  rec_matrix.weights = NULL;
+#else
+  rec_matrix.scale = scale;
+  rec_matrix.weights = gru->recurrent_weights;
+  rec_matrix.float_weights = NULL;
+#endif
+  rec_matrix.weights_idx = NULL;
+  compute_generic_gru(&in_matrix, &rec_matrix, state, input);
+}
+#endif
+
+
+#if 1
 /* The input of this GRU is after the input matrix multiply. */
 void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *input)
 {
@@ -280,9 +423,49 @@
    for (i=0;i<N;i++)
       state[i] = z[i]*state[i] + (1-z[i])*h[i];
 }
+#else
+/* The input of this GRU is after the input matrix multiply. */
+void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *input)
+{
+  LinearLayer in_matrix, rec_matrix;
+  int i, N;
+  float scale[3*MAX_RNN_NEURONS_ALL];
+  N = gru->nb_neurons;
 
+  in_matrix.bias = input;
+  in_matrix.diag = NULL;
+  in_matrix.nb_inputs = N;
+  in_matrix.nb_outputs = 3*N;
+  in_matrix.subias = input;
+  in_matrix.scale = NULL;
+  in_matrix.float_weights = NULL;
+  in_matrix.weights = NULL;
+  in_matrix.weights_idx = NULL;
+
+  rec_matrix.bias = &gru->bias[3*N];
+  rec_matrix.diag = gru->diag_weights;
+  rec_matrix.nb_inputs = N;
+  rec_matrix.nb_outputs = 3*N;
+  rec_matrix.subias = &gru->subias[3*N];
+#ifdef DISABLE_DOT_PROD
+  rec_matrix.scale = NULL;
+  rec_matrix.float_weights = gru->recurrent_weights;
+  rec_matrix.weights = NULL;
+#else
+  for (i=0;i<3*N;i++) scale[i] = SCALE_1;
+  rec_matrix.scale = scale;
+  rec_matrix.weights = gru->recurrent_weights;
+  rec_matrix.float_weights = NULL;
+#endif
+  rec_matrix.weights_idx = gru->idx;
+  compute_generic_gru(&in_matrix, &rec_matrix, state, input);
+}
+#endif
+
+
 #define MAX_CONV_INPUTS_ALL IMAX(MAX_CONV_INPUTS, DRED_MAX_CONV_INPUTS)
 
+#if 1
 void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input)
 {
    int i;
@@ -302,6 +485,32 @@
    compute_activation(output, output, N, layer->activation);
    OPUS_COPY(mem, &tmp[layer->nb_inputs], layer->nb_inputs*(layer->kernel_size-1));
 }
+#else
+void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input)
+{
+   LinearLayer matrix;
+   int N, M;
+   float tmp[MAX_CONV_INPUTS_ALL];
+   celt_assert(input != output);
+   celt_assert(layer->nb_inputs*layer->kernel_size <= MAX_CONV_INPUTS_ALL);
+   OPUS_COPY(tmp, mem, layer->nb_inputs*(layer->kernel_size-1));
+   OPUS_COPY(&tmp[layer->nb_inputs*(layer->kernel_size-1)], input, layer->nb_inputs);
+   M = layer->nb_inputs*layer->kernel_size;
+   N = layer->nb_neurons;
+   matrix.bias = layer->bias;
+   matrix.subias = NULL;
+   matrix.float_weights = layer->input_weights;
+   matrix.weights = NULL;
+   matrix.weights_idx = NULL;
+   matrix.diag = NULL;
+   matrix.nb_inputs = M;
+   matrix.nb_outputs = N;
+   matrix.scale = NULL;
+   compute_linear(&matrix, output, tmp);
+   compute_activation(output, output, N, layer->activation);
+   OPUS_COPY(mem, &tmp[layer->nb_inputs], layer->nb_inputs*(layer->kernel_size-1));
+}
+#endif
 
 void compute_embedding(const EmbeddingLayer *layer, float *output, int input)
 {
--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -61,6 +61,18 @@
   char name[44];
 } WeightHead;
 
+/* Generic sparse affine transformation. */
+typedef struct {
+  const float *bias;
+  const float *subias;
+  const opus_int8 *weights;
+  const float *float_weights;
+  const int *weights_idx;
+  const float *diag;
+  const float *scale;
+  int nb_inputs;
+  int nb_outputs;
+} LinearLayer;
 
 typedef struct {
   const float *bias;
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -56,6 +56,230 @@
 typedef float qweight;
 #endif
 
+static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+   int i, j;
+   OPUS_CLEAR(out, rows);
+   for (i=0;i<rows;i+=16)
+   {
+      for (j=0;j<cols;j++)
+      {
+         const float * restrict w;
+         float * restrict y;
+         float xj;
+         w = &weights[j*col_stride + i];
+         xj = x[j];
+         y = &out[i];
+         y[0] += w[0]*xj;
+         y[1] += w[1]*xj;
+         y[2] += w[2]*xj;
+         y[3] += w[3]*xj;
+         y[4] += w[4]*xj;
+         y[5] += w[5]*xj;
+         y[6] += w[6]*xj;
+         y[7] += w[7]*xj;
+         y[8] += w[8]*xj;
+         y[9] += w[9]*xj;
+         y[10] += w[10]*xj;
+         y[11] += w[11]*xj;
+         y[12] += w[12]*xj;
+         y[13] += w[13]*xj;
+         y[14] += w[14]*xj;
+         y[15] += w[15]*xj;
+      }
+   }
+}
+
+static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
+{
+   int i, j;
+   OPUS_CLEAR(out, rows);
+   for (i=0;i<rows;i+=8)
+   {
+      int cols;
+      cols = *idx++;
+      for (j=0;j<cols;j++)
+      {
+         int pos;
+         float * restrict y;
+         float xj0, xj1, xj2, xj3;
+         pos = (*idx++);
+         xj0 = x[pos+0];
+         xj1 = x[pos+1];
+         xj2 = x[pos+2];
+         xj3 = x[pos+3];
+         y = &out[i];
+         y[0] += w[0]*xj0;
+         y[1] += w[1]*xj0;
+         y[2] += w[2]*xj0;
+         y[3] += w[3]*xj0;
+         y[4] += w[4]*xj0;
+         y[5] += w[5]*xj0;
+         y[6] += w[6]*xj0;
+         y[7] += w[7]*xj0;
+
+         y[0] += w[8]*xj1;
+         y[1] += w[9]*xj1;
+         y[2] += w[10]*xj1;
+         y[3] += w[11]*xj1;
+         y[4] += w[12]*xj1;
+         y[5] += w[13]*xj1;
+         y[6] += w[14]*xj1;
+         y[7] += w[15]*xj1;
+
+         y[0] += w[16]*xj2;
+         y[1] += w[17]*xj2;
+         y[2] += w[18]*xj2;
+         y[3] += w[19]*xj2;
+         y[4] += w[20]*xj2;
+         y[5] += w[21]*xj2;
+         y[6] += w[22]*xj2;
+         y[7] += w[23]*xj2;
+
+         y[0] += w[24]*xj3;
+         y[1] += w[25]*xj3;
+         y[2] += w[26]*xj3;
+         y[3] += w[27]*xj3;
+         y[4] += w[28]*xj3;
+         y[5] += w[29]*xj3;
+         y[6] += w[30]*xj3;
+         y[7] += w[31]*xj3;
+         w += 32;
+      }
+   }
+}
+
+#ifdef USE_SU_BIAS
+static inline void sparse_cgemv8x4(float *out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
+{
+   int i, j;
+   unsigned char x[MAX_INPUTS];
+   for (i=0;i<rows;i++) out[i] = 0;
+   for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
+   for (i=0;i<rows;i+=8)
+   {
+      int colblocks;
+      colblocks = *idx++;
+      for (j=0;j<colblocks;j++)
+      {
+         int pos;
+         float * restrict y;
+         int xj0, xj1, xj2, xj3;
+         pos = (*idx++);
+         xj0 = x[pos+0];
+         xj1 = x[pos+1];
+         xj2 = x[pos+2];
+         xj3 = x[pos+3];
+         y = &out[i];
+         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
+         w += 32;
+      }
+   }
+   for (i=0;i<rows;i++) out[i] *= scale[i];
+}
+static inline void cgemv8x4(float *out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
+{
+   int i, j;
+   unsigned char x[MAX_INPUTS];
+   for (i=0;i<rows;i++) out[i] = 0;
+   for (i=0;i<cols;i++) x[i] = 127+(int)floor(.5+127*_x[i]);
+   for (i=0;i<rows;i+=8)
+   {
+      for (j=0;j<cols;j+=4)
+      {
+         float *y;
+         float xj0, xj1, xj2, xj3;
+         xj0 = x[j+0];
+         xj1 = x[j+1];
+         xj2 = x[j+2];
+         xj3 = x[j+3];
+         y = &out[i];
+         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
+         w += 32;
+      }
+   }
+   for (i=0;i<rows;i++) out[i] *= scale[i];
+}
+#else
+static inline void sparse_cgemv8x4(float *out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
+{
+   int i, j;
+   opus_int8 x[MAX_INPUTS];
+   for (i=0;i<rows;i++) out[i] = 0;
+   for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]);
+   for (i=0;i<rows;i+=8)
+   {
+      int colblocks;
+      colblocks = *idx++;
+      for (j=0;j<colblocks;j++)
+      {
+         int pos;
+         float * restrict y;
+         int xj0, xj1, xj2, xj3;
+         pos = (*idx++);
+         xj0 = x[pos+0];
+         xj1 = x[pos+1];
+         xj2 = x[pos+2];
+         xj3 = x[pos+3];
+         y = &out[i];
+         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
+         w += 32;
+      }
+   }
+   for (i=0;i<rows;i++) out[i] *= scale[i];
+}
+static inline void cgemv8x4(float *out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
+{
+   int i, j;
+   opus_int8 x[MAX_INPUTS];
+   for (i=0;i<rows;i++) out[i] = 0;
+   for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]);
+   for (i=0;i<rows;i+=8)
+   {
+      for (j=0;j<cols;j+=4)
+      {
+         float *y;
+         float xj0, xj1, xj2, xj3;
+         xj0 = x[j+0];
+         xj1 = x[j+1];
+         xj2 = x[j+2];
+         xj3 = x[j+3];
+         y = &out[i];
+         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
+         w += 32;
+      }
+   }
+   for (i=0;i<rows;i++) out[i] *= scale[i];
+}
+#endif
 
 /* No AVX2/FMA support */
 #ifndef LPCNET_TEST
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -35,6 +35,10 @@
 #include <immintrin.h>
 #include <math.h>
 
+
+#define MAX_INPUTS (2048)
+
+
 /* Use 8-bit dot products unless disabled or if stuck with SSE2. */
 #if (defined(__AVX2__) || defined(__SSSE3__)) && !defined(DISABLE_DOT_PROD)
 #define DOT_PROD
@@ -673,13 +677,209 @@
    }
 }
 
+static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+   int i, j;
+   for (i=0;i<rows;i+=16)
+   {
+      float *y;
+      __m256 vy0, vy8;
+      y = &out[i];
+      vy0 = _mm256_setzero_ps();
+      vy8 = _mm256_setzero_ps();
+      for (j=0;j<cols;j++)
+      {
+         __m256 vxj;
+         __m256 vw;
+         vxj = _mm256_broadcast_ss(&x[j]);
+
+         vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
+         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+
+         vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
+         vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
+      }
+      _mm256_storeu_ps (&y[0], vy0);
+      _mm256_storeu_ps (&y[8], vy8);
+   }
+}
+
+static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x)
+{
+   int i, j;
+   for (i=0;i<rows;i+=8)
+   {
+      float *y;
+      int cols;
+      __m256 vy0;
+      y = &out[i];
+      vy0 = _mm256_setzero_ps();
+      cols = *idx++;
+      for (j=0;j<cols;j++)
+      {
+         int id;
+         __m256 vxj;
+         __m256 vw;
+         id = *idx++;
+         vxj = _mm256_broadcast_ss(&x[id]);
+         vw = _mm256_loadu_ps(&weights[0]);
+         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+
+         vxj = _mm256_broadcast_ss(&x[id+1]);
+         vw = _mm256_loadu_ps(&weights[8]);
+         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+
+         vxj = _mm256_broadcast_ss(&x[id+2]);
+         vw = _mm256_loadu_ps(&weights[16]);
+         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+
+         vxj = _mm256_broadcast_ss(&x[id+3]);
+         vw = _mm256_loadu_ps(&weights[24]);
+         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+
+         weights += 32;
+      }
+      _mm256_storeu_ps (&y[0], vy0);
+   }
+}
+
+static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
+{
+   __m256i ones;
+   int i, j;
+   unsigned char x[MAX_INPUTS];
+   ones = _mm256_set1_epi16(1);
+   /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/
+   vector_ps_to_epi8(x, _x, cols);
+   for (i=0;i<rows;i+=8)
+   {
+      int colblocks;
+      __m256i vy0;
+      __m256 vout;
+      colblocks = *idx++;
+      vy0 = _mm256_setzero_si256();
+      j=0;
+#if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */
+      for (;j<colblocks-3;j+=4)
+      {
+         __m256i tmp;
+         __m256i vxj;
+         __m256i vw;
+         vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
+         vw = _mm256_loadu_si256((const __m256i *)w);
+         tmp = _mm256_maddubs_epi16(vxj, vw);
+         tmp = _mm256_madd_epi16(tmp, ones);
+         vy0 = _mm256_add_epi32(vy0, tmp);
+         w += 32;
+         vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
+         vw = _mm256_loadu_si256((const __m256i *)w);
+         tmp = _mm256_maddubs_epi16(vxj, vw);
+         tmp = _mm256_madd_epi16(tmp, ones);
+         vy0 = _mm256_add_epi32(vy0, tmp);
+         w += 32;
+         vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
+         vw = _mm256_loadu_si256((const __m256i *)w);
+         tmp = _mm256_maddubs_epi16(vxj, vw);
+         tmp = _mm256_madd_epi16(tmp, ones);
+         vy0 = _mm256_add_epi32(vy0, tmp);
+         w += 32;
+         vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
+         vw = _mm256_loadu_si256((const __m256i *)w);
+         tmp = _mm256_maddubs_epi16(vxj, vw);
+         tmp = _mm256_madd_epi16(tmp, ones);
+         vy0 = _mm256_add_epi32(vy0, tmp);
+         w += 32;
+      }
+#endif
+      for (;j<colblocks;j++)
+      {
+         __m256i tmp;
+         __m256i vxj;
+         __m256i vw;
+         int pos;
+         pos = (*idx++);
+         vxj = _mm256_set1_epi32(*(int*)&x[pos]);
+         vw = _mm256_loadu_si256((const __m256i *)w);
+         tmp = _mm256_maddubs_epi16(vxj, vw);
+         tmp = _mm256_madd_epi16(tmp, ones);
+         vy0 = _mm256_add_epi32(vy0, tmp);
+         w += 32;
+      }
+      vout = _mm256_cvtepi32_ps(vy0);
+      vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i]));
+      _mm256_storeu_ps(&_out[i], vout);
+   }
+}
+static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
+{
+   __m256i ones;
+   int i, j;
+   unsigned char x[MAX_INPUTS];
+   ones = _mm256_set1_epi16(1);
+   /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/
+   vector_ps_to_epi8(x, _x, cols);
+   for (i=0;i<rows;i+=8)
+   {
+      __m256i vy0;
+      __m256 vout;
+      vy0 = _mm256_setzero_si256();
+      j=0;
+#if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */
+      for (;j<cols-12;j+=16)
+      {
+         __m256i tmp;
+         __m256i vxj;
+         __m256i vw;
+         vxj = _mm256_set1_epi32(*(int*)&x[j]);
+         vw = _mm256_loadu_si256((const __m256i *)w);
+         tmp = _mm256_maddubs_epi16(vxj, vw);
+         tmp = _mm256_madd_epi16(tmp, ones);
+         vy0 = _mm256_add_epi32(vy0, tmp);
+         w += 32;
+         vxj = _mm256_set1_epi32(*(int*)&x[j+4]);
+         vw = _mm256_loadu_si256((const __m256i *)w);
+         tmp = _mm256_maddubs_epi16(vxj, vw);
+         tmp = _mm256_madd_epi16(tmp, ones);
+         vy0 = _mm256_add_epi32(vy0, tmp);
+         w += 32;
+         vxj = _mm256_set1_epi32(*(int*)&x[j+8]);
+         vw = _mm256_loadu_si256((const __m256i *)w);
+         tmp = _mm256_maddubs_epi16(vxj, vw);
+         tmp = _mm256_madd_epi16(tmp, ones);
+         vy0 = _mm256_add_epi32(vy0, tmp);
+         w += 32;
+         vxj = _mm256_set1_epi32(*(int*)&x[j+12]);
+         vw = _mm256_loadu_si256((const __m256i *)w);
+         tmp = _mm256_maddubs_epi16(vxj, vw);
+         tmp = _mm256_madd_epi16(tmp, ones);
+         vy0 = _mm256_add_epi32(vy0, tmp);
+         w += 32;
+      }
+#endif
+      for (;j<cols;j+=4)
+      {
+         __m256i tmp;
+         __m256i vxj;
+         __m256i vw;
+         vxj = _mm256_set1_epi32(*(int*)&x[j]);
+         vw = _mm256_loadu_si256((const __m256i *)w);
+         tmp = _mm256_maddubs_epi16(vxj, vw);
+         tmp = _mm256_madd_epi16(tmp, ones);
+         vy0 = _mm256_add_epi32(vy0, tmp);
+         w += 32;
+      }
+      vout = _mm256_cvtepi32_ps(vy0);
+      vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i]));
+      _mm256_storeu_ps(&_out[i], vout);
+   }
+}
+
+
 #ifdef DOT_PROD
 #define USE_SU_BIAS
 
 typedef signed char qweight;
 
-
-#define MAX_INPUTS (2048)
 #define MAX_OUTPUTS (8192)
 
 
--- a/dnn/vec_neon.h
+++ b/dnn/vec_neon.h
@@ -43,6 +43,9 @@
 typedef float qweight;
 #endif
 
+/* Just so it compiles when those functions aren't needed. */
+static inline void sgemv16x1(float *, const float *, int , int , int , const float *) {}
+static inline void sparse_sgemv8x4(float *, const float *, const int *, int , const float *) {}
 
 #ifndef LPCNET_TEST
 static inline float32x4_t exp4_approx(float32x4_t x) {
@@ -294,6 +297,76 @@
   return vpadalq_s16(acc, vpaddq_s16(vmull_s8(vget_low_s8(a), vget_low_s8(b)),  vmull_high_s8(a, b)));
 }
 #endif
+
+static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
+{
+   int i, j;
+   opus_int8 x[MAX_INPUTS];
+   const float32x4_t const127 = vdupq_n_f32(127.);
+   for (i=0;i<cols;i+=8) {
+      int32x4_t xi0, xi4;
+      int16x8_t x_short;
+      xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
+      xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
+      x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
+      vst1_s8(&x[i], vmovn_s16(x_short));
+   }
+   for (i=0;i<rows;i+=8)
+   {
+      int32x4_t acc0, acc1;
+      acc0 = vdupq_n_s32(0);
+      acc1 = vdupq_n_s32(0);
+      for (j=0;j<cols;j+=4)
+      {
+         int8x16_t vw0, vw1, vx;
+         vx = (int8x16_t)vld1q_dup_s32((int*)&x[j]);
+         vw0 = vld1q_s8(w);
+         vw1 = vld1q_s8(&w[16]);
+         acc0 = vdotprod(acc0, vw0, vx);
+         acc1 = vdotprod(acc1, vw1, vx);
+         w += 32;
+      }
+      vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0)));
+      vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1)));
+   }
+}
+
+static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
+{
+   int i, j;
+   opus_int8 x[MAX_INPUTS];
+   const float32x4_t const127 = vdupq_n_f32(127.);
+   for (i=0;i<cols;i+=8) {
+      int32x4_t xi0, xi4;
+      int16x8_t x_short;
+      xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
+      xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
+      x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
+      vst1_s8(&x[i], vmovn_s16(x_short));
+   }
+   for (i=0;i<rows;i+=8)
+   {
+      int colblocks;
+      int32x4_t acc0, acc1;
+      acc0 = vdupq_n_s32(0);
+      acc1 = vdupq_n_s32(0);
+      colblocks = *idx++;
+      for (j=0;j<colblocks;j++)
+      {
+         int pos;
+         pos = (*idx++);
+         int8x16_t vw0, vw1, vx;
+         vx = (int8x16_t)vld1q_dup_s32((int*)&x[pos]);
+         vw0 = vld1q_s8(w);
+         vw1 = vld1q_s8(&w[16]);
+         acc0 = vdotprod(acc0, vw0, vx);
+         acc1 = vdotprod(acc1, vw1, vx);
+         w += 32;
+      }
+      vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0)));
+      vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1)));
+   }
+}
 
 static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, int col_stride, const float *_x)
 {
--