ref: d7f0abcd1978f97ec70b352538d31d51b3e936e7
parent: faf3fe3d24397e0c3e53b5c05d17949da8e76659
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Thu Nov 29 15:09:36 EST 2018
Delaying the softmax() to avoid the pow() Now at 5x real-time, with all the low-hanging fruit done.
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -39,6 +39,8 @@
#include "nnet.h"
#include "nnet_data.h"
+#define SOFTMAX_HACK
+
#ifdef __AVX2__
#include <immintrin.h>
static __m256 exp8_approx(__m256 X)
@@ -340,6 +342,10 @@
for (i=0;i<N;i++)
output[i] = relu(input[i]);
} else if (activation == ACTIVATION_SOFTMAX) {
+#ifdef SOFTMAX_HACK
+ for (i=0;i<N;i++)
+ output[i] = input[i];
+#else
float sum = 0;
softmax(output, input, N);
for (i=0;i<N;i++) {
@@ -348,6 +354,7 @@
sum = 1.f/(sum+1e-30);
for (i=0;i<N;i++)
output[i] = sum*output[i];
+#endif
} else {
celt_assert(activation == ACTIVATION_LINEAR);
for (i=0;i<N;i++)
@@ -619,6 +626,17 @@
float tmp[DUAL_FC_OUT_SIZE];
celt_assert(N <= DUAL_FC_OUT_SIZE);
sum = 0;
+#ifdef SOFTMAX_HACK
+ for (i=0;i<N;i++)
+ {
+ tmp[i] = pdf[i] * (1.f+exp_boost);
+ }
+ softmax(tmp, tmp, N);
+ for (i=0;i<N;i++)
+ {
+ sum += tmp[i];
+ }
+#else
/* Decrease the temperature of the sampling. */
for (i=0;i<N;i++)
{
@@ -625,6 +643,7 @@
tmp[i] = pow(pdf[i], 1.f+exp_boost);
sum += tmp[i];
}
+#endif
norm = 1.f/sum;
/* Convert tmp to a CDF while subtracting the floor */
tmp[0] = MAX16(0, norm*tmp[0] - pdf_floor);
--
⑨