shithub: dav1d

Download patch

ref: 70fb01d8c76578fad89b606edf34446025b79dad
parent: e2aa2d1446d7c8af0d6d00b3e8d91563aeac36bd
author: Ronald S. Bultje <rsbultje@gmail.com>
date: Mon Dec 3 09:54:56 EST 2018

Make per-width versions of cfl_ac

Also use aligned reads and writes in sub_loop, and integrate sum_loop into
the main loop.

before:
cfl_ac_420_w4_8bpc_c: 367.4
cfl_ac_420_w4_8bpc_avx2: 72.8
cfl_ac_420_w8_8bpc_c: 621.6
cfl_ac_420_w8_8bpc_avx2: 85.1
cfl_ac_420_w16_8bpc_c: 983.4
cfl_ac_420_w16_8bpc_avx2: 141.0

after:
cfl_ac_420_w4_8bpc_c: 376.2
cfl_ac_420_w4_8bpc_avx2: 28.5
cfl_ac_420_w8_8bpc_c: 607.2
cfl_ac_420_w8_8bpc_avx2: 29.9
cfl_ac_420_w16_8bpc_c: 962.1
cfl_ac_420_w16_8bpc_avx2: 48.8

--- a/src/x86/ipred.asm
+++ b/src/x86/ipred.asm
@@ -67,6 +67,18 @@
 ipred_h_shuf: db  7,  7,  7,  7,  3,  3,  3,  3,  5,  5,  5,  5,  1,  1,  1,  1
               db  6,  6,  6,  6,  2,  2,  2,  2,  4,  4,  4,  4,  0,  0,  0,  0
 
+cfl_ac_w16_pad_shuffle: ; w=16, w_pad=1
+                        db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+                        ; w=8, w_pad=1 as well as second half of previous one
+cfl_ac_w8_pad1_shuffle: db 0, 1, 2, 3, 4, 5
+                        times 5 db 6, 7
+                        ; w=16,w_pad=2
+                        db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+                        times 8 db 14, 15
+                        ; w=16,w_pad=3
+                        db 0, 1, 2, 3, 4, 5
+                        times 13 db 6, 7
+
 pb_1:   times 4 db 1
 pb_2:   times 4 db 2
 pb_128: times 4 db 128
@@ -102,6 +114,7 @@
 JMP_TABLE ipred_cfl,      avx2, h4, h8, h16, h32, w4, w8, w16, w32, \
                                 s4-8*4, s8-8*4, s16-8*4, s32-8*4
 JMP_TABLE ipred_cfl_left, avx2, h4, h8, h16, h32
+JMP_TABLE ipred_cfl_ac_420, avx2, w16_pad1, w16_pad2, w16_pad3
 JMP_TABLE pal_pred,       avx2, w4, w8, w16, w32, w64
 
 cextern filter_intra_taps
@@ -1784,99 +1797,185 @@
     movifnidn           acq, acmp
     jmp                  wq
 
-cglobal ipred_cfl_ac_420, 6, 10, 5, ac, y, stride, wpad, hpad, w, h
-    shl               wpadd, 2
+cglobal ipred_cfl_ac_420, 4, 9, 5, ac, y, stride, wpad, hpad, w, h, sz, ac_bak
+    movifnidn         hpadd, hpadm
+    movifnidn            wd, wm
+    mov                  hd, hm
+    mov                 szd, wd
+    mov             ac_bakq, acq
+    imul                szd, hd
     shl               hpadd, 2
-    mov                 r9d, hm
-    mov                 r6d, wd
-    movsxd               wq, wd
-    add                  yq, strideq
-    mov                  r7, acq
-    sub                 r6d, wpadd
-    sub                 r9d, hpadd
-    mov                 r8d, r9d
-    vpbroadcastd        xm2, [pb_2]
-.dec_rows:
-    mov                  r3, yq
-    xor                  r4, r4
-    sub                  r3, strideq
-.dec_cols:
-    movq                xm0, [r3+r4*2]
-    movq                xm1, [yq+r4*2]
+    sub                  hd, hpadd
+    vpbroadcastd         m2, [pb_2]
+    pxor                 m4, m4
+    cmp                  wd, 8
+    jg .w16
+    je .w8
+    ; fall-through
+
+    DEFINE_ARGS ac, y, stride, wpad, hpad, stride3, h, sz, ac_bak
+.w4:
+    lea            stride3q, [strideq*3]
+.w4_loop:
+    movq                xm0, [yq]
+    movq                xm1, [yq+strideq]
+    movhps              xm0, [yq+strideq*2]
+    movhps              xm1, [yq+stride3q]
     pmaddubsw           xm0, xm2
     pmaddubsw           xm1, xm2
     paddw               xm0, xm1
-    movq          [r7+r4*2], xm0
-    add                  r4, 4
-    cmp                 r6d, r4d
-    jg .dec_cols
-    lea                  r7, [r7+wq*2]
+    mova              [acq], xm0
+    paddw               xm4, xm0
+    lea                  yq, [yq+strideq*4]
+    add                 acq, 16
+    sub                  hd, 2
+    jg .w4_loop
+    test              hpadd, hpadd
+    jz .calc_avg
+    vpermq               m0, m0, q1111
+.w4_hpad_loop:
+    mova              [acq], m0
+    paddw                m4, m0
+    add                 acq, 32
+    sub               hpadd, 4
+    jg .w4_hpad_loop
+    jmp .calc_avg
+
+.w8:
+    lea            stride3q, [strideq*3]
+    test              wpadd, wpadd
+    jnz .w8_wpad
+.w8_loop:
+    mova                xm0, [yq]
+    mova                xm1, [yq+strideq]
+    vinserti128          m0, [yq+strideq*2], 1
+    vinserti128          m1, [yq+stride3q], 1
+    pmaddubsw            m0, m2
+    pmaddubsw            m1, m2
+    paddw                m0, m1
+    mova              [acq], m0
+    paddw                m4, m0
+    lea                  yq, [yq+strideq*4]
+    add                 acq, 32
+    sub                  hd, 2
+    jg .w8_loop
+    test              hpadd, hpadd
+    jz .calc_avg
+    jmp .w8_hpad
+.w8_wpad:
+    vbroadcasti128       m3, [cfl_ac_w8_pad1_shuffle]
+.w8_wpad_loop:
+    movq                xm0, [yq]
+    movq                xm1, [yq+strideq]
+    vinserti128          m0, [yq+strideq*2], 1
+    vinserti128          m1, [yq+stride3q], 1
+    pmaddubsw            m0, m2
+    pmaddubsw            m1, m2
+    paddw                m0, m1
+    pshufb               m0, m3
+    mova              [acq], m0
+    paddw                m4, m0
+    lea                  yq, [yq+strideq*4]
+    add                 acq, 32
+    sub                  hd, 2
+    jg .w8_wpad_loop
+    test              hpadd, hpadd
+    jz .calc_avg
+.w8_hpad:
+    vpermq               m0, m0, q3232
+.w8_hpad_loop:
+    mova              [acq], m0
+    paddw                m4, m0
+    add                 acq, 32
+    sub               hpadd, 2
+    jg .w8_hpad_loop
+    jmp .calc_avg
+
+.w16:
+    test              wpadd, wpadd
+    jnz .w16_wpad
+.w16_loop:
+    mova                 m0, [yq]
+    mova                 m1, [yq+strideq]
+    pmaddubsw            m0, m2
+    pmaddubsw            m1, m2
+    paddw                m0, m1
+    mova              [acq], m0
+    paddw                m4, m0
     lea                  yq, [yq+strideq*2]
-    dec                 r8d
-    jg .dec_rows
-    cmp                 r6d, wd
-    je .wpad_end
-    mov                  r7, acq
-    lea                  r1, [r6q+r6q]
-.wpad_rows:
-    vpbroadcastw        xm0, [r7+r1-2]
-    mov                 r2q, r6q
-.wpad_cols:
-    movq         [r7+r2q*2], xm0
-    add                 r2q, 4
-    cmp                  wd, r2d
-    jg .wpad_cols
-    lea                  r7, [r7+wq*2]
-    dec                 r9d
-    jg .wpad_rows
-.wpad_end:
-    bsf                 r3d, hm
-    shlx                r6d, wd, r3d
-    neg                  wd
-    bsf                 r3d, r6d
-    movsxd               wq, wd
-    add                  wq, wq
-    movsxd              r2q, r6d
-    lea                 r2q, [acq+r2q*2]
-.hpad_loop:
-    cmp                 r2q, r7
-    jbe .hpad_end
-    mov                  r1, [r7+wq]
-    add                  r7, 8
-    mov              [r7-8], r1
-    jmp .hpad_loop
-.hpad_end:
-    mov                  r1, acq
-    pxor                 m1, m1
-    vpbroadcastd         m3, [pw_1]
-.sum_loop:
-    movdqu               m0, [r1]
-    add                  r1, 32
-    cmp                 r2q, r1
-    pmaddwd              m0, m3
-    paddd                m1, m0
-    ja .sum_loop
-    vextracti128        xm0, m1, 1
-    sar                 r6d, 1
-    movd                xm4, r6d
-    mov                 r6d, r3d
+    add                 acq, 32
+    dec                  hd
+    jg .w16_loop
+    test              hpadd, hpadd
+    jz .calc_avg
+    jmp .w16_hpad_loop
+.w16_wpad:
+    DEFINE_ARGS ac, y, stride, wpad, hpad, iptr, h, sz, ac_bak
+    lea               iptrq, [ipred_cfl_ac_420_avx2_table]
+    shl               wpadd, 2
+    mova                 m3, [iptrq+cfl_ac_w16_pad_shuffle- \
+                              ipred_cfl_ac_420_avx2_table+wpadq*8-32]
+    movsxd            wpadq, [iptrq+wpadq+4]
+    add               iptrq, wpadq
+    jmp iptrq
+.w16_pad3:
+    vpbroadcastq         m0, [yq]
+    vpbroadcastq         m1, [yq+strideq]
+    jmp .w16_wpad_end
+.w16_pad2:
+    vbroadcasti128       m0, [yq]
+    vbroadcasti128       m1, [yq+strideq]
+    jmp .w16_wpad_end
+.w16_pad1:
+    mova                 m0, [yq]
+    mova                 m1, [yq+strideq]
+    ; fall-through
+.w16_wpad_end:
+    pmaddubsw            m0, m2
+    pmaddubsw            m1, m2
+    paddw                m0, m1
+    pshufb               m0, m3
+    mova              [acq], m0
+    paddw                m4, m0
+    lea                  yq, [yq+strideq*2]
+    add                 acq, 32
+    dec                  hd
+    jz .w16_wpad_done
+    jmp iptrq
+.w16_wpad_done:
+    test              hpadd, hpadd
+    jz .calc_avg
+.w16_hpad_loop:
+    mova              [acq], m0
+    paddw                m4, m0
+    add                 acq, 32
+    dec               hpadd
+    jg .w16_hpad_loop
+    ; fall-through
+
+.calc_avg:
+    vpbroadcastd         m2, [pw_1]
+    pmaddwd              m0, m4, m2
+    vextracti128        xm1, m0, 1
+    tzcnt               r1d, szd
     paddd               xm0, xm1
+    movd                xm2, r1d
+    movd                xm3, szd
     punpckhqdq          xm1, xm0, xm0
-    paddd               xm1, xm0
-    vbroadcastss        xm0, xm4
-    psrlq               xm2, xm1, 32
-    movq                xm4, r6q
-    paddd               xm0, xm2
     paddd               xm0, xm1
-    psrld               xm0, xm4
+    psrad               xm3, 1
+    psrlq               xm1, xm0, 32
+    paddd               xm0, xm3
+    paddd               xm0, xm1
+    psrad               xm0, xm2
     vpbroadcastw         m0, xm0
 .sub_loop:
-    movdqu               m1, [acq]
-    add                 acq, 32
+    mova                 m1, [ac_bakq]
     psubw                m1, m0
-    movdqu         [acq-32], m1
-    cmp                 r2q, acq
-    ja .sub_loop
+    mova          [ac_bakq], m1
+    add             ac_bakq, 32
+    sub                 szd, 16
+    jg .sub_loop
     RET
 
 cglobal pal_pred, 4, 6, 5, dst, stride, pal, idx, w, h
--- a/tests/checkasm/ipred.c
+++ b/tests/checkasm/ipred.c
@@ -120,17 +120,18 @@
             {
                 for (int h = imax(w / 4, 4); h <= imin(w * 4, (32 >> ss_ver)); h <<= 1) {
                     const ptrdiff_t stride = 32 * sizeof(pixel);
-                    const int w_pad = rand() & ((w >> 2) - 1);
-                    const int h_pad = rand() & ((h >> 2) - 1);
+                    for (int w_pad = (w >> 2) - 1; w_pad >= 0; w_pad--) {
+                        for (int h_pad = (h >> 2) - 1; h_pad >= 0; h_pad--) {
+                            for (int y = 0; y < (h << ss_ver); y++)
+                                for (int x = 0; x < (w << ss_hor); x++)
+                                    luma[y * 32 + x] = rand() & ((1 << BITDEPTH) - 1);
 
-                    for (int y = 0; y < (h << ss_ver); y++)
-                        for (int x = 0; x < (w << ss_hor); x++)
-                            luma[y * 32 + x] = rand() & ((1 << BITDEPTH) - 1);
-
-                    call_ref(c_dst, luma, stride, w_pad, h_pad, w, h);
-                    call_new(a_dst, luma, stride, w_pad, h_pad, w, h);
-                    if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
-                        fail();
+                            call_ref(c_dst, luma, stride, w_pad, h_pad, w, h);
+                            call_new(a_dst, luma, stride, w_pad, h_pad, w, h);
+                            if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
+                                fail();
+                        }
+                    }
 
                     bench_new(a_dst, luma, stride, 0, 0, w, h);
                 }