shithub: dav1d

Download patch

ref: 64f9db558bec70accbbb3f58c81f5caf7859646f
parent: d4a7c6470b80a50a29cbfd6f791f891277ce0a62
author: Henrik Gramner <gramner@twoorioles.com>
date: Thu Feb 20 10:25:46 EST 2020

x86: Add mc w_mask 4:4:4 AVX-512 (Ice Lake) asm

--- a/src/x86/mc.asm
+++ b/src/x86/mc.asm
@@ -67,6 +67,10 @@
                 db  1,  5,  9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61
                 db 66, 70, 74, 78, 82, 86, 90, 94, 98,102,106,110,114,118,122,126
                 db 65, 69, 73, 77, 81, 85, 89, 93, 97,101,105,109,113,117,121,125
+wm_444_mask:    db  1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31
+                db 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63
+                db  0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30
+                db 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62
 bilin_h_perm16: db  1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  7
                 db  9,  8, 10,  9, 11, 10, 12, 11, 13, 12, 14, 13, 15, 14, 16, 15
                 db 33, 32, 34, 33, 35, 34, 36, 35, 37, 36, 38, 37, 39, 38, 40, 39
@@ -244,6 +248,7 @@
 BIDIR_JMP_TABLE mask_avx512icl,            4, 8, 16, 32, 64, 128
 BIDIR_JMP_TABLE w_mask_420_avx512icl,      4, 8, 16, 32, 64, 128
 BIDIR_JMP_TABLE w_mask_422_avx512icl,      4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_mask_444_avx512icl,      4, 8, 16, 32, 64, 128
 
 SECTION .text
 
@@ -4411,6 +4416,9 @@
 %if cpuflag(avx512icl)
     vpshldw             m%2, m3, 8
     psllw                m3, m%2, 10
+%if %5
+    psubb               m%2, m5, m%2
+%endif
 %else
     psrlw                m3, 8
 %if %5
@@ -4429,136 +4437,6 @@
     packuswb            m%1, m1
 %endmacro
 
-cglobal w_mask_444, 4, 8, 8, dst, stride, tmp1, tmp2, w, h, mask, stride3
-%define base r7-w_mask_444_avx2_table
-    lea                  r7, [w_mask_444_avx2_table]
-    tzcnt                wd, wm
-    movifnidn            hd, hm
-    mov               maskq, maskmp
-    movsxd               wq, dword [r7+wq*4]
-    vpbroadcastd         m6, [base+pw_6903] ; ((64 - 38) << 8) + 255 - 8
-    vpbroadcastd         m7, [base+pw_2048]
-    vpbroadcastd         m5, [base+pb_64]
-    add                  wq, r7
-    W_MASK                0, 4, 0, 1, 1
-    lea            stride3q, [strideq*3]
-    jmp                  wq
-.w4:
-    vextracti128        xm1, m0, 1
-    movd   [dstq+strideq*0], xm0
-    pextrd [dstq+strideq*1], xm0, 1
-    movd   [dstq+strideq*2], xm1
-    pextrd [dstq+stride3q ], xm1, 1
-    mova       [maskq+32*0], m4
-    cmp                  hd, 8
-    jl .w4_end
-    lea                dstq, [dstq+strideq*4]
-    pextrd [dstq+strideq*0], xm0, 2
-    pextrd [dstq+strideq*1], xm0, 3
-    pextrd [dstq+strideq*2], xm1, 2
-    pextrd [dstq+stride3q ], xm1, 3
-    je .w4_end
-    W_MASK                0, 4, 2, 3, 1
-    lea                dstq, [dstq+strideq*4]
-    vextracti128        xm1, m0, 1
-    movd   [dstq+strideq*0], xm0
-    pextrd [dstq+strideq*1], xm0, 1
-    movd   [dstq+strideq*2], xm1
-    pextrd [dstq+stride3q ], xm1, 1
-    lea                dstq, [dstq+strideq*4]
-    pextrd [dstq+strideq*0], xm0, 2
-    pextrd [dstq+strideq*1], xm0, 3
-    pextrd [dstq+strideq*2], xm1, 2
-    pextrd [dstq+stride3q ], xm1, 3
-    mova       [maskq+32*1], m4
-.w4_end:
-    RET
-.w8_loop:
-    add               tmp1q, 32*2
-    add               tmp2q, 32*2
-    W_MASK                0, 4, 0, 1, 1
-    lea                dstq, [dstq+strideq*4]
-    add               maskq, 32
-.w8:
-    vextracti128        xm1, m0, 1
-    movq   [dstq+strideq*0], xm0
-    movq   [dstq+strideq*1], xm1
-    movhps [dstq+strideq*2], xm0
-    movhps [dstq+stride3q ], xm1
-    mova            [maskq], m4
-    sub                  hd, 4
-    jg .w8_loop
-    RET
-.w16_loop:
-    add               tmp1q, 32*2
-    add               tmp2q, 32*2
-    W_MASK                0, 4, 0, 1, 1
-    lea                dstq, [dstq+strideq*2]
-    add               maskq, 32
-.w16:
-    vpermq               m0, m0, q3120
-    mova         [dstq+strideq*0], xm0
-    vextracti128 [dstq+strideq*1], m0, 1
-    mova            [maskq], m4
-    sub                  hd, 2
-    jg .w16_loop
-    RET
-.w32_loop:
-    add               tmp1q, 32*2
-    add               tmp2q, 32*2
-    W_MASK                0, 4, 0, 1, 1
-    add                dstq, strideq
-    add               maskq, 32
-.w32:
-    vpermq               m0, m0, q3120
-    mova             [dstq], m0
-    mova            [maskq], m4
-    dec                  hd
-    jg .w32_loop
-    RET
-.w64_loop:
-    add               tmp1q, 32*4
-    add               tmp2q, 32*4
-    W_MASK                0, 4, 0, 1, 1
-    add                dstq, strideq
-    add               maskq, 32*2
-.w64:
-    vpermq               m0, m0, q3120
-    mova        [dstq+32*0], m0
-    mova       [maskq+32*0], m4
-    W_MASK                0, 4, 2, 3, 1
-    vpermq               m0, m0, q3120
-    mova        [dstq+32*1], m0
-    mova       [maskq+32*1], m4
-    dec                  hd
-    jg .w64_loop
-    RET
-.w128_loop:
-    add               tmp1q, 32*8
-    add               tmp2q, 32*8
-    W_MASK                0, 4, 0, 1, 1
-    add                dstq, strideq
-    add               maskq, 32*4
-.w128:
-    vpermq               m0, m0, q3120
-    mova        [dstq+32*0], m0
-    mova       [maskq+32*0], m4
-    W_MASK                0, 4, 2, 3, 1
-    vpermq               m0, m0, q3120
-    mova        [dstq+32*1], m0
-    mova       [maskq+32*1], m4
-    W_MASK                0, 4, 4, 5, 1
-    vpermq               m0, m0, q3120
-    mova        [dstq+32*2], m0
-    mova       [maskq+32*2], m4
-    W_MASK                0, 4, 6, 7, 1
-    vpermq               m0, m0, q3120
-    mova        [dstq+32*3], m0
-    mova       [maskq+32*3], m4
-    dec                  hd
-    jg .w128_loop
-    RET
-
 cglobal blend, 3, 7, 7, dst, ds, tmp, w, h, mask
 %define base r6-blend_avx2_table
     lea                  r6, [blend_avx2_table]
@@ -5493,6 +5371,136 @@
     jg .w128_loop
     RET
 
+cglobal w_mask_444, 4, 8, 8, dst, stride, tmp1, tmp2, w, h, mask, stride3
+%define base r7-w_mask_444_avx2_table
+    lea                  r7, [w_mask_444_avx2_table]
+    tzcnt                wd, wm
+    movifnidn            hd, hm
+    mov               maskq, maskmp
+    movsxd               wq, dword [r7+wq*4]
+    vpbroadcastd         m6, [base+pw_6903] ; ((64 - 38) << 8) + 255 - 8
+    vpbroadcastd         m5, [base+pb_64]
+    vpbroadcastd         m7, [base+pw_2048]
+    add                  wq, r7
+    W_MASK                0, 4, 0, 1, 1
+    lea            stride3q, [strideq*3]
+    jmp                  wq
+.w4:
+    vextracti128        xm1, m0, 1
+    movd   [dstq+strideq*0], xm0
+    pextrd [dstq+strideq*1], xm0, 1
+    movd   [dstq+strideq*2], xm1
+    pextrd [dstq+stride3q ], xm1, 1
+    mova       [maskq+32*0], m4
+    cmp                  hd, 8
+    jl .w4_end
+    lea                dstq, [dstq+strideq*4]
+    pextrd [dstq+strideq*0], xm0, 2
+    pextrd [dstq+strideq*1], xm0, 3
+    pextrd [dstq+strideq*2], xm1, 2
+    pextrd [dstq+stride3q ], xm1, 3
+    je .w4_end
+    W_MASK                0, 4, 2, 3, 1
+    lea                dstq, [dstq+strideq*4]
+    vextracti128        xm1, m0, 1
+    movd   [dstq+strideq*0], xm0
+    pextrd [dstq+strideq*1], xm0, 1
+    movd   [dstq+strideq*2], xm1
+    pextrd [dstq+stride3q ], xm1, 1
+    lea                dstq, [dstq+strideq*4]
+    pextrd [dstq+strideq*0], xm0, 2
+    pextrd [dstq+strideq*1], xm0, 3
+    pextrd [dstq+strideq*2], xm1, 2
+    pextrd [dstq+stride3q ], xm1, 3
+    mova       [maskq+32*1], m4
+.w4_end:
+    RET
+.w8_loop:
+    add               tmp1q, 32*2
+    add               tmp2q, 32*2
+    W_MASK                0, 4, 0, 1, 1
+    lea                dstq, [dstq+strideq*4]
+    add               maskq, 32
+.w8:
+    vextracti128        xm1, m0, 1
+    movq   [dstq+strideq*0], xm0
+    movq   [dstq+strideq*1], xm1
+    movhps [dstq+strideq*2], xm0
+    movhps [dstq+stride3q ], xm1
+    mova            [maskq], m4
+    sub                  hd, 4
+    jg .w8_loop
+    RET
+.w16_loop:
+    add               tmp1q, 32*2
+    add               tmp2q, 32*2
+    W_MASK                0, 4, 0, 1, 1
+    lea                dstq, [dstq+strideq*2]
+    add               maskq, 32
+.w16:
+    vpermq               m0, m0, q3120
+    mova         [dstq+strideq*0], xm0
+    vextracti128 [dstq+strideq*1], m0, 1
+    mova            [maskq], m4
+    sub                  hd, 2
+    jg .w16_loop
+    RET
+.w32_loop:
+    add               tmp1q, 32*2
+    add               tmp2q, 32*2
+    W_MASK                0, 4, 0, 1, 1
+    add                dstq, strideq
+    add               maskq, 32
+.w32:
+    vpermq               m0, m0, q3120
+    mova             [dstq], m0
+    mova            [maskq], m4
+    dec                  hd
+    jg .w32_loop
+    RET
+.w64_loop:
+    add               tmp1q, 32*4
+    add               tmp2q, 32*4
+    W_MASK                0, 4, 0, 1, 1
+    add                dstq, strideq
+    add               maskq, 32*2
+.w64:
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*0], m0
+    mova       [maskq+32*0], m4
+    W_MASK                0, 4, 2, 3, 1
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*1], m0
+    mova       [maskq+32*1], m4
+    dec                  hd
+    jg .w64_loop
+    RET
+.w128_loop:
+    add               tmp1q, 32*8
+    add               tmp2q, 32*8
+    W_MASK                0, 4, 0, 1, 1
+    add                dstq, strideq
+    add               maskq, 32*4
+.w128:
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*0], m0
+    mova       [maskq+32*0], m4
+    W_MASK                0, 4, 2, 3, 1
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*1], m0
+    mova       [maskq+32*1], m4
+    W_MASK                0, 4, 4, 5, 1
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*2], m0
+    mova       [maskq+32*2], m4
+    W_MASK                0, 4, 6, 7, 1
+    vpermq               m0, m0, q3120
+    mova        [dstq+32*3], m0
+    mova       [maskq+32*3], m4
+    dec                  hd
+    jg .w128_loop
+    RET
+
 INIT_ZMM avx512icl
 PREP_BILIN
 PREP_8TAP
@@ -5861,6 +5869,155 @@
     add               maskq, 64
     mova        [dstq+64*0], m0
     mova        [dstq+64*1], m1
+    add                dstq, strideq
+    dec                  hd
+    jg .w128_loop
+    RET
+
+cglobal w_mask_444, 4, 8, 12, dst, stride, tmp1, tmp2, w, h, mask, stride3
+%define base r7-w_mask_444_avx512icl_table
+    lea                  r7, [w_mask_444_avx512icl_table]
+    tzcnt                wd, wm
+    movifnidn            hd, hm
+    movsxd               wq, dword [r7+wq*4]
+    vpbroadcastd         m6, [base+pw_6903] ; ((64 - 38) << 8) + 255 - 8
+    vpbroadcastd         m5, [base+pb_64]
+    vpbroadcastd         m7, [base+pw_2048]
+    mova                 m8, [base+wm_444_mask]
+    add                  wq, r7
+    mov               maskq, maskmp
+    lea            stride3q, [strideq*3]
+    jmp                  wq
+.w4:
+    cmp                  hd, 8
+    jg .w4_h16
+    WRAP_YMM W_MASK       0, 4, 0, 1, 1
+    vinserti128         ym8, [wm_444_mask+32], 1
+    vpermb              ym4, ym8, ym4
+    mova            [maskq], ym4
+    vextracti128       xmm1, m0, 1
+    movd   [dstq+strideq*0], xm0
+    pextrd [dstq+strideq*1], xm0, 1
+    movd   [dstq+strideq*2], xmm1
+    pextrd [dstq+stride3q ], xmm1, 1
+    jl .w4_end
+    lea                dstq, [dstq+strideq*4]
+    pextrd [dstq+strideq*0], xm0, 2
+    pextrd [dstq+strideq*1], xm0, 3
+    pextrd [dstq+strideq*2], xmm1, 2
+    pextrd [dstq+stride3q ], xmm1, 3
+.w4_end:
+    RET
+.w4_h16:
+    vpbroadcastd         m9, strided
+    pmulld               m9, [bidir_sctr_w4]
+    W_MASK                0, 4, 0, 1, 1
+    vpermb               m4, m8, m4
+    kxnorw               k1, k1, k1
+    mova            [maskq], m4
+    vpscatterdd [dstq+m9]{k1}, m0
+    RET
+.w8:
+    cmp                  hd, 4
+    jne .w8_h8
+    WRAP_YMM W_MASK       0, 4, 0, 1, 1
+    vinserti128         ym8, [wm_444_mask+32], 1
+    vpermb              ym4, ym8, ym4
+    mova            [maskq], ym4
+    vextracti128       xmm1, ym0, 1
+    movq   [dstq+strideq*0], xm0
+    movq   [dstq+strideq*1], xmm1
+    movhps [dstq+strideq*2], xm0
+    movhps [dstq+stride3q ], xmm1
+    RET
+.w8_loop:
+    add               tmp1q, 128
+    add               tmp2q, 128
+    add               maskq, 64
+    lea                dstq, [dstq+strideq*4]
+.w8_h8:
+    W_MASK                0, 4, 0, 1, 1
+    vpermb               m4, m8, m4
+    mova            [maskq], m4
+    vextracti32x4      xmm1, ym0, 1
+    vextracti32x4      xmm2, m0, 2
+    vextracti32x4      xmm3, m0, 3
+    movq   [dstq+strideq*0], xm0
+    movq   [dstq+strideq*1], xmm1
+    movq   [dstq+strideq*2], xmm2
+    movq   [dstq+stride3q ], xmm3
+    lea                dstq, [dstq+strideq*4]
+    movhps [dstq+strideq*0], xm0
+    movhps [dstq+strideq*1], xmm1
+    movhps [dstq+strideq*2], xmm2
+    movhps [dstq+stride3q ], xmm3
+    sub                  hd, 8
+    jg .w8_loop
+    RET
+.w16_loop:
+    add               tmp1q, 128
+    add               tmp2q, 128
+    add               maskq, 64
+    lea                dstq, [dstq+strideq*4]
+.w16:
+    W_MASK                0, 4, 0, 1, 1
+    vpermb               m4, m8, m4
+    vpermq               m0, m0, q3120
+    mova            [maskq], m4
+    mova          [dstq+strideq*0], xm0
+    vextracti32x4 [dstq+strideq*1], m0, 2
+    vextracti32x4 [dstq+strideq*2], ym0, 1
+    vextracti32x4 [dstq+stride3q ], m0, 3
+    sub                  hd, 4
+    jg .w16_loop
+    RET
+.w32:
+    pmovzxbq             m9, [warp_8x8_shufA]
+.w32_loop:
+    W_MASK                0, 4, 0, 1, 1
+    vpermb               m4, m8, m4
+    add               tmp1q, 128
+    add               tmp2q, 128
+    vpermq               m0, m9, m0
+    mova            [maskq], m4
+    add               maskq, 64
+    mova          [dstq+strideq*0], ym0
+    vextracti32x8 [dstq+strideq*1], m0, 1
+    lea                dstq, [dstq+strideq*2]
+    sub                  hd, 2
+    jg .w32_loop
+    RET
+.w64:
+    pmovzxbq             m9, [warp_8x8_shufA]
+.w64_loop:
+    W_MASK                0, 4, 0, 1, 1
+    vpermb               m4, m8, m4
+    add               tmp1q, 128
+    add               tmp2q, 128
+    vpermq               m0, m9, m0
+    mova            [maskq], m4
+    add               maskq, 64
+    mova             [dstq], m0
+    add                dstq, strideq
+    dec                  hd
+    jg .w64_loop
+    RET
+.w128:
+    pmovzxbq            m11, [warp_8x8_shufA]
+.w128_loop:
+    W_MASK                0, 4, 0, 1, 1
+    W_MASK               10, 9, 2, 3, 1
+    vpermb               m4, m8, m4
+    vpermb               m9, m8, m9
+    add               tmp1q, 256
+    add               tmp2q, 256
+    vpermq               m0, m11, m0
+    vpermq              m10, m11, m10
+    mova       [maskq+64*0], m4
+    mova       [maskq+64*1], m9
+    add               maskq, 128
+    mova        [dstq+64*0], m0
+    mova        [dstq+64*1], m10
     add                dstq, strideq
     dec                  hd
     jg .w128_loop
--- a/src/x86/mc_init_tmpl.c
+++ b/src/x86/mc_init_tmpl.c
@@ -94,6 +94,7 @@
 decl_w_mask_fn(dav1d_w_mask_420_ssse3);
 decl_w_mask_fn(dav1d_w_mask_422_avx512icl);
 decl_w_mask_fn(dav1d_w_mask_422_avx2);
+decl_w_mask_fn(dav1d_w_mask_444_avx512icl);
 decl_w_mask_fn(dav1d_w_mask_444_avx2);
 decl_blend_fn(dav1d_blend_avx2);
 decl_blend_fn(dav1d_blend_ssse3);
@@ -238,6 +239,7 @@
     c->avg = dav1d_avg_avx512icl;
     c->w_avg = dav1d_w_avg_avx512icl;
     c->mask = dav1d_mask_avx512icl;
+    c->w_mask[0] = dav1d_w_mask_444_avx512icl;
     c->w_mask[1] = dav1d_w_mask_422_avx512icl;
     c->w_mask[2] = dav1d_w_mask_420_avx512icl;
 #endif