ref: 6b582edbed29fd4a59993adf81b47255b04b71ee
parent: be392e38572760dda8d484285e55fcd5e35ddbfa
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Fri Dec 25 23:01:59 EST 2020
WIP: remove scalar code from AVX2 code
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -257,27 +257,12 @@
__m256i vxj;
__m256i vw;
int pos;
- int xj0, xj1, xj2, xj3;
pos = 4 * (*idx++);
vxj = _mm256_set1_epi32(*(int*)&x[pos]);
- xj0 = x[pos+0];
- xj1 = x[pos+1];
- xj2 = x[pos+2];
- xj3 = x[pos+3];
-
vw = _mm256_loadu_si256((const __m256i *)w); //_mm256_lddqu_si256?
tmp = _mm256_maddubs_epi16(vxj, vw); //swap?
tmp = _mm256_madd_epi16(tmp, ones);
vy0 = _mm256_add_epi32(vy0, tmp);
-
- y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
- y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
- y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
- y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
- y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
- y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
- y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
- y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
w += 32;
}
_mm256_storeu_si256 ((__m256i *)&y[0], vy0);
--
⑨