ref: e35441f2cce59cb0a8dc1dbed5755ed1ce43f8ca
parent: 5571ef1b8ebce4bb698c249bf9f63e4d076efc9b
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Tue Jun 29 00:05:48 EDT 2021
Faster activation functions for AVX Using rational function approximation for tanh() and sigmoid.
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -80,8 +80,9 @@
output[i] = relu(input[i]);
} else if (activation == ACTIVATION_SOFTMAX) {
#ifdef SOFTMAX_HACK
- for (i=0;i<N;i++)
- output[i] = input[i];
+ RNN_COPY(output, input, N);
+ /*for (i=0;i<N;i++)
+ output[i] = input[i];*/
#else
float sum = 0;
softmax(output, input, N);
--- /dev/null
+++ b/dnn/training_tf2/pade.py
@@ -1,0 +1,70 @@
+# Optimizing a rational function to optimize a tanh() approximation
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation
+import tensorflow.keras.backend as K
+from tensorflow.keras.optimizers import Adam, SGD
+
+def my_loss1(y_true, y_pred):
+ return 1*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
+
+def my_loss2(y_true, y_pred):
+ return .1*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
+
+def my_loss3(y_true, y_pred):
+ return .01*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
+
+# Using these initializers to seed the approximation
+# with a reasonable starting point
+def num_init(shape, dtype=None):
+ rr = tf.constant([[945], [105], [1]], dtype=dtype)
+ #rr = tf.constant([[946.56757], [98.01368], [0.66841]], dtype=dtype)
+ print(rr)
+ return rr
+
+def den_init(shape, dtype=None):
+ rr = tf.constant([[945], [420], [15]], dtype=dtype)
+ #rr = tf.constant([[946.604], [413.342], [12.465]], dtype=dtype)
+ print(rr)
+ return rr
+
+
+x = np.arange(-10, 10, .01)
+N = len(x)
+x = np.reshape(x, (1, -1, 1))
+x2 = x*x
+
+x2in = np.concatenate([x2*0 + 1, x2, x2*x2], axis=2)
+yout = np.tanh(x)
+
+
+model_x = Input(shape=(None, 1,))
+model_x2 = Input(shape=(None, 3,))
+
+num = Dense(1, name='num', use_bias=False, kernel_initializer=num_init)
+den = Dense(1, name='den', use_bias=False, kernel_initializer=den_init)
+
+def ratio(x):
+ return tf.minimum(1., tf.maximum(-1., x[0]*x[1]/x[2]))
+
+out_layer = Lambda(ratio)
+output = out_layer([model_x, num(model_x2), den(model_x2)])
+
+model = Model([model_x, model_x2], output)
+model.summary()
+
+model.compile(Adam(0.05, beta_1=0.9, beta_2=0.9, decay=2e-5), loss='mean_squared_error')
+model.fit([x, x2in], yout, batch_size=1, epochs=500000, validation_split=0.0)
+
+model.compile(Adam(0.001, beta_2=0.9, decay=1e-4), loss=my_loss1)
+model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
+
+model.compile(Adam(0.0001, beta_2=0.9, decay=1e-4), loss=my_loss2)
+model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
+
+model.compile(Adam(0.00001, beta_2=0.9, decay=1e-4), loss=my_loss3)
+model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
+
+model.save_weights('tanh.h5')
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -39,6 +39,11 @@
#define USE_SU_BIAS
#endif
+#ifndef __FMA__
+#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
+#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
+#endif
+
#ifdef __AVX2__
static inline __m256 exp8_approx(__m256 X)
{
@@ -61,9 +66,66 @@
Y = _mm256_castsi256_ps(_mm256_add_epi32(I, _mm256_castps_si256(Y)));
return Y;
}
+
+/* Approximating tanh() using a Padé-like rational function:
+ tanh(x) ~= x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
+ subject to the +/- 1 bounds.
+ The coefficients were determined by gradient descent trying to minimize
+ the maximum deviation over the whole range (this is only possible because
+ of the bounds). The max error is around 3e-4 and is dominated by the
+ reciprocal approximation (the max error of the rational function is
+ around 6e-5).
+ */
+static inline __m256 tanh8_approx(__m256 X)
+{
+ const __m256 N0 = _mm256_set1_ps(952.52801514f);
+ const __m256 N1 = _mm256_set1_ps(96.39235687f);
+ const __m256 N2 = _mm256_set1_ps(0.60863042f);
+ const __m256 D0 = _mm256_set1_ps(952.72399902f);
+ const __m256 D1 = _mm256_set1_ps(413.36801147f);
+ const __m256 D2 = _mm256_set1_ps(11.88600922f);
+ const __m256 max_out = _mm256_set1_ps(1.f);
+ const __m256 min_out = _mm256_set1_ps(-1.f);
+ __m256 X2, num, den;
+ X2 = _mm256_mul_ps(X, X);
+ num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
+ den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
+ num = _mm256_mul_ps(num, X);
+ den = _mm256_rcp_ps(den);
+ num = _mm256_mul_ps(num, den);
+ return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
+}
+
+/* Sigmoid approximation using a Padé-like rational function:
+ 1/(1+exp(-x)) ~= 0.5 + x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
+ subject to the [0, 1] bounds.
+ The coefficients are directly derived by dividing the tanh() coefficients
+ by powers of two to get the correct scaling. The max error is around 1.5e-4
+ and is dominated by the reciprocal approximation (the max error of the
+ rational function is around 3e-5).
+ */
+static inline __m256 sigmoid8_approx(__m256 X)
+{
+ const __m256 N0 = _mm256_set1_ps(238.13200378f);
+ const __m256 N1 = _mm256_set1_ps(6.02452230f);
+ const __m256 N2 = _mm256_set1_ps(0.00950985f);
+ const __m256 D0 = _mm256_set1_ps(952.72399902f);
+ const __m256 D1 = _mm256_set1_ps(103.34200287f);
+ const __m256 D2 = _mm256_set1_ps(0.74287558f);
+ const __m256 half = _mm256_set1_ps(0.5);
+ const __m256 max_out = _mm256_set1_ps(1.f);
+ const __m256 min_out = _mm256_set1_ps(0.f);
+ __m256 X2, num, den;
+ X2 = _mm256_mul_ps(X, X);
+ num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
+ den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
+ num = _mm256_mul_ps(num, X);
+ den = _mm256_rcp_ps(den);
+ num = _mm256_fmadd_ps(num, den, half);
+ return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
+}
+
#else
-#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
-#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
static inline __m128 exp4_approx(__m128 X)
{
const __m128 K0 = _mm_set1_ps(0.99992522f);
@@ -98,6 +160,48 @@
Y = _mm256_insertf128_ps(Y, Ylo, 0);
return Y;
}
+
+static inline __m128 tanh4_approx(__m128 X)
+{
+ const __m128 N0 = _mm_set1_ps(952.52801514f);
+ const __m128 N1 = _mm_set1_ps(96.39235687f);
+ const __m128 N2 = _mm_set1_ps(0.60863042f);
+ const __m128 D0 = _mm_set1_ps(952.72399902f);
+ const __m128 D1 = _mm_set1_ps(413.36801147f);
+ const __m128 D2 = _mm_set1_ps(11.88600922f);
+ const __m128 max_out = _mm_set1_ps(1.f);
+ const __m128 min_out = _mm_set1_ps(-1.f);
+ __m128 X2, num, den;
+ X2 = _mm_mul_ps(X, X);
+ num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
+ den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
+ num = _mm_mul_ps(num, X);
+ den = _mm_rcp_ps(den);
+ num = _mm_mul_ps(num, den);
+ return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
+}
+
+static inline __m128 sigmoid4_approx(__m128 X)
+{
+ const __m128 N0 = _mm_set1_ps(238.13200378f);
+ const __m128 N1 = _mm_set1_ps(6.02452230f);
+ const __m128 N2 = _mm_set1_ps(0.00950985f);
+ const __m128 D0 = _mm_set1_ps(952.72399902f);
+ const __m128 D1 = _mm_set1_ps(103.34200287f);
+ const __m128 D2 = _mm_set1_ps(0.74287558f);
+ const __m128 half = _mm_set1_ps(0.5);
+ const __m128 max_out = _mm_set1_ps(1.f);
+ const __m128 min_out = _mm_set1_ps(0.f);
+ __m128 X2, num, den;
+ X2 = _mm_mul_ps(X, X);
+ num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
+ den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
+ num = _mm_mul_ps(num, X);
+ den = _mm_rcp_ps(den);
+ num = _mm_fmadd_ps(num, den, half);
+ return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
+}
+
#endif
static inline float celt_exp(float x)
@@ -124,18 +228,15 @@
y[i] = celt_exp(x[i]);
}
+#ifdef __AVX2__
static inline void vec_tanh(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-7;i+=8)
{
- const __m256 two = _mm256_set1_ps(2.f);
- const __m256 one = _mm256_set1_ps(1.f);
__m256 X, Y;
X = _mm256_loadu_ps(&x[i]);
- X = _mm256_mul_ps(X, two);
- Y = exp8_approx(X);
- Y = _mm256_mul_ps(_mm256_sub_ps(Y, one), _mm256_rcp_ps(_mm256_add_ps(Y, one)));
+ Y = tanh8_approx(X);
_mm256_storeu_ps(&y[i], Y);
}
for (;i<N;i++)
@@ -151,12 +252,9 @@
int i;
for (i=0;i<N-7;i+=8)
{
- const __m256 one = _mm256_set1_ps(1.f);
__m256 X, Y;
X = _mm256_loadu_ps(&x[i]);
- Y = exp8_approx(X);
- /* Compute as 1-1/(1+e^x) to avoid >1 values caused by the reciprocal approximation. */
- Y = _mm256_sub_ps(one, _mm256_rcp_ps(_mm256_add_ps(Y, one)));
+ Y = sigmoid8_approx(X);
_mm256_storeu_ps(&y[i], Y);
}
for (;i<N;i++)
@@ -166,6 +264,44 @@
y[i] = (ex)/(ex+1);
}
}
+#else
+static inline void vec_tanh(float *y, const float *x, int N)
+{
+ int i;
+ for (i=0;i<N-3;i+=4)
+ {
+ __m128 X, Y;
+ X = _mm_loadu_ps(&x[i]);
+ Y = tanh4_approx(X);
+ _mm_storeu_ps(&y[i], Y);
+ }
+ for (;i<N;i++)
+ {
+ float ex2;
+ ex2 = celt_exp(2*x[i]);
+ y[i] = (ex2-1)/(ex2+1);
+ }
+}
+
+static inline void vec_sigmoid(float *y, const float *x, int N)
+{
+ int i;
+ for (i=0;i<N-3;i+=4)
+ {
+ __m128 X, Y;
+ X = _mm_loadu_ps(&x[i]);
+ Y = sigmoid4_approx(X);
+ _mm_storeu_ps(&y[i], Y);
+ }
+ for (;i<N;i++)
+ {
+ float ex;
+ ex = celt_exp(x[i]);
+ y[i] = (ex)/(ex+1);
+ }
+}
+
+#endif
static inline void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
{
--
⑨