shithub: dav1d

Download patch

ref: c8aaddeaf3eab776d918081ac9216ad4e6a901c2
parent: bce8fae96e397405fb4ae526f3eb631bd63a2717
author: Martin Storsjö <martin@martin.st>
date: Tue Mar 3 09:27:24 EST 2020

arm64: mc: NEON implementation of w_mask for 16 bpc

Checkasm numbers:          Cortex A53       A72       A73
w_mask_420_w4_16bpc_neon:       173.6     123.5     120.3
w_mask_420_w8_16bpc_neon:       484.2     344.1     329.5
w_mask_420_w16_16bpc_neon:     1411.2    1027.4    1035.1
w_mask_420_w32_16bpc_neon:     5561.5    4093.2    3980.1
w_mask_420_w64_16bpc_neon:    13809.6    9856.5    9581.0
w_mask_420_w128_16bpc_neon:   35614.7   25553.8   24284.4
w_mask_422_w4_16bpc_neon:       159.4     112.2     114.2
w_mask_422_w8_16bpc_neon:       453.4     326.1     326.7
w_mask_422_w16_16bpc_neon:     1394.6    1062.3    1050.2
w_mask_422_w32_16bpc_neon:     5485.8    4219.6    4027.3
w_mask_422_w64_16bpc_neon:    13701.2   10079.6    9692.6
w_mask_422_w128_16bpc_neon:   35455.3   25892.5   24625.9
w_mask_444_w4_16bpc_neon:       153.0     112.3     112.7
w_mask_444_w8_16bpc_neon:       437.2     331.8     325.8
w_mask_444_w16_16bpc_neon:     1395.1    1069.1    1041.7
w_mask_444_w32_16bpc_neon:     5370.1    4213.5    4138.1
w_mask_444_w64_16bpc_neon:    13482.6   10190.5   10004.6
w_mask_444_w128_16bpc_neon:   35583.7   26911.2   25638.8

Corresponding numbers for 8 bpc for comparison:

w_mask_420_w4_8bpc_neon:        126.6      79.1      87.7
w_mask_420_w8_8bpc_neon:        343.9     195.0     211.5
w_mask_420_w16_8bpc_neon:       886.3     540.3     577.7
w_mask_420_w32_8bpc_neon:      3558.6    2152.4    2216.7
w_mask_420_w64_8bpc_neon:      8894.9    5161.2    5297.0
w_mask_420_w128_8bpc_neon:    22520.1   13514.5   13887.2
w_mask_422_w4_8bpc_neon:        112.9      68.2      77.0
w_mask_422_w8_8bpc_neon:        314.4     175.5     208.7
w_mask_422_w16_8bpc_neon:       835.5     565.0     608.3
w_mask_422_w32_8bpc_neon:      3381.3    2231.8    2287.6
w_mask_422_w64_8bpc_neon:      8499.4    5343.6    5460.8
w_mask_422_w128_8bpc_neon:    21823.3   14206.5   14249.1
w_mask_444_w4_8bpc_neon:        104.6      65.8      72.7
w_mask_444_w8_8bpc_neon:        290.4     173.7     196.6
w_mask_444_w16_8bpc_neon:       831.4     586.7     591.7
w_mask_444_w32_8bpc_neon:      3320.8    2300.6    2251.0
w_mask_444_w64_8bpc_neon:      8300.0    5480.5    5346.8
w_mask_444_w128_8bpc_neon:    21633.8   15981.3   14384.8

--- a/src/arm/64/mc16.S
+++ b/src/arm/64/mc16.S
@@ -241,6 +241,320 @@
 bidir_fn mask, w7
 
 
+.macro w_mask_fn type
+function w_mask_\type\()_16bpc_neon, export=1
+        ldr             w8,  [sp]
+        clz             w9,  w4
+        adr             x10, L(w_mask_\type\()_tbl)
+        dup             v31.8h,  w8   // bitdepth_max
+        sub             w9,  w9,  #24
+        clz             w8,  w8       // clz(bitdepth_max)
+        ldrh            w9,  [x10,  x9,  lsl #1]
+        sub             x10, x10, w9,  uxtw
+        sub             w8,  w8,  #12 // sh = intermediate_bits + 6 = clz(bitdepth_max) - 12
+        mov             w9,  #PREP_BIAS*64
+        neg             w8,  w8       // -sh
+        mov             w11, #27615   // (64 + 1 - 38)<<mask_sh - 1 - mask_rnd
+        dup             v30.4s,  w9   // PREP_BIAS*64
+        dup             v29.4s,  w8   // -sh
+        dup             v0.8h,   w11
+.if \type == 444
+        movi            v1.16b,  #64
+.elseif \type == 422
+        dup             v2.8b,   w7
+        movi            v3.8b,   #129
+        sub             v3.8b,   v3.8b,   v2.8b
+.elseif \type == 420
+        dup             v2.8h,   w7
+        movi            v3.8h,   #1, lsl #8
+        sub             v3.8h,   v3.8h,   v2.8h
+.endif
+        add             x12,  x0,  x1
+        lsl             x1,   x1,  #1
+        br              x10
+4:
+        ld1             {v4.8h, v5.8h}, [x2], #32 // tmp1 (four rows at once)
+        ld1             {v6.8h, v7.8h}, [x3], #32 // tmp2 (four rows at once)
+        subs            w5,  w5,  #4
+        sabd            v20.8h,  v4.8h,   v6.8h   // abs(tmp1 - tmp2)
+        sabd            v21.8h,  v5.8h,   v7.8h
+        ssubl           v16.4s,  v6.4h,   v4.4h   // tmp2 - tmp1 (requires 17 bit)
+        ssubl2          v17.4s,  v6.8h,   v4.8h
+        ssubl           v18.4s,  v7.4h,   v5.4h
+        ssubl2          v19.4s,  v7.8h,   v5.8h
+        uqsub           v20.8h,  v0.8h,   v20.8h  // 27615 - abs()
+        uqsub           v21.8h,  v0.8h,   v21.8h
+        sshll2          v7.4s,   v5.8h,   #6      // tmp1 << 6
+        sshll           v6.4s,   v5.4h,   #6
+        sshll2          v5.4s,   v4.8h,   #6
+        sshll           v4.4s,   v4.4h,   #6
+        ushr            v20.8h,  v20.8h,  #10     // 64-m = (27615 - abs()) >> mask_sh
+        ushr            v21.8h,  v21.8h,  #10
+        add             v4.4s,   v4.4s,   v30.4s  // += PREP_BIAS*64
+        add             v5.4s,   v5.4s,   v30.4s
+        add             v6.4s,   v6.4s,   v30.4s
+        add             v7.4s,   v7.4s,   v30.4s
+        uxtl            v22.4s,  v20.4h
+        uxtl2           v23.4s,  v20.8h
+        uxtl            v24.4s,  v21.4h
+        uxtl2           v25.4s,  v21.8h
+        mla             v4.4s,   v16.4s,  v22.4s  // (tmp2-tmp1)*(64-m)
+        mla             v5.4s,   v17.4s,  v23.4s
+        mla             v6.4s,   v18.4s,  v24.4s
+        mla             v7.4s,   v19.4s,  v25.4s
+        srshl           v4.4s,   v4.4s,   v29.4s  // (tmp1<<6 + (tmp2-tmp1)*(64-m) + (1 << (sh-1)) + PREP_BIAS*64) >> sh
+        srshl           v5.4s,   v5.4s,   v29.4s
+        srshl           v6.4s,   v6.4s,   v29.4s
+        srshl           v7.4s,   v7.4s,   v29.4s
+        sqxtun          v4.4h,   v4.4s            // iclip_pixel
+        sqxtun2         v4.8h,   v5.4s
+        sqxtun          v5.4h,   v6.4s
+        sqxtun2         v5.8h,   v7.4s
+        umin            v4.8h,   v4.8h,   v31.8h  // iclip_pixel
+        umin            v5.8h,   v5.8h,   v31.8h
+.if \type == 444
+        xtn             v20.8b,  v20.8h           // 64 - m
+        xtn2            v20.16b, v21.8h
+        sub             v20.16b, v1.16b,  v20.16b // m
+        st1             {v20.16b}, [x6], #16
+.elseif \type == 422
+        addp            v20.8h,  v20.8h,  v21.8h  // (64 - m) + (64 - n) (column wise addition)
+        xtn             v20.8b,  v20.8h
+        uhsub           v20.8b,  v3.8b,   v20.8b  // ((129 - sign) - ((64 - m) + (64 - n)) >> 1
+        st1             {v20.8b}, [x6], #8
+.elseif \type == 420
+        trn1            v24.2d,  v20.2d,  v21.2d
+        trn2            v25.2d,  v20.2d,  v21.2d
+        add             v24.8h,  v24.8h,  v25.8h  // (64 - my1) + (64 - my2) (row wise addition)
+        addp            v20.8h,  v24.8h,  v24.8h  // (128 - m) + (128 - n) (column wise addition)
+        sub             v20.4h,  v3.4h,   v20.4h  // (256 - sign) - ((128 - m) + (128 - n))
+        rshrn           v20.8b,  v20.8h,  #2      // ((256 - sign) - ((128 - m) + (128 - n)) + 2) >> 2
+        st1             {v20.s}[0], [x6], #4
+.endif
+        st1             {v4.d}[0],  [x0],  x1
+        st1             {v4.d}[1],  [x12], x1
+        st1             {v5.d}[0],  [x0],  x1
+        st1             {v5.d}[1],  [x12], x1
+        b.gt            4b
+        ret
+8:
+        ld1             {v4.8h, v5.8h}, [x2], #32 // tmp1
+        ld1             {v6.8h, v7.8h}, [x3], #32 // tmp2
+        subs            w5,  w5,  #2
+        sabd            v20.8h,  v4.8h,   v6.8h   // abs(tmp1 - tmp2)
+        sabd            v21.8h,  v5.8h,   v7.8h
+        ssubl           v16.4s,  v6.4h,   v4.4h   // tmp2 - tmp1 (requires 17 bit)
+        ssubl2          v17.4s,  v6.8h,   v4.8h
+        ssubl           v18.4s,  v7.4h,   v5.4h
+        ssubl2          v19.4s,  v7.8h,   v5.8h
+        uqsub           v20.8h,  v0.8h,   v20.8h  // 27615 - abs()
+        uqsub           v21.8h,  v0.8h,   v21.8h
+        sshll2          v7.4s,   v5.8h,   #6      // tmp1 << 6
+        sshll           v6.4s,   v5.4h,   #6
+        sshll2          v5.4s,   v4.8h,   #6
+        sshll           v4.4s,   v4.4h,   #6
+        ushr            v20.8h,  v20.8h,  #10     // 64-m = (27615 - abs()) >> mask_sh
+        ushr            v21.8h,  v21.8h,  #10
+        add             v4.4s,   v4.4s,   v30.4s  // += PREP_BIAS*64
+        add             v5.4s,   v5.4s,   v30.4s
+        add             v6.4s,   v6.4s,   v30.4s
+        add             v7.4s,   v7.4s,   v30.4s
+        uxtl            v22.4s,  v20.4h
+        uxtl2           v23.4s,  v20.8h
+        uxtl            v24.4s,  v21.4h
+        uxtl2           v25.4s,  v21.8h
+        mla             v4.4s,   v16.4s,  v22.4s  // (tmp2-tmp1)*(64-m)
+        mla             v5.4s,   v17.4s,  v23.4s
+        mla             v6.4s,   v18.4s,  v24.4s
+        mla             v7.4s,   v19.4s,  v25.4s
+        srshl           v4.4s,   v4.4s,   v29.4s  // (tmp1<<6 + (tmp2-tmp1)*(64-m) + (1 << (sh-1)) + PREP_BIAS*64) >> sh
+        srshl           v5.4s,   v5.4s,   v29.4s
+        srshl           v6.4s,   v6.4s,   v29.4s
+        srshl           v7.4s,   v7.4s,   v29.4s
+        sqxtun          v4.4h,   v4.4s            // iclip_pixel
+        sqxtun2         v4.8h,   v5.4s
+        sqxtun          v5.4h,   v6.4s
+        sqxtun2         v5.8h,   v7.4s
+        umin            v4.8h,   v4.8h,   v31.8h  // iclip_pixel
+        umin            v5.8h,   v5.8h,   v31.8h
+.if \type == 444
+        xtn             v20.8b,  v20.8h           // 64 - m
+        xtn2            v20.16b, v21.8h
+        sub             v20.16b, v1.16b,  v20.16b // m
+        st1             {v20.16b}, [x6], #16
+.elseif \type == 422
+        addp            v20.8h,  v20.8h,  v21.8h  // (64 - m) + (64 - n) (column wise addition)
+        xtn             v20.8b,  v20.8h
+        uhsub           v20.8b,  v3.8b,   v20.8b  // ((129 - sign) - ((64 - m) + (64 - n)) >> 1
+        st1             {v20.8b}, [x6], #8
+.elseif \type == 420
+        add             v20.8h,  v20.8h,  v21.8h  // (64 - my1) + (64 - my2) (row wise addition)
+        addp            v20.8h,  v20.8h,  v20.8h  // (128 - m) + (128 - n) (column wise addition)
+        sub             v20.4h,  v3.4h,   v20.4h  // (256 - sign) - ((128 - m) + (128 - n))
+        rshrn           v20.8b,  v20.8h,  #2      // ((256 - sign) - ((128 - m) + (128 - n)) + 2) >> 2
+        st1             {v20.s}[0], [x6], #4
+.endif
+        st1             {v4.8h}, [x0],  x1
+        st1             {v5.8h}, [x12], x1
+        b.gt            8b
+        ret
+1280:
+640:
+320:
+160:
+        mov             w11, w4
+        sub             x1,  x1,  w4,  uxtw #1
+.if \type == 444
+        add             x10, x6,  w4,  uxtw
+.elseif \type == 422
+        add             x10, x6,  x11, lsr #1
+.endif
+        add             x9,  x3,  w4,  uxtw #1
+        add             x7,  x2,  w4,  uxtw #1
+161:
+        mov             w8,  w4
+16:
+        ld1             {v4.8h,   v5.8h},  [x2], #32 // tmp1
+        ld1             {v16.8h,  v17.8h}, [x3], #32 // tmp2
+        ld1             {v6.8h,   v7.8h},  [x7], #32
+        ld1             {v18.8h,  v19.8h}, [x9], #32
+        subs            w8,  w8,  #16
+        sabd            v20.8h,  v4.8h,   v16.8h  // abs(tmp1 - tmp2)
+        sabd            v21.8h,  v5.8h,   v17.8h
+        ssubl           v22.4s,  v16.4h,  v4.4h   // tmp2 - tmp1 (requires 17 bit)
+        ssubl2          v23.4s,  v16.8h,  v4.8h
+        ssubl           v24.4s,  v17.4h,  v5.4h
+        ssubl2          v25.4s,  v17.8h,  v5.8h
+        uqsub           v20.8h,  v0.8h,   v20.8h  // 27615 - abs()
+        uqsub           v21.8h,  v0.8h,   v21.8h
+        sshll2          v27.4s,  v5.8h,   #6      // tmp1 << 6
+        sshll           v26.4s,  v5.4h,   #6
+        sshll2          v5.4s,   v4.8h,   #6
+        sshll           v4.4s,   v4.4h,   #6
+        ushr            v20.8h,  v20.8h,  #10     // 64-m = (27615 - abs()) >> mask_sh
+        ushr            v21.8h,  v21.8h,  #10
+        add             v4.4s,   v4.4s,   v30.4s  // += PREP_BIAS*64
+        add             v5.4s,   v5.4s,   v30.4s
+        add             v26.4s,  v26.4s,  v30.4s
+        add             v27.4s,  v27.4s,  v30.4s
+        uxtl            v16.4s,  v20.4h
+        uxtl2           v17.4s,  v20.8h
+        uxtl            v28.4s,  v21.4h
+        mla             v4.4s,   v22.4s,  v16.4s  // (tmp2-tmp1)*(64-m)
+        uxtl2           v16.4s,  v21.8h
+        mla             v5.4s,   v23.4s,  v17.4s
+        mla             v26.4s,  v24.4s,  v28.4s
+        mla             v27.4s,  v25.4s,  v16.4s
+        srshl           v4.4s,   v4.4s,   v29.4s  // (tmp1<<6 + (tmp2-tmp1)*(64-m) + (1 << (sh-1)) + PREP_BIAS*64) >> sh
+        srshl           v5.4s,   v5.4s,   v29.4s
+        srshl           v26.4s,  v26.4s,  v29.4s
+        srshl           v27.4s,  v27.4s,  v29.4s
+        sqxtun          v4.4h,   v4.4s            // iclip_pixel
+        sqxtun2         v4.8h,   v5.4s
+        sqxtun          v5.4h,   v26.4s
+        sqxtun2         v5.8h,   v27.4s
+
+        // Start of other half
+        sabd            v22.8h,  v6.8h,   v18.8h  // abs(tmp1 - tmp2)
+        sabd            v23.8h,  v7.8h,   v19.8h
+
+        umin            v4.8h,   v4.8h,   v31.8h  // iclip_pixel
+        umin            v5.8h,   v5.8h,   v31.8h
+
+        ssubl           v16.4s,  v18.4h,  v6.4h   // tmp2 - tmp1 (requires 17 bit)
+        ssubl2          v17.4s,  v18.8h,  v6.8h
+        ssubl           v18.4s,  v19.4h,  v7.4h
+        ssubl2          v19.4s,  v19.8h,  v7.8h
+        uqsub           v22.8h,  v0.8h,   v22.8h  // 27615 - abs()
+        uqsub           v23.8h,  v0.8h,   v23.8h
+        sshll           v24.4s,  v6.4h,   #6      // tmp1 << 6
+        sshll2          v25.4s,  v6.8h,   #6
+        sshll           v26.4s,  v7.4h,   #6
+        sshll2          v27.4s,  v7.8h,   #6
+        ushr            v22.8h,  v22.8h,  #10     // 64-m = (27615 - abs()) >> mask_sh
+        ushr            v23.8h,  v23.8h,  #10
+        add             v24.4s,  v24.4s,  v30.4s  // += PREP_BIAS*64
+        add             v25.4s,  v25.4s,  v30.4s
+        add             v26.4s,  v26.4s,  v30.4s
+        add             v27.4s,  v27.4s,  v30.4s
+        uxtl            v6.4s,   v22.4h
+        uxtl2           v7.4s,   v22.8h
+        uxtl            v28.4s,  v23.4h
+        mla             v24.4s,  v16.4s,  v6.4s   // (tmp2-tmp1)*(64-m)
+        uxtl2           v6.4s,   v23.8h
+        mla             v25.4s,  v17.4s,  v7.4s
+        mla             v26.4s,  v18.4s,  v28.4s
+        mla             v27.4s,  v19.4s,  v6.4s
+        srshl           v24.4s,  v24.4s,  v29.4s  // (tmp1<<6 + (tmp2-tmp1)*(64-m) + (1 << (sh-1)) + PREP_BIAS*64) >> sh
+        srshl           v25.4s,  v25.4s,  v29.4s
+        srshl           v26.4s,  v26.4s,  v29.4s
+        srshl           v27.4s,  v27.4s,  v29.4s
+        sqxtun          v6.4h,   v24.4s           // iclip_pixel
+        sqxtun2         v6.8h,   v25.4s
+        sqxtun          v7.4h,   v26.4s
+        sqxtun2         v7.8h,   v27.4s
+        umin            v6.8h,   v6.8h,   v31.8h  // iclip_pixel
+        umin            v7.8h,   v7.8h,   v31.8h
+.if \type == 444
+        xtn             v20.8b,  v20.8h           // 64 - m
+        xtn2            v20.16b, v21.8h
+        xtn             v21.8b,  v22.8h
+        xtn2            v21.16b, v23.8h
+        sub             v20.16b, v1.16b,  v20.16b // m
+        sub             v21.16b, v1.16b,  v21.16b
+        st1             {v20.16b}, [x6],  #16
+        st1             {v21.16b}, [x10], #16
+.elseif \type == 422
+        addp            v20.8h,  v20.8h,  v21.8h  // (64 - m) + (64 - n) (column wise addition)
+        addp            v21.8h,  v22.8h,  v23.8h
+        xtn             v20.8b,  v20.8h
+        xtn             v21.8b,  v21.8h
+        uhsub           v20.8b,  v3.8b,   v20.8b  // ((129 - sign) - ((64 - m) + (64 - n)) >> 1
+        uhsub           v21.8b,  v3.8b,   v21.8b
+        st1             {v20.8b}, [x6],  #8
+        st1             {v21.8b}, [x10], #8
+.elseif \type == 420
+        add             v20.8h,  v20.8h,  v22.8h  // (64 - my1) + (64 - my2) (row wise addition)
+        add             v21.8h,  v21.8h,  v23.8h
+        addp            v20.8h,  v20.8h,  v21.8h  // (128 - m) + (128 - n) (column wise addition)
+        sub             v20.8h,  v3.8h,   v20.8h  // (256 - sign) - ((128 - m) + (128 - n))
+        rshrn           v20.8b,  v20.8h,  #2      // ((256 - sign) - ((128 - m) + (128 - n)) + 2) >> 2
+        st1             {v20.8b}, [x6], #8
+.endif
+        st1             {v4.8h, v5.8h}, [x0],  #32
+        st1             {v6.8h, v7.8h}, [x12], #32
+        b.gt            16b
+        subs            w5,  w5,  #2
+        add             x2,  x2,  w4,  uxtw #1
+        add             x3,  x3,  w4,  uxtw #1
+        add             x7,  x7,  w4,  uxtw #1
+        add             x9,  x9,  w4,  uxtw #1
+.if \type == 444
+        add             x6,  x6,  w4,  uxtw
+        add             x10, x10, w4,  uxtw
+.elseif \type == 422
+        add             x6,  x6,  x11, lsr #1
+        add             x10, x10, x11, lsr #1
+.endif
+        add             x0,  x0,  x1
+        add             x12, x12, x1
+        b.gt            161b
+        ret
+L(w_mask_\type\()_tbl):
+        .hword L(w_mask_\type\()_tbl) - 1280b
+        .hword L(w_mask_\type\()_tbl) -  640b
+        .hword L(w_mask_\type\()_tbl) -  320b
+        .hword L(w_mask_\type\()_tbl) -  160b
+        .hword L(w_mask_\type\()_tbl) -    8b
+        .hword L(w_mask_\type\()_tbl) -    4b
+endfunc
+.endm
+
+w_mask_fn 444
+w_mask_fn 422
+w_mask_fn 420
+
+
 function blend_16bpc_neon, export=1
         adr             x6,  L(blend_tbl)
         clz             w3,  w3
--- a/src/arm/mc_init_tmpl.c
+++ b/src/arm/mc_init_tmpl.c
@@ -104,13 +104,9 @@
     c->blend = BF(dav1d_blend, neon);
     c->blend_h = BF(dav1d_blend_h, neon);
     c->blend_v = BF(dav1d_blend_v, neon);
-#endif
-#if BITDEPTH == 8
     c->w_mask[0] = BF(dav1d_w_mask_444, neon);
     c->w_mask[1] = BF(dav1d_w_mask_422, neon);
     c->w_mask[2] = BF(dav1d_w_mask_420, neon);
-#endif
-#if BITDEPTH == 8 || ARCH_AARCH64
     c->warp8x8 = BF(dav1d_warp_affine_8x8, neon);
     c->warp8x8t = BF(dav1d_warp_affine_8x8t, neon);
 #endif