shithub: dav1d

Download patch

ref: 8ab69afb72053a91ccd57f4e5bec97e886fe6328
parent: ea9fc9d921ca9956838122ddf2457c651f926dc3
author: Martin Storsjö <martin@martin.st>
date: Thu Sep 19 06:34:14 EDT 2019

arm64: ipred: NEON implementation of paeth prediction

Relative speedups over the C code:
                            Cortex A53    A72    A73
intra_pred_paeth_w4_8bpc_neon:    8.36   6.55   7.27
intra_pred_paeth_w8_8bpc_neon:   15.24  11.36  11.34
intra_pred_paeth_w16_8bpc_neon:  16.63  13.20  14.17
intra_pred_paeth_w32_8bpc_neon:  10.83   9.21   9.87
intra_pred_paeth_w64_8bpc_neon:   8.37   7.07   7.45

--- a/src/arm/64/ipred.S
+++ b/src/arm/64/ipred.S
@@ -690,3 +690,180 @@
         .hword L(ipred_dc_tbl) - L(ipred_dc_w8)
         .hword L(ipred_dc_tbl) - L(ipred_dc_w4)
 endfunc
+
+// void ipred_paeth_neon(pixel *dst, const ptrdiff_t stride,
+//                       const pixel *const topleft,
+//                       const int width, const int height, const int a,
+//                       const int max_width, const int max_height);
+function ipred_paeth_neon, export=1
+        clz             w9,  w3
+        adr             x5,  L(ipred_paeth_tbl)
+        sub             w9,  w9,  #25
+        ldrh            w9,  [x5, w9, uxtw #1]
+        ld1r            {v4.16b},  [x2]
+        add             x8,  x2,  #1
+        sub             x2,  x2,  #4
+        sub             x5,  x5,  w9, uxtw
+        mov             x7,  #-4
+        add             x6,  x0,  x1
+        lsl             x1,  x1,  #1
+        br              x5
+40:
+        ld1r            {v5.4s},  [x8]
+        usubl           v6.8h,   v5.8b,   v4.8b   // top - topleft
+4:
+        ld4r            {v0.8b, v1.8b, v2.8b, v3.8b},  [x2], x7
+        zip1            v0.2s,   v0.2s,   v1.2s
+        zip1            v2.2s,   v2.2s,   v3.2s
+        uaddw           v16.8h,  v6.8h,   v0.8b
+        uaddw           v17.8h,  v6.8h,   v2.8b
+        sqxtun          v16.8b,  v16.8h           // base
+        sqxtun2         v16.16b, v17.8h
+        zip1            v0.2d,   v0.2d,   v2.2d
+        uabd            v20.16b, v5.16b,  v16.16b // tdiff
+        uabd            v22.16b, v4.16b,  v16.16b // tldiff
+        uabd            v16.16b, v0.16b,  v16.16b // ldiff
+        umin            v18.16b, v20.16b, v22.16b // min(tdiff, tldiff)
+        cmhs            v20.16b, v22.16b, v20.16b // tldiff >= tdiff
+        cmhs            v16.16b, v18.16b, v16.16b // min(tdiff, tldiff) >= ldiff
+        bsl             v20.16b, v5.16b,  v4.16b  // tdiff <= tldiff ? top : topleft
+        bit             v20.16b, v0.16b,  v16.16b // ldiff <= min ? left : ...
+        st1             {v20.s}[3], [x0], x1
+        st1             {v20.s}[2], [x6], x1
+        subs            w4,  w4,  #4
+        st1             {v20.s}[1], [x0], x1
+        st1             {v20.s}[0], [x6], x1
+        b.gt            4b
+        ret
+80:
+        ld1r            {v5.2d},  [x8]
+        usubl           v6.8h,   v5.8b,   v4.8b   // top - topleft
+8:
+        ld4r            {v0.8b, v1.8b, v2.8b, v3.8b},  [x2], x7
+        uaddw           v16.8h,  v6.8h,   v0.8b
+        uaddw           v17.8h,  v6.8h,   v1.8b
+        uaddw           v18.8h,  v6.8h,   v2.8b
+        uaddw           v19.8h,  v6.8h,   v3.8b
+        sqxtun          v16.8b,  v16.8h           // base
+        sqxtun2         v16.16b, v17.8h
+        sqxtun          v18.8b,  v18.8h
+        sqxtun2         v18.16b, v19.8h
+        zip1            v2.2d,   v2.2d,   v3.2d
+        zip1            v0.2d,   v0.2d,   v1.2d
+        uabd            v21.16b, v5.16b,  v18.16b // tdiff
+        uabd            v20.16b, v5.16b,  v16.16b
+        uabd            v23.16b, v4.16b,  v18.16b // tldiff
+        uabd            v22.16b, v4.16b,  v16.16b
+        uabd            v17.16b, v2.16b,  v18.16b // ldiff
+        uabd            v16.16b, v0.16b,  v16.16b
+        umin            v19.16b, v21.16b, v23.16b // min(tdiff, tldiff)
+        umin            v18.16b, v20.16b, v22.16b
+        cmhs            v21.16b, v23.16b, v21.16b // tldiff >= tdiff
+        cmhs            v20.16b, v22.16b, v20.16b
+        cmhs            v17.16b, v19.16b, v17.16b // min(tdiff, tldiff) >= ldiff
+        cmhs            v16.16b, v18.16b, v16.16b
+        bsl             v21.16b, v5.16b,  v4.16b  // tdiff <= tldiff ? top : topleft
+        bsl             v20.16b, v5.16b,  v4.16b
+        bit             v21.16b, v2.16b,  v17.16b // ldiff <= min ? left : ...
+        bit             v20.16b, v0.16b,  v16.16b
+        st1             {v21.d}[1], [x0], x1
+        st1             {v21.d}[0], [x6], x1
+        subs            w4,  w4,  #4
+        st1             {v20.d}[1], [x0], x1
+        st1             {v20.d}[0], [x6], x1
+        b.gt            8b
+        ret
+160:
+320:
+640:
+        ld1             {v5.16b},  [x8], #16
+        mov             w9,  w3
+        // Set up pointers for four rows in parallel; x0, x6, x5, x10
+        add             x5,  x0,  x1
+        add             x10, x6,  x1
+        lsl             x1,  x1,  #1
+        sub             x1,  x1,  w3, uxtw
+1:
+        ld4r            {v0.16b, v1.16b, v2.16b, v3.16b},  [x2], x7
+2:
+        usubl           v6.8h,   v5.8b,   v4.8b   // top - topleft
+        usubl2          v7.8h,   v5.16b,  v4.16b
+        uaddw           v24.8h,  v6.8h,   v0.8b
+        uaddw           v25.8h,  v7.8h,   v0.8b
+        uaddw           v26.8h,  v6.8h,   v1.8b
+        uaddw           v27.8h,  v7.8h,   v1.8b
+        uaddw           v28.8h,  v6.8h,   v2.8b
+        uaddw           v29.8h,  v7.8h,   v2.8b
+        uaddw           v30.8h,  v6.8h,   v3.8b
+        uaddw           v31.8h,  v7.8h,   v3.8b
+        sqxtun          v17.8b,  v26.8h           // base
+        sqxtun2         v17.16b, v27.8h
+        sqxtun          v16.8b,  v24.8h
+        sqxtun2         v16.16b, v25.8h
+        sqxtun          v19.8b,  v30.8h
+        sqxtun2         v19.16b, v31.8h
+        sqxtun          v18.8b,  v28.8h
+        sqxtun2         v18.16b, v29.8h
+        uabd            v23.16b, v5.16b,  v19.16b // tdiff
+        uabd            v22.16b, v5.16b,  v18.16b
+        uabd            v21.16b, v5.16b,  v17.16b
+        uabd            v20.16b, v5.16b,  v16.16b
+        uabd            v27.16b, v4.16b,  v19.16b // tldiff
+        uabd            v26.16b, v4.16b,  v18.16b
+        uabd            v25.16b, v4.16b,  v17.16b
+        uabd            v24.16b, v4.16b,  v16.16b
+        uabd            v19.16b, v3.16b,  v19.16b // ldiff
+        uabd            v18.16b, v2.16b,  v18.16b
+        uabd            v17.16b, v1.16b,  v17.16b
+        uabd            v16.16b, v0.16b,  v16.16b
+        umin            v31.16b, v23.16b, v27.16b // min(tdiff, tldiff)
+        umin            v30.16b, v22.16b, v26.16b
+        umin            v29.16b, v21.16b, v25.16b
+        umin            v28.16b, v20.16b, v24.16b
+        cmhs            v23.16b, v27.16b, v23.16b // tldiff >= tdiff
+        cmhs            v22.16b, v26.16b, v22.16b
+        cmhs            v21.16b, v25.16b, v21.16b
+        cmhs            v20.16b, v24.16b, v20.16b
+        cmhs            v19.16b, v31.16b, v19.16b // min(tdiff, tldiff) >= ldiff
+        cmhs            v18.16b, v30.16b, v18.16b
+        cmhs            v17.16b, v29.16b, v17.16b
+        cmhs            v16.16b, v28.16b, v16.16b
+        bsl             v23.16b, v5.16b,  v4.16b  // tdiff <= tldiff ? top : topleft
+        bsl             v22.16b, v5.16b,  v4.16b
+        bsl             v21.16b, v5.16b,  v4.16b
+        bsl             v20.16b, v5.16b,  v4.16b
+        bit             v23.16b, v3.16b,  v19.16b // ldiff <= min ? left : ...
+        bit             v22.16b, v2.16b,  v18.16b
+        bit             v21.16b, v1.16b,  v17.16b
+        bit             v20.16b, v0.16b,  v16.16b
+        subs            w3,  w3,  #16
+        st1             {v23.16b}, [x0],  #16
+        st1             {v22.16b}, [x6],  #16
+        st1             {v21.16b}, [x5],  #16
+        st1             {v20.16b}, [x10], #16
+        b.le            8f
+        ld1             {v5.16b},  [x8], #16
+        b               2b
+8:
+        subs            w4,  w4,  #4
+        b.le            9f
+        // End of horizontal loop, move pointers to next four rows
+        sub             x8,  x8,  w9, uxtw
+        add             x0,  x0,  x1
+        add             x6,  x6,  x1
+        // Load the top row as early as possible
+        ld1             {v5.16b},  [x8], #16
+        add             x5,  x5,  x1
+        add             x10, x10, x1
+        mov             w3,  w9
+        b               1b
+9:
+        ret
+
+L(ipred_paeth_tbl):
+        .hword L(ipred_paeth_tbl) - 640b
+        .hword L(ipred_paeth_tbl) - 320b
+        .hword L(ipred_paeth_tbl) - 160b
+        .hword L(ipred_paeth_tbl) -  80b
+        .hword L(ipred_paeth_tbl) -  40b
+endfunc
--- a/src/arm/ipred_init_tmpl.c
+++ b/src/arm/ipred_init_tmpl.c
@@ -33,6 +33,7 @@
 decl_angular_ipred_fn(dav1d_ipred_dc_left_neon);
 decl_angular_ipred_fn(dav1d_ipred_h_neon);
 decl_angular_ipred_fn(dav1d_ipred_v_neon);
+decl_angular_ipred_fn(dav1d_ipred_paeth_neon);
 
 COLD void bitfn(dav1d_intra_pred_dsp_init_arm)(Dav1dIntraPredDSPContext *const c) {
     const unsigned flags = dav1d_get_cpu_flags();
@@ -46,5 +47,6 @@
     c->intra_pred[LEFT_DC_PRED]  = dav1d_ipred_dc_left_neon;
     c->intra_pred[HOR_PRED]      = dav1d_ipred_h_neon;
     c->intra_pred[VERT_PRED]     = dav1d_ipred_v_neon;
+    c->intra_pred[PAETH_PRED]    = dav1d_ipred_paeth_neon;
 #endif
 }