shithub: dav1d

Download patch

ref: 46a3fd20e032a740061e222414c4145310893593
parent: e2c6d0295c58c9f1c9ce6570e993530b6bc94b68
author: Ronald S. Bultje <rsbultje@gmail.com>
date: Fri Oct 26 13:00:39 EDT 2018

Add a 4x4 cdef_filter AVX2 implementation

cdef_filter_4x4_8bpc_c: 2273.6
cdef_filter_4x4_8bpc_avx2: 113.6

Decoding time reduces to 15.51s for first 1000 frames of chimera 1080p,
from 23.1 before cdef_filter SIMD or 17.86 with only 8x8 cdef_filter
SIMD.

--- a/src/x86/cdef.asm
+++ b/src/x86/cdef.asm
@@ -56,16 +56,76 @@
 
 SECTION .text
 
+%macro ACCUMULATE_TAP 6 ; tap_offset, shift, strength, mul_tap, w, stride
+    ; load p0/p1
+    movsx         offq, byte [dirq+kq+%1]       ; off1
+%if %5 == 4
+    movq           xm5, [stkq+offq*2+%6*0]      ; p0
+    movq           xm6, [stkq+offq*2+%6*2]
+    movhps         xm5, [stkq+offq*2+%6*1]
+    movhps         xm6, [stkq+offq*2+%6*3]
+    vinserti128     m5, xm6, 1
+%else
+    movu           xm5, [stkq+offq*2+%6*0]      ; p0
+    vinserti128     m5, [stkq+offq*2+%6*1], 1
+%endif
+    neg           offq                          ; -off1
+%if %5 == 4
+    movq           xm6, [stkq+offq*2+%6*0]      ; p1
+    movq           xm9, [stkq+offq*2+%6*2]
+    movhps         xm6, [stkq+offq*2+%6*1]
+    movhps         xm9, [stkq+offq*2+%6*3]
+    vinserti128     m6, xm9, 1
+%else
+    movu           xm6, [stkq+offq*2+%6*0]      ; p1
+    vinserti128     m6, [stkq+offq*2+%6*1], 1
+%endif
+    pcmpeqw         m9, m14, m5
+    pcmpeqw        m10, m14, m6
+    pandn           m9, m5
+    pandn          m10, m6
+    pmaxsw          m7, m9                      ; max after p0
+    pminsw          m8, m5                      ; min after p0
+    pmaxsw          m7, m10                     ; max after p1
+    pminsw          m8, m6                      ; min after p1
+
+    ; accumulate sum[m15] over p0/p1
+    psubw           m5, m4                      ; diff_p0(p0 - px)
+    psubw           m6, m4                      ; diff_p1(p1 - px)
+    pabsw           m9, m5
+    pabsw          m10, m6
+    psraw          m11, m9,  %2
+    psraw          m12, m10, %2
+    psubw          m11, %3, m11
+    psubw          m12, %3, m12
+    pmaxsw         m11, m13
+    pmaxsw         m12, m13
+    pminsw         m11, m9
+    pminsw         m12, m10
+    psignw         m11, m5                      ; constrain(diff_p0)
+    psignw         m12, m6                      ; constrain(diff_p1)
+    pmullw         m11, %4                      ; constrain(diff_p0) * pri_taps
+    pmullw         m12, %4                      ; constrain(diff_p1) * pri_taps
+    paddw          m15, m11
+    paddw          m15, m12
+%endmacro
+
+%macro cdef_filter_fn 3 ; w, h, stride
 INIT_YMM avx2
-cglobal cdef_filter_8x8, 4, 9, 16, 26 * 16, dst, stride, left, top, \
-                                            pri, sec, stride3, dst4, edge
-%define px rsp+32+2*32
+cglobal cdef_filter_%1x%2, 4, 9, 16, 2 * 16 + (%2+4)*%3, \
+                           dst, stride, left, top, pri, sec, stride3, dst4, edge
+%define px rsp+2*16+2*%3
     pcmpeqw        m14, m14
     psrlw          m14, 1                   ; 0x7fff
     mov          edged, r8m
 
     ; prepare pixel buffers - body/right
+%if %1 == 4
+    INIT_XMM avx2
+%endif
+%if %2 == 8
     lea          dst4q, [dstq+strideq*4]
+%endif
     lea       stride3q, [strideq*3]
     test         edged, 2                   ; have_right
     jz .no_right
@@ -73,48 +133,70 @@
     pmovzxbw        m2, [dstq+strideq*1]
     pmovzxbw        m3, [dstq+strideq*2]
     pmovzxbw        m4, [dstq+stride3q]
-    movu     [px+0*32], m1
-    movu     [px+1*32], m2
-    movu     [px+2*32], m3
-    movu     [px+3*32], m4
+    mova     [px+0*%3], m1
+    mova     [px+1*%3], m2
+    mova     [px+2*%3], m3
+    mova     [px+3*%3], m4
+%if %2 == 8
     pmovzxbw        m1, [dst4q+strideq*0]
     pmovzxbw        m2, [dst4q+strideq*1]
     pmovzxbw        m3, [dst4q+strideq*2]
     pmovzxbw        m4, [dst4q+stride3q]
-    movu     [px+4*32], m1
-    movu     [px+5*32], m2
-    movu     [px+6*32], m3
-    movu     [px+7*32], m4
+    mova     [px+4*%3], m1
+    mova     [px+5*%3], m2
+    mova     [px+6*%3], m3
+    mova     [px+7*%3], m4
+%endif
     jmp .body_done
 .no_right:
+%if %1 == 4
+    movd           xm1, [dstq+strideq*0]
+    movd           xm2, [dstq+strideq*2]
+    pinsrd         xm1, [dstq+strideq*1], 1
+    pinsrd         xm2, [dstq+stride3q], 1
+    pmovzxbw       xm1, xm1
+    pmovzxbw       xm2, xm2
+    movq     [px+0*%3], xm1
+    movhps   [px+1*%3], xm1
+    movq     [px+2*%3], xm2
+    movhps   [px+3*%3], xm2
+%else
     pmovzxbw       xm1, [dstq+strideq*0]
     pmovzxbw       xm2, [dstq+strideq*1]
     pmovzxbw       xm3, [dstq+strideq*2]
     pmovzxbw       xm4, [dstq+stride3q]
-    movu     [px+0*32], xm1
-    movu     [px+1*32], xm2
-    movu     [px+2*32], xm3
-    movu     [px+3*32], xm4
-    movd  [px+0*32+16], xm14
-    movd  [px+1*32+16], xm14
-    movd  [px+2*32+16], xm14
-    movd  [px+3*32+16], xm14
+    mova     [px+0*%3], xm1
+    mova     [px+1*%3], xm2
+    mova     [px+2*%3], xm3
+    mova     [px+3*%3], xm4
+%endif
+    movd [px+0*%3+%1*2], xm14
+    movd [px+1*%3+%1*2], xm14
+    movd [px+2*%3+%1*2], xm14
+    movd [px+3*%3+%1*2], xm14
+%if %2 == 8
+    ; FIXME w == 4
+    movd [px+0*%3+%1*2], xm14
+    movd [px+1*%3+%1*2], xm14
+    movd [px+2*%3+%1*2], xm14
+    movd [px+3*%3+%1*2], xm14
     pmovzxbw       xm1, [dst4q+strideq*0]
     pmovzxbw       xm2, [dst4q+strideq*1]
     pmovzxbw       xm3, [dst4q+strideq*2]
     pmovzxbw       xm4, [dst4q+stride3q]
-    movu     [px+4*32], xm1
-    movu     [px+5*32], xm2
-    movu     [px+6*32], xm3
-    movu     [px+7*32], xm4
-    movd  [px+4*32+16], xm14
-    movd  [px+5*32+16], xm14
-    movd  [px+6*32+16], xm14
-    movd  [px+7*32+16], xm14
+    mova     [px+4*%3], xm1
+    mova     [px+5*%3], xm2
+    mova     [px+6*%3], xm3
+    mova     [px+7*%3], xm4
+    movd [px+4*%3+%1*2], xm14
+    movd [px+5*%3+%1*2], xm14
+    movd [px+6*%3+%1*2], xm14
+    movd [px+7*%3+%1*2], xm14
+%endif
 .body_done:
 
     ; top
-    DEFINE_ARGS dst, stride, left, top2, pri, sec, top1, dummy, edge
+    DEFINE_ARGS dst, stride, left, top2, pri, sec, stride3, top1, edge
     test         edged, 4                    ; have_top
     jz .no_top
     mov          top1q, [top2q+0*gprsize]
@@ -123,18 +205,18 @@
     jz .top_no_left
     test         edged, 2                    ; have_right
     jz .top_no_right
-    pmovzxbw        m1, [top1q-4]
-    pmovzxbw        m2, [top2q-4]
-    movu   [px-2*32-8], m1
-    movu   [px-1*32-8], m2
+    pmovzxbw        m1, [top1q-(%1/2)]
+    pmovzxbw        m2, [top2q-(%1/2)]
+    movu  [px-2*%3-%1], m1
+    movu  [px-1*%3-%1], m2
     jmp .top_done
 .top_no_right:
-    pmovzxbw        m1, [top1q-8]
-    pmovzxbw        m2, [top2q-8]
-    movu  [px-2*32-16], m1
-    movu  [px-1*32-16], m2
-    movd  [px-2*32+16], xm14
-    movd  [px-1*32+16], xm14
+    pmovzxbw        m1, [top1q-%1]
+    pmovzxbw        m2, [top2q-%1]
+    movu [px-2*%3-%1*2], m1
+    movu [px-1*%3-%1*2], m2
+    movd [px-2*%3+%1*2], xm14
+    movd [px-1*%3+%1*2], xm14
     jmp .top_done
 .top_no_left:
     test         edged, 2                   ; have_right
@@ -141,24 +223,32 @@
     jz .top_no_left_right
     pmovzxbw        m1, [top1q]
     pmovzxbw        m2, [top2q]
-    movu   [px-2*32+0], m1
-    movu   [px-1*32+0], m2
-    movd   [px-2*32-4], xm14
-    movd   [px-1*32-4], xm14
+    mova   [px-2*%3+0], m1
+    mova   [px-1*%3+0], m2
+    movd   [px-2*%3-4], xm14
+    movd   [px-1*%3-4], xm14
     jmp .top_done
 .top_no_left_right:
+%if %1 == 4
+    movd           xm1, [top1q]
+    pinsrd         xm1, [top2q], 1
+    pmovzxbw       xm1, xm1
+    movq   [px-2*%3+0], xm1
+    movhps [px-1*%3+0], xm1
+%else
     pmovzxbw       xm1, [top1q]
     pmovzxbw       xm2, [top2q]
-    movu   [px-2*32+0], xm1
-    movu   [px-1*32+0], xm2
-    movd   [px-2*32-4], xm14
-    movd   [px-1*32-4], xm14
-    movd  [px-2*32+16], xm14
-    movd  [px-1*32+16], xm14
+    mova   [px-2*%3+0], xm1
+    mova   [px-1*%3+0], xm2
+%endif
+    movd   [px-2*%3-4], xm14
+    movd   [px-1*%3-4], xm14
+    movd [px-2*%3+%1*2], xm14
+    movd [px-1*%3+%1*2], xm14
     jmp .top_done
 .no_top:
-    movu   [px-2*32-8], m14
-    movu   [px-1*32-8], m14
+    movu   [px-2*%3-%1], m14
+    movu   [px-1*%3-%1], m14
 .top_done:
 
     ; left
@@ -165,49 +255,57 @@
     test         edged, 1                   ; have_left
     jz .no_left
     pmovzxbw       xm1, [leftq+ 0]
+%if %2 == 8
     pmovzxbw       xm2, [leftq+ 8]
+%endif
     movd   [px+0*32-4], xm1
     pextrd [px+1*32-4], xm1, 1
     pextrd [px+2*32-4], xm1, 2
     pextrd [px+3*32-4], xm1, 3
+%if %2 == 8
     movd   [px+4*32-4], xm2
     pextrd [px+5*32-4], xm2, 1
     pextrd [px+6*32-4], xm2, 2
     pextrd [px+7*32-4], xm2, 3
+%endif
     jmp .left_done
 .no_left:
-    movd   [px+0*32-4], xm14
-    movd   [px+1*32-4], xm14
-    movd   [px+2*32-4], xm14
-    movd   [px+3*32-4], xm14
-    movd   [px+4*32-4], xm14
-    movd   [px+5*32-4], xm14
-    movd   [px+6*32-4], xm14
-    movd   [px+7*32-4], xm14
+    movd   [px+0*%3-4], xm14
+    movd   [px+1*%3-4], xm14
+    movd   [px+2*%3-4], xm14
+    movd   [px+3*%3-4], xm14
+%if %2 == 8
+    movd   [px+4*%3-4], xm14
+    movd   [px+5*%3-4], xm14
+    movd   [px+6*%3-4], xm14
+    movd   [px+7*%3-4], xm14
+%endif
 .left_done:
 
     ; bottom
-    DEFINE_ARGS dst, stride, dst8, dummy1, pri, sec, dummy2, dummy3, edge
+    DEFINE_ARGS dst, stride, dst8, dummy1, pri, sec, stride3, dummy3, edge
     test         edged, 8                   ; have_bottom
     jz .no_bottom
-    lea          dst8q, [dstq+8*strideq]
+    lea          dst8q, [dstq+%2*strideq]
     test         edged, 1                   ; have_left
     jz .bottom_no_left
     test         edged, 2                   ; have_right
     jz .bottom_no_right
-    pmovzxbw        m1, [dst8q-4]
-    pmovzxbw        m2, [dst8q+strideq-4]
-    movu   [px+8*32-8], m1
-    movu   [px+9*32-8], m2
+    pmovzxbw        m1, [dst8q-(%1/2)]
+    pmovzxbw        m2, [dst8q+strideq-(%1/2)]
+    movu   [px+(%2+0)*%3-%1], m1
+    movu   [px+(%2+1)*%3-%1], m2
     jmp .bottom_done
 .bottom_no_right:
-    pmovzxbw        m1, [dst8q-8]
-    pmovzxbw        m2, [dst8q+strideq-8]
-    movu  [px+8*32-16], m1
-    movu  [px+9*32-16], m2
-    movd  [px+7*32+16], xm14                ; overwritten by previous movu
-    movd  [px+8*32+16], xm14
-    movd  [px+9*32+16], xm14
+    pmovzxbw        m1, [dst8q-%1]
+    pmovzxbw        m2, [dst8q+strideq-%1]
+    movu  [px+(%2+0)*%3-%1*2], m1
+    movu  [px+(%2+1)*%3-%1*2], m2
+%if %1 == 8
+    movd  [px+(%2-1)*%3+%1*2], xm14                ; overwritten by previous movu
+%endif
+    movd  [px+(%2+0)*%3+%1*2], xm14
+    movd  [px+(%2+1)*%3+%1*2], xm14
     jmp .bottom_done
 .bottom_no_left:
     test          edged, 2                  ; have_right
@@ -214,28 +312,37 @@
     jz .bottom_no_left_right
     pmovzxbw        m1, [dst8q]
     pmovzxbw        m2, [dst8q+strideq]
-    movu   [px+8*32+0], m1
-    movu   [px+9*32+0], m2
-    movd   [px+8*32-4], xm14
-    movd   [px+9*32-4], xm14
+    mova   [px+(%2+0)*%3+0], m1
+    mova   [px+(%2+1)*%3+0], m2
+    movd   [px+(%2+0)*%3-4], xm14
+    movd   [px+(%2+1)*%3-4], xm14
     jmp .bottom_done
 .bottom_no_left_right:
+%if %1 == 4
+    movd           xm1, [dst8q]
+    pinsrd         xm1, [dst8q+strideq], 1
+    pmovzxbw       xm1, xm1
+    movq   [px+(%2+0)*%3+0], xm1
+    movhps [px+(%2+1)*%3+0], xm1
+%else
     pmovzxbw       xm1, [dst8q]
     pmovzxbw       xm2, [dst8q+strideq]
-    movu   [px+8*32+0], xm1
-    movu   [px+9*32+0], xm2
-    movd   [px+8*32-4], xm14
-    movd   [px+9*32-4], xm14
-    movd  [px+8*32+16], xm14
-    movd  [px+9*32+16], xm14
+    mova   [px+(%2+0)*%3+0], xm1
+    mova   [px+(%2+1)*%3+0], xm2
+%endif
+    movd   [px+(%2+0)*%3-4], xm14
+    movd   [px+(%2+1)*%3-4], xm14
+    movd  [px+(%2+0)*%3+%1*2], xm14
+    movd  [px+(%2+1)*%3+%1*2], xm14
     jmp .bottom_done
 .no_bottom:
-    movu   [px+8*32-8], m14
-    movu   [px+9*32-8], m14
+    movu   [px+(%2+0)*%3-%1], m14
+    movu   [px+(%2+1)*%3-%1], m14
 .bottom_done:
 
     ; actual filter
-    DEFINE_ARGS dst, stride, pridmp, damping, pri, sec, secdmp
+    INIT_YMM avx2
+    DEFINE_ARGS dst, stride, pridmp, damping, pri, sec, stride3, secdmp
 %undef edged
     movifnidn     prid, prim
     movifnidn     secd, secm
@@ -258,7 +365,7 @@
     mov        [rsp+8], secdmpq                 ; sec_shift
 
     ; pri/sec_taps[k] [4 total]
-    DEFINE_ARGS dst, stride, tap, dummy, pri, sec
+    DEFINE_ARGS dst, stride, tap, dummy, pri, sec, stride3
     movd           xm0, prid
     movd           xm1, secd
     vpbroadcastw    m0, xm0                     ; pri_strength
@@ -270,17 +377,31 @@
     lea           secq, [tapq+secq*4+8]         ; sec_taps
 
     ; off1/2/3[k] [6 total] from [tapq+16+(dir+0/2/6)*2+k]
-    DEFINE_ARGS dst, stride, tap, dir, pri, sec
+    DEFINE_ARGS dst, stride, tap, dir, pri, sec, stride3
     mov           dird, r6m
     lea           tapq, [tapq+dirq*2+16]
-    DEFINE_ARGS dst, stride, dir, h, pri, sec, stk, off, k
-    mov             hd, 4
+%if %1*%2*2/mmsize > 1
+    DEFINE_ARGS dst, stride, dir, stk, pri, sec, h, off, k
+    mov             hd, %1*%2*2/mmsize
+%else
+    DEFINE_ARGS dst, stride, dir, stk, pri, sec, stride3, off, k
+%endif
     lea           stkq, [px]
     pxor           m13, m13
+%if %1*%2*2/mmsize > 1
 .v_loop:
+%endif
     mov             kd, 1
-    mova           xm4, [stkq+32*0]             ; px
-    vinserti128     m4, [stkq+32*1], 1
+%if %1 == 4
+    movq           xm4, [stkq+%3*0]
+    movhps         xm4, [stkq+%3*1]
+    movq           xm5, [stkq+%3*2]
+    movhps         xm5, [stkq+%3*3]
+    vinserti128     m4, xm5, 1
+%else
+    mova           xm4, [stkq+%3*0]             ; px
+    vinserti128     m4, [stkq+%3*1], 1
+%endif
     pxor           m15, m15                     ; sum
     mova            m7, m4                      ; max
     mova            m8, m4                      ; min
@@ -288,48 +409,10 @@
     vpbroadcastw    m2, [priq+kq*2]             ; pri_taps
     vpbroadcastw    m3, [secq+kq*2]             ; sec_taps
 
-%macro ACCUMULATE_TAP 4 ; tap_offset, shift, strength, mul_tap
-    ; load p0/p1
-    movsx         offq, byte [dirq+kq+%1]       ; off1
-    movu           xm5, [stkq+offq*2+32*0]      ; p0
-    vinserti128     m5, [stkq+offq*2+32*1], 1
-    neg           offq                          ; -off1
-    movu           xm6, [stkq+offq*2+32*0]      ; p1
-    vinserti128     m6, [stkq+offq*2+32*1], 1
-    pcmpeqw         m9, m14, m5
-    pcmpeqw        m10, m14, m6
-    pandn           m9, m5
-    pandn          m10, m6
-    pmaxsw          m7, m9                      ; max after p0
-    pminsw          m8, m5                      ; min after p0
-    pmaxsw          m7, m10                     ; max after p1
-    pminsw          m8, m6                      ; min after p1
+    ACCUMULATE_TAP 0*2, [rsp+0], m0, m2, %1, %3
+    ACCUMULATE_TAP 2*2, [rsp+8], m1, m3, %1, %3
+    ACCUMULATE_TAP 6*2, [rsp+8], m1, m3, %1, %3
 
-    ; accumulate sum[m15] over p0/p1
-    psubw           m5, m4                      ; diff_p0(p0 - px)
-    psubw           m6, m4                      ; diff_p1(p1 - px)
-    pabsw           m9, m5
-    pabsw          m10, m6
-    psraw          m11, m9,  %2
-    psraw          m12, m10, %2
-    psubw          m11, %3, m11
-    psubw          m12, %3, m12
-    pmaxsw         m11, m13
-    pmaxsw         m12, m13
-    pminsw         m11, m9
-    pminsw         m12, m10
-    psignw         m11, m5                      ; constrain(diff_p0)
-    psignw         m12, m6                      ; constrain(diff_p1)
-    pmullw         m11, %4                      ; constrain(diff_p0) * pri_taps
-    pmullw         m12, %4                      ; constrain(diff_p1) * pri_taps
-    paddw          m15, m11
-    paddw          m15, m12
-%endmacro
-
-    ACCUMULATE_TAP 0*2, [rsp+0], m0, m2
-    ACCUMULATE_TAP 2*2, [rsp+8], m1, m3
-    ACCUMULATE_TAP 6*2, [rsp+8], m1, m3
-
     dec             kq
     jge .k_loop
 
@@ -342,14 +425,28 @@
     pmaxsw          m4, m8
     packuswb        m4, m4
     vextracti128   xm5, m4, 1
+%if %1 == 4
+    movd [dstq+strideq*0], xm4
+    pextrd [dstq+strideq*1], xm4, 1
+    movd [dstq+strideq*2], xm5
+    pextrd [dstq+stride3q], xm5, 1
+%else
     movq [dstq+strideq*0], xm4
     movq [dstq+strideq*1], xm5
+%endif
+
+%if %1*%2*2/mmsize > 1
     lea           dstq, [dstq+strideq*2]
-    add           stkq, 32*2
+    add           stkq, %3*2
     dec             hd
     jg .v_loop
+%endif
 
     RET
+%endmacro
+
+cdef_filter_fn 8, 8, 32
+cdef_filter_fn 4, 4, 32
 
 INIT_YMM avx2
 cglobal cdef_dir, 3, 4, 15, src, stride, var, stride3
--- a/src/x86/cdef_init_tmpl.c
+++ b/src/x86/cdef_init_tmpl.c
@@ -29,6 +29,8 @@
 #include "src/cdef.h"
 
 decl_cdef_fn(dav1d_cdef_filter_8x8_avx2);
+decl_cdef_fn(dav1d_cdef_filter_4x4_avx2);
+
 decl_cdef_dir_fn(dav1d_cdef_dir_avx2);
 
 void bitfn(dav1d_cdef_dsp_init_x86)(Dav1dCdefDSPContext *const c) {
@@ -39,5 +41,6 @@
 #if BITDEPTH == 8 && ARCH_X86_64
     c->dir = dav1d_cdef_dir_avx2;
     c->fb[0] = dav1d_cdef_filter_8x8_avx2;
+    c->fb[2] = dav1d_cdef_filter_4x4_avx2;
 #endif
 }