shithub: dav1d

Download patch

ref: 370200cd99bbc24c4d5c30e51ec2ef9bcf7f986e
parent: 8bc5487093271b239c5d4577e91628845b6309cc
author: Martin Storsjö <martin@martin.st>
date: Tue Jun 16 10:06:02 EDT 2020

arm64: msac: Add a special cased implementation of decode_hi_tok

The speedup (over the normal version, that just calls the existing
assembly version of symbol_adapt4) is not very impressive on
bigger cores, but looks decent on small cores. It's an improvement
though, in any case.

                             Cortex A53    A72    A73
msac_decode_hi_tok_c:             175.7  136.2  138.1
msac_decode_hi_tok_neon:          146.8  129.4  125.9

--- a/src/arm/64/msac.S
+++ b/src/arm/64/msac.S
@@ -274,6 +274,128 @@
         b               L(renorm)
 endfunc
 
+function msac_decode_hi_tok_neon, export=1
+        ld1             {v0.4h},  [x1]            // cdf
+        add             x16, x0,  #RNG
+        movi            v31.4h, #0x7f, lsl #8     // 0x7f00
+        movrel          x17, coeffs, 30-2*3
+        mvni            v30.4h, #0x3f             // 0xffc0
+        ldrh            w9,  [x1, #6]             // count = cdf[n_symbols]
+        ld1r            {v3.4h},  [x16]           // rng
+        movrel          x16, bits
+        ld1             {v29.4h}, [x17]           // EC_MIN_PROB * (n_symbols - ret)
+        add             x17, x0,  #DIF + 6
+        ld1             {v16.8h}, [x16]
+        mov             w13, #-24
+        and             v17.8b,  v0.8b,   v30.8b  // cdf & 0xffc0
+        ldr             w10, [x0, #ALLOW_UPDATE_CDF]
+        ld1r            {v1.8h},  [x17]           // dif >> (EC_WIN_SIZE - 16)
+        sub             sp,  sp,  #48
+        ldr             w6,  [x0, #CNT]
+        ldr             x7,  [x0, #DIF]
+1:
+        and             v7.8b,   v3.8b,   v31.8b  // rng & 0x7f00
+        sqdmulh         v6.4h,   v17.4h,  v7.4h   // ((cdf >> EC_PROB_SHIFT) * (r - 128)) >> 1
+        add             v4.4h,   v17.4h,  v29.4h  // v = cdf + EC_MIN_PROB * (n_symbols - ret)
+        add             v4.4h,   v6.4h,   v4.4h   // v = ((cdf >> EC_PROB_SHIFT) * r) >> 1 + EC_MIN_PROB * (n_symbols - ret)
+        str             h3,  [sp, #14]            // store original u = s->rng
+        cmhs            v2.8h,   v1.8h,   v4.8h   // c >= v
+        str             q4,  [sp, #16]            // store v values to allow indexed access
+        and             v6.16b,  v2.16b,  v16.16b // One bit per halfword set in the mask
+        addv            h6,  v6.8h                // Aggregate mask bits
+        umov            w3,  v6.h[0]
+        add             w13, w13, #5
+        rbit            w3,  w3
+        add             x8,  sp,  #16
+        clz             w15, w3                   // ret
+
+        cbz             w10, 2f
+        // update_cdf
+        movi            v5.8b, #0xff
+        mov             w4,  #-5
+        urhadd          v4.4h,   v5.4h,   v2.4h   // i >= val ? -1 : 32768
+        sub             w4,  w4,  w9, lsr #4      // -((count >> 4) + 5)
+        sub             v4.4h,   v4.4h,   v0.4h   // (32768 - cdf[i]) or (-1 - cdf[i])
+        dup             v6.4h,    w4              // -rate
+
+        sub             w9,  w9,  w9, lsr #5      // count - (count == 32)
+        sub             v0.4h,   v0.4h,   v2.4h   // cdf + (i >= val ? 1 : 0)
+        sshl            v4.4h,   v4.4h,   v6.4h   // ({32768,-1} - cdf[i]) >> rate
+        add             w9,  w9,  #1              // count + (count < 32)
+        add             v0.4h,   v0.4h,   v4.4h   // cdf + (32768 - cdf[i]) >> rate
+        st1             {v0.4h},  [x1]
+        and             v17.8b,  v0.8b,   v30.8b  // cdf & 0xffc0
+        strh            w9,  [x1, #6]
+
+2:
+        add             x8,  x8,  w15, uxtw #1
+        ldrh            w3,  [x8]              // v
+        ldurh           w4,  [x8, #-2]         // u
+        sub             w4,  w4,  w3           // rng = u - v
+        clz             w5,  w4                // clz(rng)
+        eor             w5,  w5,  #16          // d = clz(rng) ^ 16
+        mvn             x7,  x7                // ~dif
+        add             x7,  x7,  x3, lsl #48  // ~dif + (v << 48)
+        lsl             w4,  w4,  w5           // rng << d
+        subs            w6,  w6,  w5           // cnt -= d
+        lsl             x7,  x7,  x5           // (~dif + (v << 48)) << d
+        str             w4,  [x0, #RNG]
+        dup             v3.4h,   w4
+        mvn             x7,  x7                // ~dif
+        b.hs            9f
+
+        // refill
+        ldp             x3,  x4,  [x0]         // BUF_POS, BUF_END
+        add             x5,  x3,  #8
+        cmp             x5,  x4
+        b.gt            2f
+
+        ldr             x3,  [x3]              // next_bits
+        add             w8,  w6,  #23          // shift_bits = cnt + 23
+        add             w6,  w6,  #16          // cnt += 16
+        rev             x3,  x3                // next_bits = bswap(next_bits)
+        sub             x5,  x5,  x8, lsr #3   // buf_pos -= shift_bits >> 3
+        and             w8,  w8,  #24          // shift_bits &= 24
+        lsr             x3,  x3,  x8           // next_bits >>= shift_bits
+        sub             w8,  w8,  w6           // shift_bits -= 16 + cnt
+        str             x5,  [x0, #BUF_POS]
+        lsl             x3,  x3,  x8           // next_bits <<= shift_bits
+        mov             w4,  #48
+        sub             w6,  w4,  w8           // cnt = cnt + 64 - shift_bits
+        eor             x7,  x7,  x3           // dif ^= next_bits
+        b               9f
+
+2:      // refill_eob
+        mov             w14, #40
+        sub             w5,  w14, w6           // c = 40 - cnt
+3:
+        cmp             x3,  x4
+        b.ge            4f
+        ldrb            w8,  [x3], #1
+        lsl             x8,  x8,  x5
+        eor             x7,  x7,  x8
+        subs            w5,  w5,  #8
+        b.ge            3b
+
+4:      // refill_eob_end
+        str             x3,  [x0, #BUF_POS]
+        sub             w6,  w14, w5           // cnt = 40 - c
+
+9:
+        lsl             w15, w15, #1
+        sub             w15, w15, #5
+        lsr             x12, x7,  #48
+        adds            w13, w13, w15          // carry = tok_br < 3 || tok == 15
+        dup             v1.8h,   w12
+        b.cc            1b                     // loop if !carry
+        add             w13, w13, #30
+        str             w6,  [x0, #CNT]
+        add             sp,  sp,  #48
+        str             x7,  [x0, #DIF]
+        lsr             w0,  w13, #1
+        ret
+endfunc
+
 function msac_decode_bool_equi_neon, export=1
         ldp             w5,  w6,  [x0, #RNG]   // + CNT
         sub             sp,  sp,  #48
--- a/src/arm/msac.h
+++ b/src/arm/msac.h
@@ -34,6 +34,7 @@
                                               size_t n_symbols);
 unsigned dav1d_msac_decode_symbol_adapt16_neon(MsacContext *s, uint16_t *cdf,
                                                size_t n_symbols);
+unsigned dav1d_msac_decode_hi_tok_neon(MsacContext *s, uint16_t *cdf);
 unsigned dav1d_msac_decode_bool_adapt_neon(MsacContext *s, uint16_t *cdf);
 unsigned dav1d_msac_decode_bool_equi_neon(MsacContext *s);
 unsigned dav1d_msac_decode_bool_neon(MsacContext *s, unsigned f);
@@ -42,6 +43,7 @@
 #define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt4_neon
 #define dav1d_msac_decode_symbol_adapt8  dav1d_msac_decode_symbol_adapt8_neon
 #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_neon
+#define dav1d_msac_decode_hi_tok         dav1d_msac_decode_hi_tok_neon
 #define dav1d_msac_decode_bool_adapt     dav1d_msac_decode_bool_adapt_neon
 #define dav1d_msac_decode_bool_equi      dav1d_msac_decode_bool_equi_neon
 #define dav1d_msac_decode_bool           dav1d_msac_decode_bool_neon
--- a/tests/checkasm/msac.c
+++ b/tests/checkasm/msac.c
@@ -247,6 +247,7 @@
         c.bool_adapt     = dav1d_msac_decode_bool_adapt_neon;
         c.bool_equi      = dav1d_msac_decode_bool_equi_neon;
         c.bool           = dav1d_msac_decode_bool_neon;
+        c.hi_tok         = dav1d_msac_decode_hi_tok_neon;
     }
 #elif ARCH_X86 && HAVE_ASM
     if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_SSE2) {