ref: f16b43cdfa2f3f2d5af36185819bebf1ca9c806d
parent: 502204562c7458d42812dcf5c32b0cbc28150c25
author: Henrik Gramner <gramner@twoorioles.com>
date: Tue Jan 7 19:43:58 EST 2020
x86: Fix AVX2 inverse identity transform overflow/clipping The coefficients after the first (8-bit) 1D identity transform may require more than 16 bits precision before downshifting in some cases, and they may also need to be clipped to int16_t after downshifting.
--- a/src/x86/itx.asm
+++ b/src/x86/itx.asm
@@ -60,7 +60,6 @@
pw_1697x16: times 2 dw 1697*16
pw_1697x8: times 2 dw 1697*8
pw_2896x8: times 2 dw 2896*8
-pw_5793x4: times 2 dw 5793*4
pd_2048: dd 2048
@@ -393,7 +392,7 @@
pmulhrsw m0, [cq]
vpbroadcastd m1, [o(pw_1697x8)]
pmulhrsw m1, m0
- paddw m0, m1
+ paddsw m0, m1
punpcklwd m0, m0
punpckhdq m1, m0, m0
punpckldq m0, m0
@@ -405,7 +404,7 @@
vpbroadcastd m2, [o(pw_2896x8)]
packusdw m0, m0
pmulhrsw m1, m0
- paddw m0, m1
+ paddsw m0, m1
pmulhrsw m0, m2
mova m1, m0
jmp m(iadst_4x4_internal).end
@@ -561,8 +560,8 @@
vpbroadcastd m3, [o(pw_1697x8)]
pmulhrsw m2, m3, m0
pmulhrsw m3, m1
- paddw m0, m2
- paddw m1, m3
+ paddsw m0, m2
+ paddsw m1, m3
punpckhwd m2, m0, m1
punpcklwd m0, m1
punpckhwd m1, m0, m2
@@ -572,8 +571,8 @@
vpbroadcastd m3, [o(pw_1697x8)]
pmulhrsw m2, m3, m0
pmulhrsw m3, m1
- paddw m0, m2
- paddw m1, m3
+ paddsw m0, m2
+ paddsw m1, m3
jmp m(iadst_4x4_internal).end
%macro WRITE_4X8 2 ; coefs[1-2]
@@ -626,7 +625,7 @@
punpckldq xm0, xm1
pmulhrsw xm0, xm2
pmulhrsw xm3, xm0
- paddw xm0, xm3
+ paddsw xm0, xm3
pmulhrsw xm0, xm2
pmulhrsw xm0, xm4
vpbroadcastq m0, xm0
@@ -907,8 +906,8 @@
punpckhwd m1, m2
pmulhrsw m2, m4, m0
pmulhrsw m4, m1
- paddw m0, m2
- paddw m1, m4
+ paddsw m0, m2
+ paddsw m1, m4
jmp tx2q
.pass2:
vpbroadcastd m4, [o(pw_4096)]
@@ -925,8 +924,8 @@
vpbroadcastd m3, [o(pw_2048)]
pmulhrsw m0, m1
pmulhrsw m2, m0
- paddw m0, m0
- paddw m0, m2
+ paddsw m0, m0
+ paddsw m0, m2
pmulhrsw m3, m0
punpcklwd m1, m3, m3
punpckhwd m3, m3
@@ -941,15 +940,16 @@
movd xm1, [cq+32*2]
punpcklwd xm1, [cq+32*3]
vpbroadcastd xm2, [o(pw_1697x8)]
- vpbroadcastd xm3, [o(pw_16384)]
- vpbroadcastd xm4, [o(pw_2896x8)]
+ vpbroadcastd xm3, [o(pw_2896x8)]
+ vpbroadcastd xm4, [o(pw_2048)]
punpckldq xm0, xm1
+ pcmpeqw xm1, xm1
pmulhrsw xm2, xm0
- paddw xm0, xm2
+ pcmpeqw xm1, xm0
+ pxor xm0, xm1
+ pavgw xm0, xm2
pmulhrsw xm0, xm3
- psrlw xm3, 3 ; pw_2048
pmulhrsw xm0, xm4
- pmulhrsw xm0, xm3
vpbroadcastq m0, xm0
mova m1, m0
mova m2, m0
@@ -1283,26 +1283,33 @@
mova m3, [cq+32*0]
mova m2, [cq+32*1]
mova m4, [cq+32*2]
- mova m0, [cq+32*3]
- vpbroadcastd m5, [o(pw_1697x8)]
+ mova m5, [cq+32*3]
+ vpbroadcastd m8, [o(pw_1697x8)]
+ pcmpeqw m0, m0 ; -1
punpcklwd m1, m3, m2
punpckhwd m3, m2
- punpcklwd m2, m4, m0
- punpckhwd m4, m0
- pmulhrsw m0, m5, m1
- pmulhrsw m6, m5, m2
- pmulhrsw m7, m5, m3
- pmulhrsw m5, m4
- paddw m1, m0
- paddw m2, m6
- paddw m3, m7
- paddw m4, m5
- vpbroadcastd m5, [o(pw_16384)]
+ punpcklwd m2, m4, m5
+ punpckhwd m4, m5
+ pmulhrsw m5, m8, m1
+ pmulhrsw m6, m8, m2
+ pmulhrsw m7, m8, m3
+ pmulhrsw m8, m4
+ pcmpeqw m9, m0, m1 ; we want to do a signed avg, but pavgw is
+ pxor m1, m9 ; unsigned. as long as both signs are equal
+ pcmpeqw m9, m0, m2 ; it still works, but if the input is -1 the
+ pxor m2, m9 ; pmulhrsw result will become 0 which causes
+ pcmpeqw m9, m0, m3 ; pavgw to output -32768 instead of 0 unless
+ pxor m3, m9 ; we explicitly deal with that case here.
+ pcmpeqw m0, m4
+ pxor m4, m0
+ pavgw m1, m5
+ pavgw m2, m6
+ pavgw m3, m7
+ pavgw m4, m8
punpckldq m0, m1, m2
punpckhdq m1, m2
punpckldq m2, m3, m4
punpckhdq m3, m4
- REPX {pmulhrsw x, m5}, m0, m1, m2, m3
jmp tx2q
.pass2:
vpbroadcastd m8, [o(pw_1697x16)]
@@ -1311,11 +1318,11 @@
pmulhrsw m6, m8, m1
pmulhrsw m7, m8, m2
pmulhrsw m8, m3
- REPX {paddw x, x}, m0, m1, m2, m3
- paddw m0, m4
- paddw m1, m6
- paddw m2, m7
- paddw m3, m8
+ REPX {paddsw x, x}, m0, m1, m2, m3
+ paddsw m0, m4
+ paddsw m1, m6
+ paddsw m2, m7
+ paddsw m3, m8
jmp m(iadst_4x16_internal).end2
%macro WRITE_8X4 4-7 strideq*1, strideq*2, r3, ; coefs[1-2], tmp[1-2], off[1-3]
@@ -1353,7 +1360,7 @@
vpbroadcastd xm3, [o(pw_2048)]
pmulhrsw xm1, xm0
pmulhrsw xm2, xm1
- paddw xm1, xm2
+ paddsw xm1, xm2
pmulhrsw xm1, xm3
punpcklwd xm1, xm1
punpckldq xm0, xm1, xm1
@@ -1369,7 +1376,7 @@
vpbroadcastd xm3, [o(pw_2048)]
packusdw xm0, xm1
pmulhrsw xm0, xm2
- paddw xm0, xm0
+ paddsw xm0, xm0
pmulhrsw xm0, xm2
pmulhrsw xm0, xm3
vinserti128 m0, m0, xm0, 1
@@ -1520,15 +1527,15 @@
pmulhrsw m2, m3
punpcklwd m0, m1, m2
punpckhwd m1, m2
- paddw m0, m0
- paddw m1, m1
+ paddsw m0, m0
+ paddsw m1, m1
jmp tx2q
.pass2:
vpbroadcastd m3, [o(pw_1697x8)]
pmulhrsw m2, m3, m0
pmulhrsw m3, m1
- paddw m0, m2
- paddw m1, m3
+ paddsw m0, m2
+ paddsw m1, m3
jmp m(iadst_8x4_internal).end
%macro INV_TXFM_8X8_FN 2-3 -1 ; type1, type2, fast_thresh
@@ -1796,8 +1803,8 @@
pmulhrsw m7, m1
psrlw m1, 3 ; pw_2048
pmulhrsw m2, m7
- paddw m7, m7
- paddw m7, m2
+ paddsw m7, m7
+ paddsw m7, m2
pmulhrsw m7, m1
punpcklwd m5, m7, m7
punpckhwd m7, m7
@@ -2120,12 +2127,12 @@
%macro IDTX16 3-4 ; src/dst, tmp, pw_1697x16, [pw_16394]
pmulhrsw m%2, m%3, m%1
-%if %0 == 4 ; if we're going to downshift by 1 doing so here eliminates the paddw
+%if %0 == 4 ; if downshifting by 1
pmulhrsw m%2, m%4
%else
- paddw m%1, m%1
+ paddsw m%1, m%1
%endif
- paddw m%1, m%2
+ paddsw m%1, m%2
%endmacro
cglobal iidentity_8x16_internal, 0, 5, 13, dst, stride, c, eob, tx2
@@ -2201,7 +2208,7 @@
pmulhrsw xm3, xm0
psrlw xm0, 3 ; pw_2048
pmulhrsw xm1, xm3
- paddw xm3, xm1
+ paddsw xm3, xm1
pmulhrsw xm3, xm0
punpcklwd xm3, xm3
punpckldq xm1, xm3, xm3
@@ -2228,7 +2235,7 @@
vpbroadcastd m1, [o(pw_2896x8)]
pmulhrsw m4, m0
pmulhrsw m4, m5
- paddw m0, m4
+ paddsw m0, m4
psrlw m5, 3 ; pw_2048
pmulhrsw m0, m1
pmulhrsw m0, m5
@@ -2503,10 +2510,10 @@
pmulhrsw m6, m7, m3
pmulhrsw m7, m4
REPX {pmulhrsw x, m8}, m0, m5, m6, m7
- paddw m1, m0
- paddw m2, m5
- paddw m3, m6
- paddw m4, m7
+ paddsw m1, m0
+ paddsw m2, m5
+ paddsw m3, m6
+ paddsw m4, m7
punpcklqdq m0, m1, m2
punpckhqdq m1, m2
punpcklqdq m2, m3, m4
@@ -2518,10 +2525,10 @@
pmulhrsw m5, m7, m1
pmulhrsw m6, m7, m2
pmulhrsw m7, m3
- paddw m0, m4
- paddw m1, m5
- paddw m2, m6
- paddw m3, m7
+ paddsw m0, m4
+ paddsw m1, m5
+ paddsw m2, m6
+ paddsw m3, m7
jmp m(iadst_16x4_internal).end
%macro INV_TXFM_16X8_FN 2-3 -1 ; type1, type2, fast_thresh
@@ -2581,7 +2588,7 @@
pmulhrsw m0, m4
pmulhrsw m5, m0
pmulhrsw m5, m2
- paddw m0, m5
+ paddsw m0, m5
psrlw m2, 3 ; pw_2048
pmulhrsw m0, m4
pmulhrsw m0, m2
@@ -2903,7 +2910,7 @@
vpbroadcastd m3, [o(pw_2896x8)]
pmulhrsw m3, [cq]
vpbroadcastd m0, [o(pw_8192)]
- vpbroadcastd m1, [o(pw_5793x4)]
+ vpbroadcastd m1, [o(pw_1697x16)]
vpbroadcastw m4, [o(deint_shuf)] ; pb_0_1
pcmpeqb m5, m5
pxor m6, m6
@@ -2911,8 +2918,7 @@
paddb m5, m5 ; pb_m2
pmulhrsw m3, m0
psrlw m0, 2 ; pw_2048
- psllw m3, 2
- pmulhrsw m3, m1
+ IDTX16 3, 1, 1
pmulhrsw m3, m0
mov r3d, 8
.loop:
@@ -2954,17 +2960,15 @@
punpcklwd m1, m3
vpbroadcastd m3, [o(pw_1697x16)]
punpcklwd m2, m4
- vpbroadcastd m4, [o(pw_8192)]
+ vpbroadcastd m4, [o(pw_2896x8)]
punpckldq m1, m2
- vpbroadcastd m2, [o(pw_2896x8)]
+ vpbroadcastd m2, [o(pw_2048)]
punpcklqdq m0, m1
pmulhrsw m3, m0
- paddw m0, m0
- paddw m0, m3
+ psraw m3, 1
+ pavgw m0, m3
pmulhrsw m0, m4
- psrlw m4, 2 ; pw_2048
pmulhrsw m0, m2
- pmulhrsw m0, m4
mov r3d, 8
jmp m(inv_txfm_add_identity_dct_16x4).end
%endif
@@ -3385,6 +3389,12 @@
WRITE_16X2 7, [rsp+32*2], 0, 1, strideq*2, r3
jmp m(idct_16x16_internal).end3
+%macro IDTX16B 3 ; src/dst, tmp, pw_1697x16
+ pmulhrsw m%2, m%3, m%1
+ psraw m%2, 1
+ pavgw m%1, m%2 ; signs are guaranteed to be equal
+%endmacro
+
INV_TXFM_16X16_FN identity, dct, 15
INV_TXFM_16X16_FN identity, identity
@@ -3419,22 +3429,17 @@
vinserti128 m13, [cq+16*13], 1
mova xm14, [cq-16* 1]
vinserti128 m14, [cq+16*15], 1
- REPX {IDTX16 x, 6, 7}, 0, 15, 1, 8, 2, 9, 3, \
+ REPX {IDTX16B x, 6, 7}, 0, 15, 1, 8, 2, 9, 3, \
10, 4, 11, 5, 12, 13, 14
mova xm6, [cq-16* 4]
vinserti128 m6, [cq+16*12], 1
- mova [rsp], m1
- IDTX16 6, 1, 7
- mova xm1, [cq-16* 2]
- vinserti128 m1, [cq+16*14], 1
- pmulhrsw m7, m1
- paddw m1, m1
- paddw m7, m1
- vpbroadcastd m1, [o(pw_8192)]
- REPX {pmulhrsw x, m1}, m0, m2, m3, m4, m5, m6, m7, \
- m8, m9, m10, m11, m12, m13, m14, m15
- pmulhrsw m1, [rsp]
mova [rsp], m0
+ IDTX16B 6, 0, 7
+ mova xm0, [cq-16* 2]
+ vinserti128 m0, [cq+16*14], 1
+ pmulhrsw m7, m0
+ psraw m7, 1
+ pavgw m7, m0
jmp m(idct_16x16_internal).pass1_end3
ALIGN function_align
.pass2:
@@ -3963,7 +3968,7 @@
vinserti128 m6, m6, [cq+16* 9], 1
vinserti128 m7, m7, [cq+16*13], 1
REPX {mova [cq+32*x], m8}, -4, -2, 0, 2, 4, 6
- REPX {paddw x, m9}, m0, m1, m2, m3, m4, m5, m6, m7
+ REPX {paddsw x, m9}, m0, m1, m2, m3, m4, m5, m6, m7
call .transpose8x8
REPX {psraw x, 3 }, m0, m1, m2, m3, m4, m5, m6, m7
WRITE_8X4 0, 4, 8, 10, strideq*8, strideq*4, r4*4
@@ -4572,12 +4577,12 @@
IDCT32_PASS1_END 1, 9, 6, 7
ret
-cglobal inv_txfm_add_identity_identity_16x32, 4, 5, 12, dst, stride, c, eob
+cglobal inv_txfm_add_identity_identity_16x32, 4, 5, 13, dst, stride, c, eob
%undef cmp
lea rax, [o_base]
vpbroadcastd m9, [o(pw_2896x8)]
- vpbroadcastd m10, [o(pw_5793x4)]
- vpbroadcastd m11, [o(pw_5)]
+ vpbroadcastd m10, [o(pw_1697x16)]
+ vpbroadcastd m12, [o(pw_8192)]
cmp eobd, 43 ; if (eob > 43)
setg r4b ; iteration_count++
cmp eobd, 150 ; if (eob > 150)
@@ -4586,6 +4591,7 @@
adc r4b, al ; iteration_count++
lea r3, [strideq*3]
mov rax, cq
+ paddw m11, m12, m12 ; pw_16384
.loop:
mova xm0, [cq+64* 0]
mova xm1, [cq+64* 1]
@@ -4604,11 +4610,9 @@
vinserti128 m6, m6, [cq+64*14], 1
vinserti128 m7, m7, [cq+64*15], 1
REPX {pmulhrsw x, m9 }, m0, m1, m2, m3, m4, m5, m6, m7
- REPX {psllw x, 2 }, m0, m1, m2, m3, m4, m5, m6, m7
+ REPX {IDTX16 x, 8, 10, 11}, 0, 1, 2, 3, 4, 5, 6, 7
call m(inv_txfm_add_identity_identity_8x32).transpose8x8
- REPX {pmulhrsw x, m10}, m0, m1, m2, m3, m4, m5, m6, m7
- REPX {paddw x, m11}, m0, m1, m2, m3, m4, m5, m6, m7
- REPX {psraw x, 3 }, m0, m1, m2, m3, m4, m5, m6, m7
+ REPX {pmulhrsw x, m12}, m0, m1, m2, m3, m4, m5, m6, m7
WRITE_16X2 0, 1, 8, 0, strideq*0, strideq*1
WRITE_16X2 2, 3, 0, 1, strideq*2, r3
lea dstq, [dstq+strideq*4]
@@ -4646,7 +4650,7 @@
%undef cmp
lea rax, [o_base]
vpbroadcastd m9, [o(pw_2896x8)]
- vpbroadcastd m10, [o(pw_1697x8)]
+ vpbroadcastd m10, [o(pw_1697x16)]
vpbroadcastd m11, [o(pw_2048)]
cmp eobd, 35 ; if (eob > 35)
setg r4b ; iteration_count++
@@ -4674,24 +4678,9 @@
vinserti128 m6, m6, [cq+32*14], 1
vinserti128 m7, m7, [cq+32*15], 1
REPX {pmulhrsw x, m9 }, m0, m1, m2, m3, m4, m5, m6, m7
- REPX {psllw x, 2 }, m0, m1, m2, m3, m4, m5, m6, m7
+ REPX {paddsw x, x }, m0, m1, m2, m3, m4, m5, m6, m7
call m(inv_txfm_add_identity_identity_8x32).transpose8x8
- pmulhrsw m8, m10, m0
- paddw m0, m8
- pmulhrsw m8, m10, m1
- paddw m1, m8
- pmulhrsw m8, m10, m2
- paddw m2, m8
- pmulhrsw m8, m10, m3
- paddw m3, m8
- pmulhrsw m8, m10, m4
- paddw m4, m8
- pmulhrsw m8, m10, m5
- paddw m5, m8
- pmulhrsw m8, m10, m6
- paddw m6, m8
- pmulhrsw m8, m10, m7
- paddw m7, m8
+ REPX {IDTX16 x, 8, 10}, 0, 1, 2, 3, 4, 5, 6, 7
REPX {pmulhrsw x, m11}, m0, m1, m2, m3, m4, m5, m6, m7
WRITE_16X2 0, 1, 8, 0, strideq*0, strideq*1
WRITE_16X2 2, 3, 0, 1, strideq*2, r3