shithub: dav1d

Download patch

ref: 4c21c9312d31300401554404b188478f8f25f404
parent: 37093f98aee62eb79d8bc0d31ef29c13d3901066
author: Henrik Gramner <gramner@twoorioles.com>
date: Mon Feb 11 12:54:31 EST 2019

x86: Add w_mask_444 AVX2 asm

--- a/src/x86/mc.asm
+++ b/src/x86/mc.asm
@@ -92,6 +92,7 @@
 BIDIR_JMP_TABLE mask_avx2,       4, 8, 16, 32, 64, 128
 BIDIR_JMP_TABLE w_mask_420_avx2, 4, 8, 16, 32, 64, 128
 BIDIR_JMP_TABLE w_mask_422_avx2, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_mask_444_avx2, 4, 8, 16, 32, 64, 128
 BIDIR_JMP_TABLE blend_avx2,      4, 8, 16, 32
 BIDIR_JMP_TABLE blend_v_avx2, 2, 4, 8, 16, 32
 BIDIR_JMP_TABLE blend_h_avx2, 2, 4, 8, 16, 32, 32, 32
@@ -3055,7 +3056,7 @@
     add                  wq, r7
     BIDIR_FN           MASK
 
-%macro W_MASK 2 ; src_offset, mask_out
+%macro W_MASK 2-3 0 ; src_offset, mask_out, 4:4:4
     mova                 m0, [tmp1q+(%1+0)*mmsize]
     mova                 m1, [tmp2q+(%1+0)*mmsize]
     psubw                m1, m0
@@ -3071,7 +3072,13 @@
     pabsw                m3, m2
     psubusw              m3, m6, m3
     psrlw                m3, 8
+%if %3
+    packuswb            m%2, m3
+    psubb               m%2, m5, m%2
+    vpermq              m%2, m%2, q3120
+%else
     phaddw              m%2, m3
+%endif
     psllw                m3, 10
     pmulhw               m2, m3
     paddw                m1, m2
@@ -3459,6 +3466,136 @@
     vpermd               m5, m10, m5
     mova        [dstq+32*3], m0
     mova       [maskq+32*1], m5
+    dec                  hd
+    jg .w128_loop
+    RET
+
+cglobal w_mask_444, 4, 8, 8, dst, stride, tmp1, tmp2, w, h, mask, stride3
+%define base r7-w_mask_444_avx2_table
+    lea                  r7, [w_mask_444_avx2_table]
+    tzcnt                wd, wm
+    movifnidn            hd, hm
+    mov               maskq, maskmp
+    movsxd               wq, dword [r7+wq*4]
+    vpbroadcastd         m6, [base+pw_6903] ; ((64 - 38) << 8) + 255 - 8
+    vpbroadcastd         m7, [base+pw_2048]
+    vpbroadcastd         m5, [base+pb_64]
+    add                  wq, r7
+    W_MASK                0, 4, 1
+    lea            stride3q, [strideq*3]
+    jmp                  wq
+.w4:
+    vextracti128        xm1, m0, 1
+    movd   [dstq+strideq*0], xm0
+    pextrd [dstq+strideq*1], xm0, 1
+    movd   [dstq+strideq*2], xm1
+    pextrd [dstq+stride3q ], xm1, 1
+    mova       [maskq+32*0], m4
+    cmp                  hd, 8
+    jl .w4_end
+    lea                dstq, [dstq+strideq*4]
+    pextrd [dstq+strideq*0], xm0, 2
+    pextrd [dstq+strideq*1], xm0, 3
+    pextrd [dstq+strideq*2], xm1, 2
+    pextrd [dstq+stride3q ], xm1, 3
+    je .w4_end
+    W_MASK                2, 4, 1
+    lea                dstq, [dstq+strideq*4]
+    vextracti128        xm1, m0, 1
+    movd   [dstq+strideq*0], xm0
+    pextrd [dstq+strideq*1], xm0, 1
+    movd   [dstq+strideq*2], xm1
+    pextrd [dstq+stride3q ], xm1, 1
+    lea                dstq, [dstq+strideq*4]
+    pextrd [dstq+strideq*0], xm0, 2
+    pextrd [dstq+strideq*1], xm0, 3
+    pextrd [dstq+strideq*2], xm1, 2
+    pextrd [dstq+stride3q ], xm1, 3
+    mova       [maskq+32*1], m4
+.w4_end:
+    RET
+.w8_loop:
+    add               tmp1q, 32*2
+    add               tmp2q, 32*2
+    W_MASK                0, 4, 1
+    lea                dstq, [dstq+strideq*4]
+    add               maskq, 32
+.w8:
+    vextracti128        xm1, m0, 1
+    movq   [dstq+strideq*0], xm0
+    movq   [dstq+strideq*1], xm1
+    movhps [dstq+strideq*2], xm0
+    movhps [dstq+stride3q ], xm1
+    mova            [maskq], m4
+    sub                  hd, 4
+    jg .w8_loop
+    RET
+.w16_loop:
+    add               tmp1q, 32*2
+    add               tmp2q, 32*2
+    W_MASK                0, 4, 1
+    lea                dstq, [dstq+strideq*2]
+    add               maskq, 32
+.w16:
+    vpermq               m0, m0, q3120
+    mova         [dstq+strideq*0], xm0
+    vextracti128 [dstq+strideq*1], m0, 1
+    mova            [maskq], m4
+    sub                  hd, 2
+    jg .w16_loop
+    RET
+.w32_loop:
+    add               tmp1q, 32*2
+    add               tmp2q, 32*2
+    W_MASK                0, 4, 1
+    add                dstq, strideq
+    add               maskq, 32
+.w32:
+    vpermq               m0, m0, q3120
+    mova             [dstq], m0
+    mova            [maskq], m4
+    dec                  hd
+    jg .w32_loop
+    RET
+.w64_loop:
+    add               tmp1q, 32*4
+    add               tmp2q, 32*4
+    W_MASK                0, 4, 1
+    add                dstq, strideq
+    add               maskq, 32*2
+.w64:
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*0], m0
+    mova       [maskq+32*0], m4
+    W_MASK                2, 4, 1
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*1], m0
+    mova       [maskq+32*1], m4
+    dec                  hd
+    jg .w64_loop
+    RET
+.w128_loop:
+    add               tmp1q, 32*8
+    add               tmp2q, 32*8
+    W_MASK                0, 4, 1
+    add                dstq, strideq
+    add               maskq, 32*4
+.w128:
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*0], m0
+    mova       [maskq+32*0], m4
+    W_MASK                2, 4, 1
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*1], m0
+    mova       [maskq+32*1], m4
+    W_MASK                4, 4, 1
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*2], m0
+    mova       [maskq+32*2], m4
+    W_MASK                6, 4, 1
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*3], m0
+    mova       [maskq+32*3], m4
     dec                  hd
     jg .w128_loop
     RET
--- a/src/x86/mc_init_tmpl.c
+++ b/src/x86/mc_init_tmpl.c
@@ -60,6 +60,7 @@
 decl_w_mask_fn(dav1d_w_mask_420_avx2);
 decl_w_mask_fn(dav1d_w_mask_420_ssse3);
 decl_w_mask_fn(dav1d_w_mask_422_avx2);
+decl_w_mask_fn(dav1d_w_mask_444_avx2);
 decl_blend_fn(dav1d_blend_avx2);
 decl_blend_fn(dav1d_blend_ssse3);
 decl_blend_dir_fn(dav1d_blend_v_avx2);
@@ -126,6 +127,7 @@
     c->avg = dav1d_avg_avx2;
     c->w_avg = dav1d_w_avg_avx2;
     c->mask = dav1d_mask_avx2;
+    c->w_mask[0] = dav1d_w_mask_444_avx2;
     c->w_mask[1] = dav1d_w_mask_422_avx2;
     c->w_mask[2] = dav1d_w_mask_420_avx2;
     c->blend = dav1d_blend_avx2;