shithub: dav1d

Download patch

ref: 91568b2a57de65aad9f9f73fcc8cb71fdc507e9b
parent: 75e88fab368c21cc0089222e5e08c9a15b369885
author: Victorien Le Couviour--Tuffet <victorien.lecouviour.tuffet@gmail.com>
date: Thu Mar 21 13:11:00 EDT 2019

x86: cdef_dir: optimize best cost finding for SSE

Port of 65ee1233cf86f03e029d0520f7cc5a3e152d3bbd for AVX-2
from Kyle Siefring to SSE41, and optimize SSSE3.

---------------------
x86_64:
------------------------------------------
before: cdef_dir_8bpc_ssse3: 110.3
 after: cdef_dir_8bpc_ssse3: 105.9
   new: cdef_dir_8bpc_sse4:   96.4
------------------------------------------

---------------------
x86_32:
------------------------------------------
before: cdef_dir_8bpc_ssse3: 120.6
 after: cdef_dir_8bpc_ssse3: 110.7
   new: cdef_dir_8bpc_sse4:  106.5
------------------------------------------

--- a/src/x86/cdef_init_tmpl.c
+++ b/src/x86/cdef_init_tmpl.c
@@ -41,6 +41,7 @@
 decl_cdef_fn(dav1d_cdef_filter_4x4_ssse3);
 
 decl_cdef_dir_fn(dav1d_cdef_dir_avx2);
+decl_cdef_dir_fn(dav1d_cdef_dir_sse4);
 decl_cdef_dir_fn(dav1d_cdef_dir_ssse3);
 
 void bitfn(dav1d_cdef_dsp_init_x86)(Dav1dCdefDSPContext *const c) {
@@ -58,6 +59,7 @@
     if (!(flags & DAV1D_X86_CPU_FLAG_SSE41)) return;
 
 #if BITDEPTH == 8
+    c->dir = dav1d_cdef_dir_sse4;
     c->fb[0] = dav1d_cdef_filter_8x8_sse4;
     c->fb[1] = dav1d_cdef_filter_4x8_sse4;
     c->fb[2] = dav1d_cdef_filter_4x4_sse4;
--- a/src/x86/cdef_sse.asm
+++ b/src/x86/cdef_sse.asm
@@ -40,9 +40,10 @@
 pw_0x7FFF: times 8 dw 0x7FFF
 pw_0x8000: times 8 dw 0x8000
 %endif
-pd_0to7: dd 0, 4, 2, 6, 1, 5, 3, 7
-div_table: dw 840, 840, 420, 420, 280, 280, 210, 210, 168, 168, 140, 140, 120, 120, 105, 105
-           dw 420, 420, 210, 210, 140, 140, 105, 105, 105, 105, 105, 105, 105, 105, 105, 105
+div_table_sse4: dd 840, 420, 280, 210, 168, 140, 120, 105
+                dd 420, 210, 140, 105, 105, 105, 105, 105
+div_table_ssse3: dw 840, 840, 420, 420, 280, 280, 210, 210, 168, 168, 140, 140, 120, 120, 105, 105
+                 dw 420, 420, 210, 210, 140, 140, 105, 105, 105, 105, 105, 105, 105, 105, 105, 105
 shufb_lohi: db 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15
 shufw_6543210x: db 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1, 14, 15
 tap_table: ; masks for 8-bit shift emulation
@@ -746,18 +747,22 @@
 %endmacro
 
 %macro MULLD 2
- %if ARCH_X86_32
-  %define m15 m1
- %endif
+ %if cpuflag(sse4)
+    pmulld          %1, %2
+ %else
+  %if ARCH_X86_32
+   %define m15 m1
+  %endif
     pmulhuw        m15, %1, %2
     pmullw          %1, %2
     pslld          m15, 16
     paddd           %1, m15
+ %endif
 %endmacro
 
-INIT_XMM ssse3
-%if ARCH_X86_64
-cglobal cdef_dir, 3, 4, 16, src, stride, var, stride3
+%macro CDEF_DIR 0
+ %if ARCH_X86_64
+cglobal cdef_dir, 3, 5, 16, 32, src, stride, var, stride3
     lea       stride3q, [strideq*3]
     movq            m1, [srcq+strideq*0]
     movhps          m1, [srcq+strideq*1]
@@ -812,7 +817,7 @@
     pmaddwd         m9, m9
     phaddd          m9, m8
     SWAP            m8, m9
-    MULLD           m8, [div_table+48]
+    MULLD           m8, [div_table%+SUFFIX+48]
 
     pslldq          m9, m1, 2
     psrldq         m10, m1, 14
@@ -846,8 +851,8 @@
     punpcklwd       m9, m10
     pmaddwd        m11, m11
     pmaddwd         m9, m9
-    MULLD          m11, [div_table+16]
-    MULLD           m9, [div_table+0]
+    MULLD          m11, [div_table%+SUFFIX+16]
+    MULLD           m9, [div_table%+SUFFIX+0]
     paddd           m9, m11                 ; cost[0a-d]
 
     pslldq         m10, m0, 14
@@ -882,8 +887,8 @@
     punpcklwd      m10, m11
     pmaddwd        m12, m12
     pmaddwd        m10, m10
-    MULLD          m12, [div_table+16]
-    MULLD          m10, [div_table+0]
+    MULLD          m12, [div_table%+SUFFIX+16]
+    MULLD          m10, [div_table%+SUFFIX+0]
     paddd          m10, m12                 ; cost[4a-d]
     phaddd          m9, m10                 ; cost[0a/b,4a/b]
 
@@ -908,14 +913,14 @@
     paddw           m4, m6
     paddw           m5, m15                 ; partial_sum_alt[3] right
     paddw           m4, m14                 ; partial_sum_alt[3] left
-    pshuflw         m5, m5, q3012
-    punpckhwd       m6, m4, m5
-    punpcklwd       m4, m5
-    pmaddwd         m6, m6
+    pshuflw         m6, m5, q3012
+    punpckhwd       m5, m4
+    punpcklwd       m4, m6
+    pmaddwd         m5, m5
     pmaddwd         m4, m4
-    MULLD           m6, [div_table+48]
-    MULLD           m4, [div_table+32]
-    paddd           m4, m6                  ; cost[7a-d]
+    MULLD           m5, [div_table%+SUFFIX+48]
+    MULLD           m4, [div_table%+SUFFIX+32]
+    paddd           m4, m5                  ; cost[7a-d]
 
     pslldq          m5, m10, 6
     psrldq          m6, m10, 10
@@ -928,14 +933,14 @@
     paddw           m5, m11
     paddw           m6, m12
     paddw           m5, m13
-    pshuflw         m6, m6, q3012
-    punpckhwd       m7, m5, m6
-    punpcklwd       m5, m6
-    pmaddwd         m7, m7
+    pshuflw         m7, m6, q3012
+    punpckhwd       m6, m5
+    punpcklwd       m5, m7
+    pmaddwd         m6, m6
     pmaddwd         m5, m5
-    MULLD           m7, [div_table+48]
-    MULLD           m5, [div_table+32]
-    paddd           m5, m7                  ; cost[5a-d]
+    MULLD           m6, [div_table%+SUFFIX+48]
+    MULLD           m5, [div_table%+SUFFIX+32]
+    paddd           m5, m6                  ; cost[5a-d]
 
     pslldq          m6, m1, 2
     psrldq          m7, m1, 14
@@ -948,14 +953,14 @@
     paddw           m6, m10
     paddw           m7, m13                 ; partial_sum_alt[3] right
     paddw           m6, m12                 ; partial_sum_alt[3] left
-    pshuflw         m7, m7, q3012
-    punpckhwd      m10, m6, m7
-    punpcklwd       m6, m7
-    pmaddwd        m10, m10
+    pshuflw        m10, m7, q3012
+    punpckhwd       m7, m6
+    punpcklwd       m6, m10
+    pmaddwd         m7, m7
     pmaddwd         m6, m6
-    MULLD          m10, [div_table+48]
-    MULLD           m6, [div_table+32]
-    paddd           m6, m10                 ; cost[1a-d]
+    MULLD           m7, [div_table%+SUFFIX+48]
+    MULLD           m6, [div_table%+SUFFIX+32]
+    paddd           m6, m7                  ; cost[1a-d]
 
     pshufd          m0, m0, q1032
     pshufd          m1, m1, q1032
@@ -973,63 +978,64 @@
     paddw          m10, m14
     paddw          m11, m2
     paddw          m10, m3
-    pshuflw        m11, m11, q3012
-    punpckhwd      m12, m10, m11
-    punpcklwd      m10, m11
-    pmaddwd        m12, m12
+    pshuflw        m12, m11, q3012
+    punpckhwd      m11, m10
+    punpcklwd      m10, m12
+    pmaddwd        m11, m11
     pmaddwd        m10, m10
-    MULLD          m12, [div_table+48]
-    MULLD          m10, [div_table+32]
-    paddd          m10, m12                 ; cost[3a-d]
+    MULLD          m11, [div_table%+SUFFIX+48]
+    MULLD          m10, [div_table%+SUFFIX+32]
+    paddd          m10, m11                 ; cost[3a-d]
 
-    phaddd          m0, m9, m8              ; cost[0,4,2,6]
-    phaddd          m6, m5
-    phaddd         m10, m4
-    phaddd          m1, m6, m10             ; cost[1,5,3,7]
+    phaddd          m9, m8                  ; cost[0,4,2,6]
+    phaddd          m6, m10
+    phaddd          m5, m4
+    phaddd          m6, m5                  ; cost[1,3,5,7]
+    pshufd          m4, m9, q3120
 
-    pcmpgtd         m2, m1, m0              ; [1/5/3/7] > [0/4/2/6]
-    pand            m3, m2, m1
-    pandn           m4, m2, m0
-    por             m3, m4                  ; higher 4 values
-    pshufd          m1, m1, q2301
-    pshufd          m0, m0, q2301
-    pand            m1, m2, m1
-    pandn           m4, m2, m0
-    por             m0, m4, m1              ; 4 values at idx^4 offset
-    pand           m14, m2, [pd_0to7+16]
-    pandn          m15, m2, [pd_0to7]
-    por            m15, m14
+    ; now find the best cost
+  %if cpuflag(sse4)
+    pmaxsd          m9, m6
+    pshufd          m0, m9, q1032
+    pmaxsd          m0, m9
+    pshufd          m1, m0, q2301
+    pmaxsd          m0, m1                  ; best cost
+  %else
+    pcmpgtd         m0, m9, m6
+    pand            m9, m0
+    pandn           m0, m6
+    por             m9, m0
+    pshufd          m1, m9, q1032
+    pcmpgtd         m0, m9, m1
+    pand            m9, m0
+    pandn           m0, m1
+    por             m9, m0
+    pshufd          m1, m9, q2301
+    pcmpgtd         m0, m9, m1
+    pand            m9, m0
+    pandn           m0, m1
+    por             m0, m9
+  %endif
 
-    punpckhqdq      m4, m3, m0
-    punpcklqdq      m3, m0
-    pcmpgtd         m5, m4, m3              ; [2or3-6or7] > [0or1/4or5]
-    punpcklqdq      m5, m5
-    pand            m6, m5, m4
-    pandn           m7, m5, m3
-    por             m6, m7                  ; { highest 2 values, complements at idx^4 }
-    movhlps        m14, m15
-    pand           m14, m5, m14
-    pandn          m13, m5, m15
-    por            m15, m13, m14
-
-    pshufd          m7, m6, q3311
-    pcmpgtd         m8, m7, m6              ; [4or5or6or7] > [0or1or2or3]
-    punpcklqdq      m8, m8
-    pand            m9, m8, m7
-    pandn          m10, m8, m6
-    por             m9, m10                 ; max
-    movhlps        m10, m9                  ; complement at idx^4
-    psubd           m9, m10
-    psrld           m9, 10
-    movd        [varq], m9
-    pshufd         m14, m15, q1111
-    pand           m14, m8, m14
-    pandn          m13, m8, m15
-    por            m15, m13, m14
-    movd           eax, m15
-%else
+    ; get direction and variance
+    punpckhdq       m1, m4, m6
+    punpckldq       m4, m6
+    psubd           m2, m0, m1
+    psubd           m3, m0, m4
+    mova    [rsp+0x00], m2                  ; emulate ymm in stack
+    mova    [rsp+0x10], m3
+    pcmpeqd         m1, m0                  ; compute best cost mask
+    pcmpeqd         m4, m0
+    packssdw        m4, m1
+    pmovmskb       eax, m4                  ; get byte-idx from mask
+    tzcnt          eax, eax
+    mov            r1d, [rsp+rax*2]         ; get idx^4 complement from emulated ymm
+    shr            eax, 1                   ; get direction by converting byte-idx to word-idx
+    shr            r1d, 10
+    mov         [varq], r1d
+ %else
 cglobal cdef_dir, 3, 5, 16, 96, src, stride, var, stride3
- %define PIC_reg r4
+  %define PIC_reg r4
     LEA        PIC_reg, PIC_base_offset
 
     pxor            m0, m0
@@ -1092,7 +1098,7 @@
     pmaddwd         m0, m0
 
     phaddd          m2, m0
-    MULLD           m2, [PIC_sym(div_table)+48]
+    MULLD           m2, [PIC_sym(div_table%+SUFFIX)+48]
     mova    [esp+0x30], m2
 
     mova            m1, [esp+0x10]
@@ -1130,8 +1136,8 @@
     punpcklwd       m0, m1
     pmaddwd         m2, m2
     pmaddwd         m0, m0
-    MULLD           m2, [PIC_sym(div_table)+16]
-    MULLD           m0, [PIC_sym(div_table)+0]
+    MULLD           m2, [PIC_sym(div_table%+SUFFIX)+16]
+    MULLD           m0, [PIC_sym(div_table%+SUFFIX)+0]
     paddd           m0, m2                  ; cost[0a-d]
     mova    [esp+0x40], m0
 
@@ -1171,8 +1177,8 @@
     punpcklwd       m0, m1
     pmaddwd         m2, m2
     pmaddwd         m0, m0
-    MULLD           m2, [PIC_sym(div_table)+16]
-    MULLD           m0, [PIC_sym(div_table)+0]
+    MULLD           m2, [PIC_sym(div_table%+SUFFIX)+16]
+    MULLD           m0, [PIC_sym(div_table%+SUFFIX)+0]
     paddd           m0, m2                  ; cost[4a-d]
     phaddd          m1, [esp+0x40], m0      ; cost[0a/b,4a/b]
     phaddd          m1, [esp+0x30]          ; cost[0,4,2,6]
@@ -1208,8 +1214,8 @@
     punpcklwd       m0, m1
     pmaddwd         m2, m2
     pmaddwd         m0, m0
-    MULLD           m2, [PIC_sym(div_table)+48]
-    MULLD           m0, [PIC_sym(div_table)+32]
+    MULLD           m2, [PIC_sym(div_table%+SUFFIX)+48]
+    MULLD           m0, [PIC_sym(div_table%+SUFFIX)+32]
     paddd           m0, m2                  ; cost[7a-d]
     mova    [esp+0x40], m0
 
@@ -1224,44 +1230,44 @@
     paddw           m0, m1
     paddw           m7, m4
     paddw           m0, m2
-    pshuflw         m7, m7, q3012
-    punpckhwd       m2, m0, m7
-    punpcklwd       m0, m7
-    pmaddwd         m2, m2
+    pshuflw         m2, m7, q3012
+    punpckhwd       m7, m0
+    punpcklwd       m0, m2
+    pmaddwd         m7, m7
     pmaddwd         m0, m0
-    MULLD           m2, [PIC_sym(div_table)+48]
-    MULLD           m0, [PIC_sym(div_table)+32]
-    paddd           m0, m2                  ; cost[5a-d]
+    MULLD           m7, [PIC_sym(div_table%+SUFFIX)+48]
+    MULLD           m0, [PIC_sym(div_table%+SUFFIX)+32]
+    paddd           m0, m7                  ; cost[5a-d]
     mova    [esp+0x50], m0
 
-    mova            m1, [esp+0x10]
+    mova            m7, [esp+0x10]
     mova            m2, [esp+0x20]
-    pslldq          m0, m1, 2
-    psrldq          m1, 14
+    pslldq          m0, m7, 2
+    psrldq          m7, 14
     pslldq          m4, m2, 4
     psrldq          m2, 12
     pslldq          m5, m3, 6
     psrldq          m6, m3, 10
     paddw           m0, [esp+0x00]
-    paddw           m1, m2
+    paddw           m7, m2
     paddw           m4, m5
-    paddw           m1, m6                  ; partial_sum_alt[3] right
+    paddw           m7, m6                  ; partial_sum_alt[3] right
     paddw           m0, m4                  ; partial_sum_alt[3] left
-    pshuflw         m1, m1, q3012
-    punpckhwd       m2, m0, m1
-    punpcklwd       m0, m1
-    pmaddwd         m2, m2
+    pshuflw         m2, m7, q3012
+    punpckhwd       m7, m0
+    punpcklwd       m0, m2
+    pmaddwd         m7, m7
     pmaddwd         m0, m0
-    MULLD           m2, [PIC_sym(div_table)+48]
-    MULLD           m0, [PIC_sym(div_table)+32]
-    paddd           m0, m2                  ; cost[1a-d]
-    phaddd          m0, [esp+0x50]
-    mova    [esp+0x50], m0
+    MULLD           m7, [PIC_sym(div_table%+SUFFIX)+48]
+    MULLD           m0, [PIC_sym(div_table%+SUFFIX)+32]
+    paddd           m0, m7                  ; cost[1a-d]
+    SWAP            m0, m4
 
     pshufd          m0, [esp+0x00], q1032
     pshufd          m1, [esp+0x10], q1032
     pshufd          m2, [esp+0x20], q1032
     pshufd          m3, m3, q1032
+    mova    [esp+0x00], m4
 
     pslldq          m4, m0, 6
     psrldq          m0, 10
@@ -1274,70 +1280,76 @@
     paddw           m5, m6
     paddw           m0, m2
     paddw           m4, m5
-    pshuflw         m0, m0, q3012
-    punpckhwd      m2, m4, m0
-    punpcklwd      m4, m0
-    pmaddwd        m2, m2
-    pmaddwd        m4, m4
-    MULLD          m2, [PIC_sym(div_table)+48]
-    MULLD          m4, [PIC_sym(div_table)+32]
-    paddd          m4, m2                   ; cost[3a-d]
-    phaddd         m4, [esp+0x40]
+    pshuflw         m2, m0, q3012
+    punpckhwd       m0, m4
+    punpcklwd       m4, m2
+    pmaddwd         m0, m0
+    pmaddwd         m4, m4
+    MULLD           m0, [PIC_sym(div_table%+SUFFIX)+48]
+    MULLD           m4, [PIC_sym(div_table%+SUFFIX)+32]
+    paddd           m4, m0                   ; cost[3a-d]
 
-    mova            m1, [esp+0x50]
+    mova            m1, [esp+0x00]
+    mova            m2, [esp+0x50]
     mova            m0, [esp+0x30]          ; cost[0,4,2,6]
-    phaddd          m1, m4                  ; cost[1,5,3,7]
+    phaddd          m1, m4
+    phaddd          m2, [esp+0x40]          ; cost[1,3,5,7]
+    phaddd          m1, m2
+    pshufd          m2, m0, q3120
 
-    pcmpgtd         m2, m1, m0              ; [1/5/3/7] > [0/4/2/6]
-    pand            m3, m2, m1
-    pandn           m4, m2, m0
-    por             m3, m4                  ; higher 4 values
-    pshufd          m1, m1, q2301
-    pshufd          m0, m0, q2301
-    pand            m1, m2, m1
-    pandn           m4, m2, m0
-    por             m0, m4, m1              ; 4 values at idx^4 offset
-    pand            m5, m2, [PIC_sym(pd_0to7)+16]
-    pandn           m6, m2, [PIC_sym(pd_0to7)]
-    por             m6, m5
+    ; now find the best cost
+  %if cpuflag(sse4)
+    pmaxsd          m0, m1
+    pshufd          m3, m0, q1032
+    pmaxsd          m3, m0
+    pshufd          m0, m3, q2301
+    pmaxsd          m0, m3
+  %else
+    pcmpgtd         m3, m0, m1
+    pand            m0, m3
+    pandn           m3, m1
+    por             m0, m3
+    pshufd          m4, m0, q1032
+    pcmpgtd         m3, m0, m4
+    pand            m0, m3
+    pandn           m3, m4
+    por             m0, m3
+    pshufd          m4, m0, q2301
+    pcmpgtd         m3, m0, m4
+    pand            m0, m3
+    pandn           m3, m4
+    por             m0, m3
+  %endif
 
-    punpckhqdq      m4, m3, m0
-    punpcklqdq      m3, m0
-    pcmpgtd         m0, m4, m3              ; [2or3-6or7] > [0or1/4or5]
-    punpcklqdq      m0, m0
-    pand            m1, m0, m4
-    pandn           m7, m0, m3
-    por             m1, m7                  ; { highest 2 values, complements at idx^4 }
-    movhlps         m5, m6
-    pand            m5, m0, m5
-    pandn           m3, m0, m6
-    por             m6, m3, m5
+    ; get direction and variance
+    punpckhdq       m3, m2, m1
+    punpckldq       m2, m1
+    psubd           m1, m0, m3
+    psubd           m4, m0, m2
+    mova    [esp+0x00], m1                  ; emulate ymm in stack
+    mova    [esp+0x10], m4
+    pcmpeqd         m3, m0                  ; compute best cost mask
+    pcmpeqd         m2, m0
+    packssdw        m2, m3
+    pmovmskb       eax, m2                  ; get byte-idx from mask
+    tzcnt          eax, eax
+    mov            r1d, [esp+eax*2]         ; get idx^4 complement from emulated ymm
+    shr            eax, 1                   ; get direction by converting byte-idx to word-idx
+    shr            r1d, 10
+    mov         [vard], r1d
+ %endif
 
-    pshufd          m7, m1, q3311
-    pcmpgtd         m2, m7, m1              ; [4or5or6or7] > [0or1or2or3]
-    punpcklqdq      m2, m2
-    pand            m0, m2, m7
-    pandn           m7, m2, m1
-    por             m0, m7                  ; max
-    movhlps         m7, m0                  ; complement at idx^4
-    psubd           m0, m7
-    psrld           m0, 10
-    movd        [varq], m0
-    pshufd          m5, m6, q1111
-    pand            m5, m2, m5
-    pandn           m3, m2, m6
-    por             m6, m3, m5
-    movd           eax, m6
-%endif
-
     RET
+%endmacro
 
 INIT_XMM sse4
 CDEF_FILTER 8, 8, 32
 CDEF_FILTER 4, 8, 32
 CDEF_FILTER 4, 4, 32
+CDEF_DIR
 
 INIT_XMM ssse3
 CDEF_FILTER 8, 8, 32
 CDEF_FILTER 4, 8, 32
 CDEF_FILTER 4, 4, 32
+CDEF_DIR