shithub: dav1d

Download patch

ref: 19ce77e0b9ec60e0a3a748ae9c2c2b417d3b75b5
parent: 58a4ba072214e2b4a39d1d54821a426c2269e233
author: Henrik Gramner <gramner@twoorioles.com>
date: Tue Jan 28 12:42:06 EST 2020

x86: Add cdef_filter_4x4 AVX-512 (Ice Lake) asm

--- a/src/cdef_apply_tmpl.c
+++ b/src/cdef_apply_tmpl.c
@@ -120,7 +120,7 @@
         if (edges & CDEF_HAVE_BOTTOM) // backup pre-filter data for next iteration
             backup2lines(f->lf.cdef_line[!tf], ptrs, f->cur.stride, layout);
 
-        pixel lr_bak[2 /* idx */][3 /* plane */][8 /* y */][2 /* x */];
+        ALIGN_STK_16(pixel, lr_bak, 2 /* idx */, [3 /* plane */][8 /* y */][2 /* x */]);
         pixel *iptrs[3] = { ptrs[0], ptrs[1], ptrs[2] };
         edges &= ~CDEF_HAVE_LEFT;
         edges |= CDEF_HAVE_RIGHT;
--- a/src/x86/cdef.asm
+++ b/src/x86/cdef.asm
@@ -27,34 +27,72 @@
 
 %if ARCH_X86_64
 
-SECTION_RODATA 32
-pd_47130256: dd 4, 7, 1, 3, 0, 2, 5, 6
-div_table: dd 840, 420, 280, 210, 168, 140, 120, 105
-           dd 420, 210, 140, 105
-shufw_6543210x: db 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1, 14, 15
-shufb_lohi: db 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15
-pw_128: times 2 dw 128
-pw_2048: times 2 dw 2048
-tap_table: ; masks for 8 bit shifts
-           db 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01
-           ; weights
-           db 4, 2, 3, 3, 2, 1
-           db -1 * 16 + 1, -2 * 16 + 2
-           db  0 * 16 + 1, -1 * 16 + 2
-           db  0 * 16 + 1,  0 * 16 + 2
-           db  0 * 16 + 1,  1 * 16 + 2
-           db  1 * 16 + 1,  2 * 16 + 2
-           db  1 * 16 + 0,  2 * 16 + 1
-           db  1 * 16 + 0,  2 * 16 + 0
-           db  1 * 16 + 0,  2 * 16 - 1
-           ; the last 6 are repeats of the first 6 so we don't need to & 7
-           db -1 * 16 + 1, -2 * 16 + 2
-           db  0 * 16 + 1, -1 * 16 + 2
-           db  0 * 16 + 1,  0 * 16 + 2
-           db  0 * 16 + 1,  1 * 16 + 2
-           db  1 * 16 + 1,  2 * 16 + 2
-           db  1 * 16 + 0,  2 * 16 + 1
+%macro DUP4 1-*
+    %rep %0
+        times 4 db %1
+        %rotate 1
+    %endrep
+%endmacro
 
+%macro DIRS 16 ; cdef_directions[]
+    %rep 4 + 16 + 4 ; 6 7   0 1 2 3 4 5 6 7   0 1
+        ; masking away unused bits allows us to use a single vpaddd {1to16}
+        ; instruction instead of having to do vpbroadcastd + paddb
+        db %13 & 0x3f, -%13 & 0x3f
+        %rotate 1
+    %endrep
+%endmacro
+
+SECTION_RODATA 64
+
+lut_perm_4x4:  db 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79
+               db 16, 17,  0,  1,  2,  3,  4,  5, 18, 19,  8,  9, 10, 11, 12, 13
+               db 20, 21, 80, 81, 82, 83, 84, 85, 22, 23, 32, 33, 34, 35, 36, 37
+               db 98, 99,100,101,102,103,104,105, 50, 51, 52, 53, 54, 55, 56, 57
+edge_mask:     dq 0x00003c3c3c3c0000, 0x00003f3f3f3f0000 ; 0000, 0001
+               dq 0x0000fcfcfcfc0000, 0x0000ffffffff0000 ; 0010, 0011
+               dq 0x00003c3c3c3c3c3c, 0x00003f3f3f3f3f3f ; 0100, 0101
+               dq 0x0000fcfcfcfcfcfc, 0x0000ffffffffffff ; 0110, 0111
+               dq 0x3c3c3c3c3c3c0000, 0x3f3f3f3f3f3f0000 ; 1000, 1001
+               dq 0xfcfcfcfcfcfc0000, 0xffffffffffff0000 ; 1010, 1011
+               dq 0x3c3c3c3c3c3c3c3c, 0x3f3f3f3f3f3f3f3f ; 1100, 1101
+               dq 0xfcfcfcfcfcfcfcfc, 0xffffffffffffffff ; 1110, 1111
+px_idx:      DUP4 18, 19, 20, 21, 26, 27, 28, 29, 34, 35, 36, 37, 42, 43, 44, 45
+cdef_dirs:   DIRS -7,-14,  1, -6,  1,  2,  1, 10,  9, 18,  8, 17,  8, 16,  8, 15
+gf_shr:        dq 0x0102040810204080, 0x0102040810204080 ; >> 0, >> 0
+               dq 0x0204081020408000, 0x0408102040800000 ; >> 1, >> 2
+               dq 0x0810204080000000, 0x1020408000000000 ; >> 3, >> 4
+               dq 0x2040800000000000, 0x4080000000000000 ; >> 5, >> 6
+end_perm:      db  1,  5,  9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61
+pri_tap:       db 64, 64, 32, 32, 48, 48, 48, 48         ; left-shifted by 4
+sec_tap:       db 32, 32, 16, 16
+pd_268435568:  dd 268435568
+div_table:     dd 840, 420, 280, 210, 168, 140, 120, 105, 420, 210, 140, 105
+pd_47130256:   dd  4,  7,  1,  3,  0,  2,  5,  6
+shufw_6543210x:db 12, 13, 10, 11,  8,  9,  6,  7,  4,  5,  2,  3,  0,  1, 14, 15
+shufb_lohi:    db  0,  8,  1,  9,  2, 10,  3, 11,  4, 12,  5, 13,  6, 14,  7, 15
+pw_128:        times 2 dw 128
+pw_2048:       times 2 dw 2048
+tap_table:     ; masks for 8 bit shifts
+               db 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01
+               ; weights
+               db  4,  2,  3,  3,  2,  1
+               db -1 * 16 + 1, -2 * 16 + 2
+               db  0 * 16 + 1, -1 * 16 + 2
+               db  0 * 16 + 1,  0 * 16 + 2
+               db  0 * 16 + 1,  1 * 16 + 2
+               db  1 * 16 + 1,  2 * 16 + 2
+               db  1 * 16 + 0,  2 * 16 + 1
+               db  1 * 16 + 0,  2 * 16 + 0
+               db  1 * 16 + 0,  2 * 16 - 1
+               ; the last 6 are repeats of the first 6 so we don't need to & 7
+               db -1 * 16 + 1, -2 * 16 + 2
+               db  0 * 16 + 1, -1 * 16 + 2
+               db  0 * 16 + 1,  0 * 16 + 2
+               db  0 * 16 + 1,  1 * 16 + 2
+               db  1 * 16 + 1,  2 * 16 + 2
+               db  1 * 16 + 0,  2 * 16 + 1
+
 SECTION .text
 
 %macro ACCUMULATE_TAP 7 ; tap_offset, shift, mask, strength, mul_tap, w, stride
@@ -679,4 +717,170 @@
     psrld          xm2, 10
     movd        [varq], xm2
     RET
+
+%if WIN64
+DECLARE_REG_TMP 5, 6
+%else
+DECLARE_REG_TMP 8, 5
+%endif
+
+; lut:
+; t0 t1 t2 t3 t4 t5 t6 t7
+; T0 T1 T2 T3 T4 T5 T6 T7
+; L0 L1 00 01 02 03 04 05
+; L2 L3 10 11 12 13 14 15
+; L4 L5 20 21 22 23 24 25
+; L6 L7 30 31 32 33 34 35
+; 4e 4f 40 41 42 43 44 45
+; 5e 5f 50 51 52 53 54 55
+
+INIT_ZMM avx512icl
+cglobal cdef_filter_4x4, 4, 8, 13, dst, stride, left, top, pri, sec, dir, damping, edge
+%define base r7-edge_mask
+    movq         xmm0, [dstq+strideq*0]
+    movhps       xmm0, [dstq+strideq*1]
+    lea            r7, [edge_mask]
+    movq         xmm1, [topq+strideq*0-2]
+    movhps       xmm1, [topq+strideq*1-2]
+    mov           r6d, edgem
+    vinserti32x4  ym0, ymm0, [leftq], 1
+    lea            r2, [strideq*3]
+    vinserti32x4  ym1, ymm1, [dstq+strideq*2], 1
+    mova           m5, [base+lut_perm_4x4]
+    vinserti32x4   m0, [dstq+r2], 2
+    test          r6b, 0x08      ; avoid buffer overread
+    jz .main
+    lea            r3, [dstq+strideq*4-4]
+    vinserti32x4   m1, [r3+strideq*0], 2
+    vinserti32x4   m0, [r3+strideq*1], 3
+.main:
+    movifnidn    prid, prim
+    mov           t0d, dirm
+    mova           m3, [base+px_idx]
+    mov           r3d, dampingm
+    vpermi2b       m5, m0, m1    ; lut
+    vpbroadcastd   m0, [base+pd_268435568] ; (1 << 28) + (7 << 4)
+    pxor           m7, m7
+    lea            r3, [r7+r3*8] ; gf_shr + (damping - 30) * 8
+    vpermb         m6, m3, m5    ; px
+    cmp           r6d, 0x0f
+    jne .mask_edges              ; mask edges only if required
+    test         prid, prid
+    jz .sec_only
+    vpaddd         m1, m3, [base+cdef_dirs+(t0+2)*4] {1to16} ; dir
+    vpermb         m1, m1, m5    ; k0p0 k0p1 k1p0 k1p1
+%macro CDEF_FILTER_4x4_PRI 0
+    vpcmpub        k1, m6, m1, 6 ; px > pN
+    psubb          m2, m1, m6
+    lzcnt         r6d, prid
+    vpsubb     m2{k1}, m6, m1    ; abs(diff)
+    vpbroadcastb   m4, prim
+    and          prid, 1
+    vgf2p8affineqb m9, m2, [r3+r6*8] {1to8}, 0 ; abs(diff) >> shift
+    movifnidn     t1d, secm
+    vpbroadcastd  m10, [base+pri_tap+priq*4]
+    vpsubb    m10{k1}, m7, m10   ; apply_sign(pri_tap)
+    psubusb        m4, m9        ; imax(0, pri_strength - (abs(diff) >> shift)))
+    pminub         m2, m4
+    vpdpbusd       m0, m2, m10   ; sum
+%endmacro
+    CDEF_FILTER_4x4_PRI
+    test          t1d, t1d       ; sec
+    jz .end_no_clip
+    call .sec
+.end_clip:
+    pminub         m4, m6, m1
+    pmaxub         m1, m6
+    pminub         m5, m2, m3
+    pmaxub         m2, m3
+    pminub         m4, m5
+    pmaxub         m2, m1
+    psrldq         m1, m4, 2
+    psrldq         m3, m2, 2
+    pminub         m1, m4
+    vpcmpw         k1, m0, m7, 1
+    vpshldd        m6, m0, 8
+    pmaxub         m2, m3
+    pslldq         m3, m1, 1
+    psubw          m7, m0
+    paddusw        m0, m6     ; clip >0xff
+    vpsubusw   m0{k1}, m6, m7 ; clip <0x00
+    pslldq         m4, m2, 1
+    pminub         m1, m3
+    pmaxub         m2, m4
+    pmaxub         m0, m1
+    pminub         m0, m2
+    jmp .end
+.sec_only:
+    movifnidn     t1d, secm
+    call .sec
+.end_no_clip:
+    vpshldd        m6, m0, 8  ; (px << 8) + ((sum > -8) << 4)
+    paddw          m0, m6     ; (px << 8) + ((sum + (sum > -8) + 7) << 4)
+.end:
+    mova          xm1, [base+end_perm]
+    vpermb         m0, m1, m0 ; output in bits 8-15 of each dword
+    movd   [dstq+strideq*0], xm0
+    pextrd [dstq+strideq*1], xm0, 1
+    pextrd [dstq+strideq*2], xm0, 2
+    pextrd [dstq+r2       ], xm0, 3
+    RET
+.mask_edges_sec_only:
+    movifnidn     t1d, secm
+    call .mask_edges_sec
+    jmp .end_no_clip
+ALIGN function_align
+.mask_edges:
+    vpbroadcastq   m8, [base+edge_mask+r6*8]
+    test         prid, prid
+    jz .mask_edges_sec_only
+    vpaddd         m2, m3, [base+cdef_dirs+(t0+2)*4] {1to16}
+    vpshufbitqmb   k1, m8, m2 ; index in-range
+    mova           m1, m6
+    vpermb     m1{k1}, m2, m5
+    CDEF_FILTER_4x4_PRI
+    test          t1d, t1d
+    jz .end_no_clip
+    call .mask_edges_sec
+    jmp .end_clip
+.mask_edges_sec:
+    vpaddd         m4, m3, [base+cdef_dirs+(t0+4)*4] {1to16}
+    vpaddd         m9, m3, [base+cdef_dirs+(t0+0)*4] {1to16}
+    vpshufbitqmb   k1, m8, m4
+    mova           m2, m6
+    vpermb     m2{k1}, m4, m5
+    vpshufbitqmb   k1, m8, m9
+    mova           m3, m6
+    vpermb     m3{k1}, m9, m5
+    jmp .sec_main
+ALIGN function_align
+.sec:
+    vpaddd         m2, m3, [base+cdef_dirs+(t0+4)*4] {1to16} ; dir + 2
+    vpaddd         m3,     [base+cdef_dirs+(t0+0)*4] {1to16} ; dir - 2
+    vpermb         m2, m2, m5 ; k0s0 k0s1 k1s0 k1s1
+    vpermb         m3, m3, m5 ; k0s2 k0s3 k1s2 k1s3
+.sec_main:
+    vpbroadcastd   m8, [base+sec_tap]
+    vpcmpub        k1, m6, m2, 6
+    psubb          m4, m2, m6
+    vpbroadcastb  m12, t1d
+    lzcnt         t1d, t1d
+    vpsubb     m4{k1}, m6, m2
+    vpcmpub        k2, m6, m3, 6
+    vpbroadcastq  m11, [r3+t1*8]
+    gf2p8affineqb m10, m4, m11, 0
+    psubb          m5, m3, m6
+    mova           m9, m8
+    vpsubb     m8{k1}, m7, m8
+    psubusb       m10, m12, m10
+    vpsubb     m5{k2}, m6, m3
+    pminub         m4, m10
+    vpdpbusd       m0, m4, m8
+    gf2p8affineqb m11, m5, m11, 0
+    vpsubb     m9{k2}, m7, m9
+    psubusb       m12, m11
+    pminub         m5, m12
+    vpdpbusd       m0, m5, m9
+    ret
+
 %endif ; ARCH_X86_64
--- a/src/x86/cdef_init_tmpl.c
+++ b/src/x86/cdef_init_tmpl.c
@@ -28,21 +28,17 @@
 #include "src/cpu.h"
 #include "src/cdef.h"
 
-decl_cdef_fn(dav1d_cdef_filter_8x8_avx2);
-decl_cdef_fn(dav1d_cdef_filter_8x8_sse4);
-decl_cdef_fn(dav1d_cdef_filter_8x8_ssse3);
-decl_cdef_fn(dav1d_cdef_filter_8x8_sse2);
+#define decl_cdef_size_fn(sz) \
+    decl_cdef_fn(dav1d_cdef_filter_##sz##_avx512icl); \
+    decl_cdef_fn(dav1d_cdef_filter_##sz##_avx2); \
+    decl_cdef_fn(dav1d_cdef_filter_##sz##_sse4); \
+    decl_cdef_fn(dav1d_cdef_filter_##sz##_ssse3); \
+    decl_cdef_fn(dav1d_cdef_filter_##sz##_sse2)
 
-decl_cdef_fn(dav1d_cdef_filter_4x8_avx2);
-decl_cdef_fn(dav1d_cdef_filter_4x8_sse4);
-decl_cdef_fn(dav1d_cdef_filter_4x8_ssse3);
-decl_cdef_fn(dav1d_cdef_filter_4x8_sse2);
+decl_cdef_size_fn(4x4);
+decl_cdef_size_fn(4x8);
+decl_cdef_size_fn(8x8);
 
-decl_cdef_fn(dav1d_cdef_filter_4x4_avx2);
-decl_cdef_fn(dav1d_cdef_filter_4x4_sse4);
-decl_cdef_fn(dav1d_cdef_filter_4x4_ssse3);
-decl_cdef_fn(dav1d_cdef_filter_4x4_sse2);
-
 decl_cdef_dir_fn(dav1d_cdef_dir_avx2);
 decl_cdef_dir_fn(dav1d_cdef_dir_sse4);
 decl_cdef_dir_fn(dav1d_cdef_dir_ssse3);
@@ -76,12 +72,21 @@
     c->fb[2] = dav1d_cdef_filter_4x4_sse4;
 #endif
 
+#if ARCH_X86_64
     if (!(flags & DAV1D_X86_CPU_FLAG_AVX2)) return;
 
-#if BITDEPTH == 8 && ARCH_X86_64
+#if BITDEPTH == 8
     c->dir = dav1d_cdef_dir_avx2;
     c->fb[0] = dav1d_cdef_filter_8x8_avx2;
     c->fb[1] = dav1d_cdef_filter_4x8_avx2;
     c->fb[2] = dav1d_cdef_filter_4x4_avx2;
+#endif
+
+    if (!(flags & DAV1D_X86_CPU_FLAG_AVX512ICL)) return;
+
+#if BITDEPTH == 8
+    c->fb[2] = dav1d_cdef_filter_4x4_avx512icl;
+#endif
+
 #endif
 }
--- a/tests/checkasm/cdef.c
+++ b/tests/checkasm/cdef.c
@@ -48,7 +48,7 @@
     ALIGN_STK_64(pixel, c_src,   16 * 10 + 16, ), *const c_dst = c_src + 8;
     ALIGN_STK_64(pixel, a_src,   16 * 10 + 16, ), *const a_dst = a_src + 8;
     ALIGN_STK_64(pixel, top_buf, 16 *  2 + 16, ), *const top = top_buf + 8;
-    pixel left[8][2];
+    ALIGN_STK_16(pixel, left, 8,[2]);
     const ptrdiff_t stride = 16 * sizeof(pixel);
 
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel (*left)[2],