shithub: opus

Download patch

ref: e35441f2cce59cb0a8dc1dbed5755ed1ce43f8ca
parent: 5571ef1b8ebce4bb698c249bf9f63e4d076efc9b
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Tue Jun 29 00:05:48 EDT 2021

Faster activation functions for AVX

Using rational function approximation for tanh() and sigmoid.

--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -80,8 +80,9 @@
          output[i] = relu(input[i]);
    } else if (activation == ACTIVATION_SOFTMAX) {
 #ifdef SOFTMAX_HACK
-      for (i=0;i<N;i++)
-         output[i] = input[i];
+      RNN_COPY(output, input, N);
+      /*for (i=0;i<N;i++)
+         output[i] = input[i];*/
 #else
       float sum = 0;
       softmax(output, input, N);
--- /dev/null
+++ b/dnn/training_tf2/pade.py
@@ -1,0 +1,70 @@
+# Optimizing a rational function to optimize a tanh() approximation
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation
+import tensorflow.keras.backend as K
+from tensorflow.keras.optimizers import Adam, SGD
+
+def my_loss1(y_true, y_pred):
+    return 1*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
+
+def my_loss2(y_true, y_pred):
+    return .1*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
+
+def my_loss3(y_true, y_pred):
+    return .01*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
+
+# Using these initializers to seed the approximation
+# with a reasonable starting point
+def num_init(shape, dtype=None):
+    rr = tf.constant([[945], [105], [1]], dtype=dtype)
+    #rr = tf.constant([[946.56757], [98.01368], [0.66841]], dtype=dtype)
+    print(rr)
+    return rr
+
+def den_init(shape, dtype=None):
+    rr = tf.constant([[945], [420], [15]], dtype=dtype)
+    #rr = tf.constant([[946.604], [413.342], [12.465]], dtype=dtype)
+    print(rr)
+    return rr
+
+
+x = np.arange(-10, 10, .01)
+N = len(x)
+x = np.reshape(x, (1, -1, 1))
+x2 = x*x
+
+x2in = np.concatenate([x2*0 + 1, x2, x2*x2], axis=2)
+yout = np.tanh(x)
+
+
+model_x = Input(shape=(None, 1,))
+model_x2 = Input(shape=(None, 3,))
+
+num = Dense(1, name='num', use_bias=False, kernel_initializer=num_init)
+den = Dense(1, name='den', use_bias=False, kernel_initializer=den_init)
+
+def ratio(x):
+    return tf.minimum(1., tf.maximum(-1., x[0]*x[1]/x[2]))
+
+out_layer = Lambda(ratio)
+output = out_layer([model_x, num(model_x2), den(model_x2)])
+
+model = Model([model_x, model_x2], output)
+model.summary()
+
+model.compile(Adam(0.05, beta_1=0.9, beta_2=0.9, decay=2e-5), loss='mean_squared_error')
+model.fit([x, x2in], yout, batch_size=1, epochs=500000, validation_split=0.0)
+
+model.compile(Adam(0.001, beta_2=0.9, decay=1e-4), loss=my_loss1)
+model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
+
+model.compile(Adam(0.0001, beta_2=0.9, decay=1e-4), loss=my_loss2)
+model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
+
+model.compile(Adam(0.00001, beta_2=0.9, decay=1e-4), loss=my_loss3)
+model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
+
+model.save_weights('tanh.h5')
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -39,6 +39,11 @@
 #define USE_SU_BIAS
 #endif
 
+#ifndef __FMA__
+#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
+#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
+#endif
+
 #ifdef __AVX2__
 static inline __m256 exp8_approx(__m256 X)
 {
@@ -61,9 +66,66 @@
    Y = _mm256_castsi256_ps(_mm256_add_epi32(I, _mm256_castps_si256(Y)));
    return Y;
 }
+
+/* Approximating tanh() using a Padé-like rational function:
+   tanh(x) ~= x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
+   subject to the +/- 1 bounds.
+   The coefficients were determined by gradient descent trying to minimize
+   the maximum deviation over the whole range (this is only possible because
+   of the bounds). The max error is around 3e-4 and is dominated by the
+   reciprocal approximation (the max error of the rational function is
+   around 6e-5).
+   */
+static inline __m256 tanh8_approx(__m256 X)
+{
+   const __m256 N0 = _mm256_set1_ps(952.52801514f);
+   const __m256 N1 = _mm256_set1_ps(96.39235687f);
+   const __m256 N2 = _mm256_set1_ps(0.60863042f);
+   const __m256 D0 = _mm256_set1_ps(952.72399902f);
+   const __m256 D1 = _mm256_set1_ps(413.36801147f);
+   const __m256 D2 = _mm256_set1_ps(11.88600922f);
+   const __m256 max_out = _mm256_set1_ps(1.f);
+   const __m256 min_out = _mm256_set1_ps(-1.f);
+   __m256 X2, num, den;
+   X2 = _mm256_mul_ps(X, X);
+   num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
+   den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
+   num = _mm256_mul_ps(num, X);
+   den = _mm256_rcp_ps(den);
+   num = _mm256_mul_ps(num, den);
+   return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
+}
+
+/* Sigmoid approximation using a Padé-like rational function:
+   1/(1+exp(-x)) ~= 0.5 + x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
+   subject to the [0, 1] bounds.
+   The coefficients are directly derived by dividing the tanh() coefficients
+   by powers of two to get the correct scaling. The max error is around 1.5e-4
+   and is dominated by the reciprocal approximation (the max error of the
+   rational function is around 3e-5).
+   */
+static inline __m256 sigmoid8_approx(__m256 X)
+{
+   const __m256 N0 = _mm256_set1_ps(238.13200378f);
+   const __m256 N1 = _mm256_set1_ps(6.02452230f);
+   const __m256 N2 = _mm256_set1_ps(0.00950985f);
+   const __m256 D0 = _mm256_set1_ps(952.72399902f);
+   const __m256 D1 = _mm256_set1_ps(103.34200287f);
+   const __m256 D2 = _mm256_set1_ps(0.74287558f);
+   const __m256 half = _mm256_set1_ps(0.5);
+   const __m256 max_out = _mm256_set1_ps(1.f);
+   const __m256 min_out = _mm256_set1_ps(0.f);
+   __m256 X2, num, den;
+   X2 = _mm256_mul_ps(X, X);
+   num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
+   den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
+   num = _mm256_mul_ps(num, X);
+   den = _mm256_rcp_ps(den);
+   num = _mm256_fmadd_ps(num, den, half);
+   return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
+}
+
 #else
-#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
-#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
 static inline __m128 exp4_approx(__m128 X)
 {
    const __m128 K0 = _mm_set1_ps(0.99992522f);
@@ -98,6 +160,48 @@
    Y = _mm256_insertf128_ps(Y, Ylo, 0);
    return Y;
 }
+
+static inline __m128 tanh4_approx(__m128 X)
+{
+   const __m128 N0 = _mm_set1_ps(952.52801514f);
+   const __m128 N1 = _mm_set1_ps(96.39235687f);
+   const __m128 N2 = _mm_set1_ps(0.60863042f);
+   const __m128 D0 = _mm_set1_ps(952.72399902f);
+   const __m128 D1 = _mm_set1_ps(413.36801147f);
+   const __m128 D2 = _mm_set1_ps(11.88600922f);
+   const __m128 max_out = _mm_set1_ps(1.f);
+   const __m128 min_out = _mm_set1_ps(-1.f);
+   __m128 X2, num, den;
+   X2 = _mm_mul_ps(X, X);
+   num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
+   den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
+   num = _mm_mul_ps(num, X);
+   den = _mm_rcp_ps(den);
+   num = _mm_mul_ps(num, den);
+   return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
+}
+
+static inline __m128 sigmoid4_approx(__m128 X)
+{
+   const __m128 N0 = _mm_set1_ps(238.13200378f);
+   const __m128 N1 = _mm_set1_ps(6.02452230f);
+   const __m128 N2 = _mm_set1_ps(0.00950985f);
+   const __m128 D0 = _mm_set1_ps(952.72399902f);
+   const __m128 D1 = _mm_set1_ps(103.34200287f);
+   const __m128 D2 = _mm_set1_ps(0.74287558f);
+   const __m128 half = _mm_set1_ps(0.5);
+   const __m128 max_out = _mm_set1_ps(1.f);
+   const __m128 min_out = _mm_set1_ps(0.f);
+   __m128 X2, num, den;
+   X2 = _mm_mul_ps(X, X);
+   num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
+   den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
+   num = _mm_mul_ps(num, X);
+   den = _mm_rcp_ps(den);
+   num = _mm_fmadd_ps(num, den, half);
+   return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
+}
+
 #endif
 
 static inline float celt_exp(float x)
@@ -124,18 +228,15 @@
         y[i] = celt_exp(x[i]);
 }
 
+#ifdef __AVX2__
 static inline void vec_tanh(float *y, const float *x, int N)
 {
     int i;
     for (i=0;i<N-7;i+=8)
     {
-        const __m256 two = _mm256_set1_ps(2.f);
-        const __m256 one = _mm256_set1_ps(1.f);
         __m256 X, Y;
         X = _mm256_loadu_ps(&x[i]);
-        X = _mm256_mul_ps(X, two);
-        Y = exp8_approx(X);
-        Y = _mm256_mul_ps(_mm256_sub_ps(Y, one),  _mm256_rcp_ps(_mm256_add_ps(Y, one)));
+        Y = tanh8_approx(X);
         _mm256_storeu_ps(&y[i], Y);
     }
     for (;i<N;i++)
@@ -151,12 +252,9 @@
     int i;
     for (i=0;i<N-7;i+=8)
     {
-        const __m256 one = _mm256_set1_ps(1.f);
         __m256 X, Y;
         X = _mm256_loadu_ps(&x[i]);
-        Y = exp8_approx(X);
-        /* Compute as 1-1/(1+e^x) to avoid >1 values caused by the reciprocal approximation. */
-        Y = _mm256_sub_ps(one, _mm256_rcp_ps(_mm256_add_ps(Y, one)));
+        Y = sigmoid8_approx(X);
         _mm256_storeu_ps(&y[i], Y);
     }
     for (;i<N;i++)
@@ -166,6 +264,44 @@
         y[i] = (ex)/(ex+1);
     }
 }
+#else
+static inline void vec_tanh(float *y, const float *x, int N)
+{
+    int i;
+    for (i=0;i<N-3;i+=4)
+    {
+        __m128 X, Y;
+        X = _mm_loadu_ps(&x[i]);
+        Y = tanh4_approx(X);
+        _mm_storeu_ps(&y[i], Y);
+    }
+    for (;i<N;i++)
+    {
+        float ex2;
+        ex2 = celt_exp(2*x[i]);
+        y[i] = (ex2-1)/(ex2+1);
+    }
+}
+
+static inline void vec_sigmoid(float *y, const float *x, int N)
+{
+    int i;
+    for (i=0;i<N-3;i+=4)
+    {
+        __m128 X, Y;
+        X = _mm_loadu_ps(&x[i]);
+        Y = sigmoid4_approx(X);
+        _mm_storeu_ps(&y[i], Y);
+    }
+    for (;i<N;i++)
+    {
+        float ex;
+        ex = celt_exp(x[i]);
+        y[i] = (ex)/(ex+1);
+    }
+}
+
+#endif
 
 static inline void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
 {
--