shithub: dav1d

Download patch

ref: e705519d406941886431300ca432d33980cb554c
parent: e1be33b9c8cb20c62b26b9e3f02d206ddf54a80e
author: Martin Storsjö <martin@martin.st>
date: Thu Nov 26 17:48:26 EST 2020

arm32: looprestoration: NEON implementation of SGR for 10 bpc

Checkasm numbers:           Cortex A7         A8       A53       A72       A73
selfguided_3x3_10bpc_neon:   919127.6   717942.8  565717.8  404748.0  372179.8
selfguided_5x5_10bpc_neon:   640310.8   511873.4  370653.3  273593.7  256403.2
selfguided_mix_10bpc_neon:  1533887.0  1252389.5  922111.1  659033.4  613410.6

Corresponding numbers for arm64, for comparison:

                                                Cortex A53       A72       A73
selfguided_3x3_10bpc_neon:                        500706.0  367199.2  345261.2
selfguided_5x5_10bpc_neon:                        361403.3  270550.0  249955.3
selfguided_mix_10bpc_neon:                        846172.4  623590.3  578404.8

--- a/src/arm/32/looprestoration.S
+++ b/src/arm/32/looprestoration.S
@@ -1239,3 +1239,5 @@
         pop             {r4-r11,pc}
 .purgem add5
 endfunc
+
+sgr_funcs 8
--- a/src/arm/32/looprestoration16.S
+++ b/src/arm/32/looprestoration16.S
@@ -718,3 +718,568 @@
         bgt             70b
         pop             {r4,pc}
 endfunc
+
+#define SUM_STRIDE (384+16)
+
+#include "looprestoration_tmpl.S"
+
+// void dav1d_sgr_box3_h_16bpc_neon(int32_t *sumsq, int16_t *sum,
+//                                  const pixel (*left)[4],
+//                                  const pixel *src, const ptrdiff_t stride,
+//                                  const int w, const int h,
+//                                  const enum LrEdgeFlags edges);
+function sgr_box3_h_16bpc_neon, export=1
+        push            {r4-r11,lr}
+        vpush           {q4-q7}
+        ldrd            r4,  r5,  [sp, #100]
+        ldrd            r6,  r7,  [sp, #108]
+        add             r5,  r5,  #2 // w += 2
+
+        // Set up pointers for reading/writing alternate rows
+        add             r10, r0,  #(4*SUM_STRIDE)   // sumsq
+        add             r11, r1,  #(2*SUM_STRIDE)   // sum
+        add             r12, r3,  r4                // src
+        lsl             r4,  r4,  #1
+        mov             r9,       #(2*2*SUM_STRIDE) // double sum stride
+
+        // Subtract the aligned width from the output stride.
+        // With LR_HAVE_RIGHT, align to 8, without it, align to 4.
+        tst             r7,  #2 // LR_HAVE_RIGHT
+        bne             0f
+        // !LR_HAVE_RIGHT
+        add             lr,  r5,  #3
+        bic             lr,  lr,  #3
+        b               1f
+0:
+        add             lr,  r5,  #7
+        bic             lr,  lr,  #7
+1:
+        sub             r9,  r9,  lr, lsl #1
+
+        // Store the width for the vertical loop
+        mov             r8,  r5
+
+        // Subtract the number of pixels read from the input from the stride
+        add             lr,  r5,  #14
+        bic             lr,  lr,  #7
+        sub             r4,  r4,  lr, lsl #1
+
+        // Set up the src pointers to include the left edge, for LR_HAVE_LEFT, left == NULL
+        tst             r7,  #1 // LR_HAVE_LEFT
+        beq             2f
+        // LR_HAVE_LEFT
+        cmp             r2,  #0
+        bne             0f
+        // left == NULL
+        sub             r3,  r3,  #4
+        sub             r12, r12, #4
+        b               1f
+0:      // LR_HAVE_LEFT, left != NULL
+2:      // !LR_HAVE_LEFT, increase the stride.
+        // For this case we don't read the left 2 pixels from the src pointer,
+        // but shift it as if we had done that.
+        add             r4,  r4,  #4
+
+
+1:      // Loop vertically
+        vld1.16         {q0, q1}, [r3]!
+        vld1.16         {q4, q5}, [r12]!
+
+        tst             r7,  #1 // LR_HAVE_LEFT
+        beq             0f
+        cmp             r2,  #0
+        beq             2f
+        // LR_HAVE_LEFT, left != NULL
+        vld1.16         {d5}, [r2]!
+        // Move r3/r12 back to account for the last 2 pixels we loaded earlier,
+        // which we'll shift out.
+        sub             r3,  r3,  #4
+        sub             r12, r12, #4
+        vld1.16         {d13}, [r2]!
+        vext.8          q1,  q0,  q1,  #12
+        vext.8          q0,  q2,  q0,  #12
+        vext.8          q5,  q4,  q5,  #12
+        vext.8          q4,  q6,  q4,  #12
+        b               2f
+0:
+        // !LR_HAVE_LEFT, fill q2 with the leftmost pixel
+        // and shift q0 to have 2x the first byte at the front.
+        vdup.16         q2,  d0[0]
+        vdup.16         q6,  d8[0]
+        // Move r3 back to account for the last 2 pixels we loaded before,
+        // which we shifted out.
+        sub             r3,  r3,  #4
+        sub             r12, r12, #4
+        vext.8          q1,  q0,  q1,  #12
+        vext.8          q0,  q2,  q0,  #12
+        vext.8          q5,  q4,  q5,  #12
+        vext.8          q4,  q6,  q4,  #12
+
+2:
+        tst             r7,  #2 // LR_HAVE_RIGHT
+        bne             4f
+        // If we'll need to pad the right edge, load that pixel to pad with
+        // here since we can find it pretty easily from here.
+        sub             lr,  r5,  #(2 + 16 - 2 + 1)
+        lsl             lr,  lr,  #1
+        ldrh            r11, [r3,  lr]
+        ldrh            lr,  [r12, lr]
+        // Fill q14/q15 with the right padding pixel
+        vdup.16         q14, r11
+        vdup.16         q15, lr
+        // Restore r11 after using it for a temporary value
+        add             r11, r1,  #(2*SUM_STRIDE)
+3:      // !LR_HAVE_RIGHT
+        // If we'll have to pad the right edge we need to quit early here.
+        cmp             r5,  #10
+        bge             4f   // If w >= 10, all used input pixels are valid
+        cmp             r5,  #6
+        bge             5f   // If w >= 6, we can filter 4 pixels
+        b               6f
+
+4:      // Loop horizontally
+.macro add3 w
+.if \w > 4
+        vext.8          q8,  q0,  q1,  #2
+        vext.8          q10, q4,  q5,  #2
+        vext.8          q9,  q0,  q1,  #4
+        vext.8          q11, q4,  q5,  #4
+        vadd.i16        q2,  q0,  q8
+        vadd.i16        q3,  q4,  q10
+        vadd.i16        q2,  q2,  q9
+        vadd.i16        q3,  q3,  q11
+.else
+        vext.8          d16, d0,  d1,  #2
+        vext.8          d20, d8,  d9,  #2
+        vext.8          d18, d0,  d1,  #4
+        vext.8          d22, d8,  d9,  #4
+        vadd.i16        d4,  d0,  d16
+        vadd.i16        d6,  d8,  d20
+        vadd.i16        d4,  d4,  d18
+        vadd.i16        d6,  d6,  d22
+.endif
+
+        vmull.u16       q6,  d0,  d0
+        vmlal.u16       q6,  d16, d16
+        vmlal.u16       q6,  d18, d18
+        vmull.u16       q12, d8,  d8
+        vmlal.u16       q12, d20, d20
+        vmlal.u16       q12, d22, d22
+.if \w > 4
+        vmull.u16       q7,  d1,  d1
+        vmlal.u16       q7,  d17, d17
+        vmlal.u16       q7,  d19, d19
+        vmull.u16       q13, d9,  d9
+        vmlal.u16       q13, d21, d21
+        vmlal.u16       q13, d23, d23
+.endif
+.endm
+        add3            8
+        vst1.16         {q2},       [r1,  :128]!
+        vst1.16         {q3},       [r11, :128]!
+        vst1.32         {q6,  q7},  [r0,  :128]!
+        vst1.32         {q12, q13}, [r10, :128]!
+
+        subs            r5,  r5,  #8
+        ble             9f
+        tst             r7,  #2 // LR_HAVE_RIGHT
+        vmov            q0,  q1
+        vmov            q4,  q5
+        vld1.16         {q1}, [r3]!
+        vld1.16         {q5}, [r12]!
+
+        bne             4b // If we don't need to pad, just keep summing.
+        b               3b // If we need to pad, check how many pixels we have left.
+
+5:      // Produce 4 pixels, 6 <= w < 10
+        add3            4
+        vst1.16         {d4},  [r1,  :64]!
+        vst1.16         {d6},  [r11, :64]!
+        vst1.32         {q6},  [r0,  :128]!
+        vst1.32         {q12}, [r10, :128]!
+
+        subs            r5,  r5,  #4 // 2 <= w < 6
+        vext.8          q0,  q0,  q1,  #8
+        vext.8          q4,  q4,  q5,  #8
+
+6:      // Pad the right edge and produce the last few pixels.
+        // 2 <= w < 6, 2-5 pixels valid in q0
+        sub             lr,  r5,  #2
+        // lr = (pixels valid - 2)
+        adr             r11, L(box3_variable_shift_tbl)
+        ldr             lr,  [r11, lr, lsl #2]
+        add             r11, r11, lr
+        bx              r11
+
+        .align 2
+L(box3_variable_shift_tbl):
+        .word 22f - L(box3_variable_shift_tbl) + CONFIG_THUMB
+        .word 33f - L(box3_variable_shift_tbl) + CONFIG_THUMB
+        .word 44f - L(box3_variable_shift_tbl) + CONFIG_THUMB
+        .word 55f - L(box3_variable_shift_tbl) + CONFIG_THUMB
+
+        // Shift q0 right, shifting out invalid pixels,
+        // shift q0 left to the original offset, shifting in padding pixels.
+22:     // 2 pixels valid
+        vext.8          q0,  q0,  q0,  #4
+        vext.8          q4,  q4,  q4,  #4
+        vext.8          q0,  q0,  q14, #12
+        vext.8          q4,  q4,  q15, #12
+        b               88f
+33:     // 3 pixels valid
+        vext.8          q0,  q0,  q0,  #6
+        vext.8          q4,  q4,  q4,  #6
+        vext.8          q0,  q0,  q14, #10
+        vext.8          q4,  q4,  q15, #10
+        b               88f
+44:     // 4 pixels valid
+        vmov            d1,  d28
+        vmov            d9,  d30
+        b               88f
+55:     // 5 pixels valid
+        vext.8          q0,  q0,  q0,  #10
+        vext.8          q4,  q4,  q4,  #10
+        vext.8          q0,  q0,  q14, #6
+        vext.8          q4,  q4,  q15, #6
+
+88:
+        // Restore r11 after using it for a temporary value above
+        add             r11, r1,  #(2*SUM_STRIDE)
+
+        add3            4
+        subs            r5,  r5,  #4
+        vst1.16         {d4},  [r1,  :64]!
+        vst1.16         {d6},  [r11, :64]!
+        vst1.32         {q6},  [r0,  :128]!
+        vst1.32         {q12}, [r10, :128]!
+        ble             9f
+        vext.8          q0,  q0,  q0,  #8
+        vext.8          q4,  q4,  q4,  #8
+        // Only one needed pixel left, but do a normal 4 pixel
+        // addition anyway
+        add3            4
+        vst1.16         {d4},  [r1,  :64]!
+        vst1.16         {d6},  [r11, :64]!
+        vst1.32         {q6},  [r0,  :128]!
+        vst1.32         {q12}, [r10, :128]!
+
+9:
+        subs            r6,  r6,  #2
+        ble             0f
+        // Jump to the next row and loop horizontally
+        add             r0,  r0,  r9, lsl #1
+        add             r10, r10, r9, lsl #1
+        add             r1,  r1,  r9
+        add             r11, r11, r9
+        add             r3,  r3,  r4
+        add             r12, r12, r4
+        mov             r5,  r8
+        b               1b
+0:
+        vpop            {q4-q7}
+        pop             {r4-r11,pc}
+.purgem add3
+endfunc
+
+// void dav1d_sgr_box5_h_16bpc_neon(int32_t *sumsq, int16_t *sum,
+//                                  const pixel (*left)[4],
+//                                  const pixel *src, const ptrdiff_t stride,
+//                                  const int w, const int h,
+//                                  const enum LrEdgeFlags edges);
+function sgr_box5_h_16bpc_neon, export=1
+        push            {r4-r11,lr}
+        vpush           {q4-q7}
+        ldrd            r4,  r5,  [sp, #100]
+        ldrd            r6,  r7,  [sp, #108]
+        add             r5,  r5,  #2 // w += 2
+
+        // Set up pointers for reading/writing alternate rows
+        add             r10, r0,  #(4*SUM_STRIDE)   // sumsq
+        add             r11, r1,  #(2*SUM_STRIDE)   // sum
+        add             r12, r3,  r4                // src
+        lsl             r4,  r4,  #1
+        mov             r9,       #(2*2*SUM_STRIDE) // double sum stride
+
+        // Subtract the aligned width from the output stride.
+        // With LR_HAVE_RIGHT, align to 8, without it, align to 4.
+        // Subtract the number of pixels read from the input from the stride.
+        tst             r7,  #2 // LR_HAVE_RIGHT
+        bne             0f
+        // !LR_HAVE_RIGHT
+        add             lr,  r5,  #3
+        bic             lr,  lr,  #3
+        add             r8,  r5,  #13
+        b               1f
+0:
+        add             lr,  r5,  #7
+        bic             lr,  lr,  #7
+        add             r8,  r5,  #15
+1:
+        sub             r9,  r9,  lr, lsl #1
+        bic             r8,  r8,  #7
+        sub             r4,  r4,  r8, lsl #1
+
+        // Store the width for the vertical loop
+        mov             r8,  r5
+
+        // Set up the src pointers to include the left edge, for LR_HAVE_LEFT, left == NULL
+        tst             r7,  #1 // LR_HAVE_LEFT
+        beq             2f
+        // LR_HAVE_LEFT
+        cmp             r2,  #0
+        bne             0f
+        // left == NULL
+        sub             r3,  r3,  #6
+        sub             r12, r12, #6
+        b               1f
+0:      // LR_HAVE_LEFT, left != NULL
+2:      // !LR_HAVE_LEFT, increase the stride.
+        // For this case we don't read the left 3 pixels from the src pointer,
+        // but shift it as if we had done that.
+        add             r4,  r4,  #6
+
+1:      // Loop vertically
+        vld1.16         {q0, q1}, [r3]!
+        vld1.16         {q4, q5}, [r12]!
+
+        tst             r7,  #1 // LR_HAVE_LEFT
+        beq             0f
+        cmp             r2,  #0
+        beq             2f
+        // LR_HAVE_LEFT, left != NULL
+        vld1.16         {d5}, [r2]!
+        // Move r3/r12 back to account for the last 3 pixels we loaded earlier,
+        // which we'll shift out.
+        sub             r3,  r3,  #6
+        sub             r12, r12, #6
+        vld1.16         {d13}, [r2]!
+        vext.8          q1,  q0,  q1,  #10
+        vext.8          q0,  q2,  q0,  #10
+        vext.8          q5,  q4,  q5,  #10
+        vext.8          q4,  q6,  q4,  #10
+        b               2f
+0:
+        // !LR_HAVE_LEFT, fill q2 with the leftmost pixel
+        // and shift q0 to have 3x the first pixel at the front.
+        vdup.16         q2,  d0[0]
+        vdup.16         q6,  d8[0]
+        // Move r3 back to account for the last 3 pixels we loaded before,
+        // which we shifted out.
+        sub             r3,  r3,  #6
+        sub             r12, r12, #6
+        vext.8          q1,  q0,  q1,  #10
+        vext.8          q0,  q2,  q0,  #10
+        vext.8          q5,  q4,  q5,  #10
+        vext.8          q4,  q6,  q4,  #10
+
+2:
+        tst             r7,  #2 // LR_HAVE_RIGHT
+        bne             4f
+        // If we'll need to pad the right edge, load that pixel to pad with
+        // here since we can find it pretty easily from here.
+        sub             lr,  r5,  #(2 + 16 - 3 + 1)
+        lsl             lr,  lr,  #1
+        ldrh            r11, [r3,  lr]
+        ldrh            lr,  [r12, lr]
+        // Fill q14/q15 with the right padding pixel
+        vdup.16         q14, r11
+        vdup.16         q15, lr
+        // Restore r11 after using it for a temporary value
+        add             r11, r1,  #(2*SUM_STRIDE)
+3:      // !LR_HAVE_RIGHT
+        // If we'll have to pad the right edge we need to quit early here.
+        cmp             r5,  #11
+        bge             4f   // If w >= 11, all used input pixels are valid
+        cmp             r5,  #7
+        bge             5f   // If w >= 7, we can produce 4 pixels
+        b               6f
+
+4:      // Loop horizontally
+.macro add5 w
+.if \w > 4
+        vext.8          q8,  q0,  q1,  #2
+        vext.8          q10, q4,  q5,  #2
+        vext.8          q9,  q0,  q1,  #4
+        vext.8          q11, q4,  q5,  #4
+        vadd.i16        q2,  q0,  q8
+        vadd.i16        q3,  q4,  q10
+        vadd.i16        q2,  q2,  q9
+        vadd.i16        q3,  q3,  q11
+.else
+        vext.8          d16, d0,  d1,  #2
+        vext.8          d20, d8,  d9,  #2
+        vext.8          d18, d0,  d1,  #4
+        vext.8          d22, d8,  d9,  #4
+        vadd.i16        d4,  d0,  d16
+        vadd.i16        d6,  d8,  d20
+        vadd.i16        d4,  d4,  d18
+        vadd.i16        d6,  d6,  d22
+.endif
+
+        vmull.u16       q6,  d0,  d0
+        vmlal.u16       q6,  d16, d16
+        vmlal.u16       q6,  d18, d18
+        vmull.u16       q12, d8,  d8
+        vmlal.u16       q12, d20, d20
+        vmlal.u16       q12, d22, d22
+.if \w > 4
+        vmull.u16       q7,  d1,  d1
+        vmlal.u16       q7,  d17, d17
+        vmlal.u16       q7,  d19, d19
+        vmull.u16       q13, d9,  d9
+        vmlal.u16       q13, d21, d21
+        vmlal.u16       q13, d23, d23
+.endif
+
+.if \w > 4
+        vext.8          q8,  q0,  q1,  #6
+        vext.8          q10, q4,  q5,  #6
+        vext.8          q9,  q0,  q1,  #8
+        vext.8          q11, q4,  q5,  #8
+        vadd.i16        q2,  q2,  q8
+        vadd.i16        q3,  q3,  q10
+        vadd.i16        q2,  q2,  q9
+        vadd.i16        q3,  q3,  q11
+.else
+        vext.8          d16, d0,  d1,  #6
+        // d18 would be equal to d1; using d1 instead
+        vext.8          d20, d8,  d9,  #6
+        // d22 would be equal to d9; using d9 instead
+        vadd.i16        d4,  d4,  d16
+        vadd.i16        d6,  d6,  d20
+        vadd.i16        d4,  d4,  d1
+        vadd.i16        d6,  d6,  d9
+.endif
+
+        vmlal.u16       q6,  d16, d16
+        vmlal.u16       q6,  d1,  d1
+        vmlal.u16       q12, d20, d20
+        vmlal.u16       q12, d9,  d9
+.if \w > 4
+        vmlal.u16       q7,  d17, d17
+        vmlal.u16       q7,  d19, d19
+        vmlal.u16       q13, d21, d21
+        vmlal.u16       q13, d23, d23
+.endif
+.endm
+        add5            8
+        vst1.16         {q2},       [r1,  :128]!
+        vst1.16         {q3},       [r11, :128]!
+        vst1.32         {q6,  q7},  [r0,  :128]!
+        vst1.32         {q12, q13}, [r10, :128]!
+
+        subs            r5,  r5,  #8
+        ble             9f
+        tst             r7,  #2 // LR_HAVE_RIGHT
+        vmov            q0,  q1
+        vmov            q4,  q5
+        vld1.16         {q1}, [r3]!
+        vld1.16         {q5}, [r12]!
+        bne             4b // If we don't need to pad, just keep summing.
+        b               3b // If we need to pad, check how many pixels we have left.
+
+5:      // Produce 4 pixels, 7 <= w < 11
+        add5            4
+        vst1.16         {d4},  [r1,  :64]!
+        vst1.16         {d6},  [r11, :64]!
+        vst1.32         {q6},  [r0,  :128]!
+        vst1.32         {q12}, [r10, :128]!
+
+        subs            r5,  r5,  #4 // 3 <= w < 7
+        vext.8          q0,  q0,  q1,  #8
+        vext.8          q4,  q4,  q5,  #8
+
+6:      // Pad the right edge and produce the last few pixels.
+        // w < 7, w+1 pixels valid in q0/q4
+        sub             lr,  r5,  #1
+        // lr = pixels valid - 2
+        adr             r11, L(box5_variable_shift_tbl)
+        ldr             lr,  [r11, lr, lsl #2]
+        vmov            q1,  q14
+        vmov            q5,  q15
+        add             r11, r11, lr
+        bx              r11
+
+        .align 2
+L(box5_variable_shift_tbl):
+        .word 22f - L(box5_variable_shift_tbl) + CONFIG_THUMB
+        .word 33f - L(box5_variable_shift_tbl) + CONFIG_THUMB
+        .word 44f - L(box5_variable_shift_tbl) + CONFIG_THUMB
+        .word 55f - L(box5_variable_shift_tbl) + CONFIG_THUMB
+        .word 66f - L(box5_variable_shift_tbl) + CONFIG_THUMB
+        .word 77f - L(box5_variable_shift_tbl) + CONFIG_THUMB
+
+        // Shift q0 right, shifting out invalid pixels,
+        // shift q0 left to the original offset, shifting in padding pixels.
+22:     // 2 pixels valid
+        vext.8          q0,  q0,  q0,  #4
+        vext.8          q4,  q4,  q4,  #4
+        vext.8          q0,  q0,  q14, #12
+        vext.8          q4,  q4,  q15, #12
+        b               88f
+33:     // 3 pixels valid
+        vext.8          q0,  q0,  q0,  #6
+        vext.8          q4,  q4,  q4,  #6
+        vext.8          q0,  q0,  q14, #10
+        vext.8          q4,  q4,  q15, #10
+        b               88f
+44:     // 4 pixels valid
+        vmov            d1,  d28
+        vmov            d9,  d30
+        b               88f
+55:     // 5 pixels valid
+        vext.8          q0,  q0,  q0,  #10
+        vext.8          q4,  q4,  q4,  #10
+        vext.8          q0,  q0,  q14, #6
+        vext.8          q4,  q4,  q15, #6
+        b               88f
+66:     // 6 pixels valid
+        vext.8          q0,  q0,  q0,  #12
+        vext.8          q4,  q4,  q4,  #12
+        vext.8          q0,  q0,  q14, #4
+        vext.8          q4,  q4,  q15, #4
+        b               88f
+77:     // 7 pixels valid
+        vext.8          q0,  q0,  q0,  #14
+        vext.8          q4,  q4,  q4,  #14
+        vext.8          q0,  q0,  q14, #2
+        vext.8          q4,  q4,  q15, #2
+
+88:
+        // Restore r11 after using it for a temporary value above
+        add             r11, r1,  #(2*SUM_STRIDE)
+
+        add5            4
+        subs            r5,  r5,  #4
+        vst1.16         {d4},  [r1,  :64]!
+        vst1.16         {d6},  [r11, :64]!
+        vst1.32         {q6},  [r0,  :128]!
+        vst1.32         {q12}, [r10, :128]!
+        ble             9f
+        vext.8          q0,  q0,  q1,  #8
+        vext.8          q4,  q4,  q5,  #8
+        add5            4
+        vst1.16         {d4},  [r1,  :64]!
+        vst1.16         {d6},  [r11, :64]!
+        vst1.32         {q6},  [r0,  :128]!
+        vst1.32         {q12}, [r10, :128]!
+
+9:
+        subs            r6,  r6,  #2
+        ble             0f
+        // Jump to the next row and loop horizontally
+        add             r0,  r0,  r9, lsl #1
+        add             r10, r10, r9, lsl #1
+        add             r1,  r1,  r9
+        add             r11, r11, r9
+        add             r3,  r3,  r4
+        add             r12, r12, r4
+        mov             r5,  r8
+        b               1b
+0:
+        vpop            {q4-q7}
+        pop             {r4-r11,pc}
+.purgem add5
+endfunc
+
+sgr_funcs 16
--- a/src/arm/32/looprestoration_common.S
+++ b/src/arm/32/looprestoration_common.S
@@ -336,14 +336,17 @@
 endfunc
 
 // void dav1d_sgr_calc_ab1_neon(int32_t *a, int16_t *b,
-//                              const int w, const int h, const int strength);
+//                              const int w, const int h, const int strength,
+//                              const int bitdepth_max);
 // void dav1d_sgr_calc_ab2_neon(int32_t *a, int16_t *b,
-//                              const int w, const int h, const int strength);
+//                              const int w, const int h, const int strength,
+//                              const int bitdepth_max);
 function sgr_calc_ab1_neon, export=1
-        push            {r4-r5,lr}
+        push            {r4-r7,lr}
         vpush           {q4-q7}
-        ldr             r4,  [sp, #76]
+        ldrd            r4,  r5,  [sp, #84]
         add             r3,  r3,  #2   // h += 2
+        clz             r6,  r5
         vmov.i32        q15, #9        // n
         movw            r5,  #455
         mov             lr,  #SUM_STRIDE
@@ -351,10 +354,11 @@
 endfunc
 
 function sgr_calc_ab2_neon, export=1
-        push            {r4-r5,lr}
+        push            {r4-r7,lr}
         vpush           {q4-q7}
-        ldr             r4,  [sp, #76]
+        ldrd            r4,  r5,  [sp, #84]
         add             r3,  r3,  #3   // h += 3
+        clz             r6,  r5
         asr             r3,  r3,  #1   // h /= 2
         vmov.i32        q15, #25       // n
         mov             r5,  #164
@@ -363,7 +367,9 @@
 
 function sgr_calc_ab_neon
         movrel          r12, X(sgr_x_by_x)
+        sub             r6,  r6,  #24  // -bitdepth_min_8
         vld1.8          {q8, q9}, [r12, :128]!
+        add             r7,  r6,  r6   // -2*bitdepth_min_8
         vmov.i8         q11, #5
         vmov.i8         d10, #55       // idx of last 5
         vld1.8          {q10},    [r12, :128]
@@ -376,9 +382,7 @@
         add             r12, r2,  #7
         bic             r12, r12, #7   // aligned w
         sub             r12, lr,  r12  // increment between rows
-        vmov.i16        q13, #256
         vdup.32         q12, r4
-        vdup.32         q14, r5        // one_by_x
         sub             r0,  r0,  #(4*(SUM_STRIDE))
         sub             r1,  r1,  #(2*(SUM_STRIDE))
         mov             r4,  r2        // backup of w
@@ -386,13 +390,18 @@
         vsub.i8         q9,  q9,  q11
         vsub.i8         q10, q10, q11
 1:
-        subs            r2,  r2,  #8
         vld1.32         {q0, q1}, [r0, :128] // a
         vld1.16         {q2},     [r1, :128] // b
+        vdup.32         q13, r7        // -2*bitdepth_min_8
+        vdup.16         q14, r6        // -bitdepth_min_8
+        subs            r2,  r2,  #8
+        vrshl.s32       q0,  q0,  q13
+        vrshl.s32       q1,  q1,  q13
+        vrshl.s16       q4,  q2,  q14
         vmul.i32        q0,  q0,  q15  // a * n
         vmul.i32        q1,  q1,  q15  // a * n
-        vmull.u16       q3,  d4,  d4   // b * b
-        vmull.u16       q4,  d5,  d5   // b * b
+        vmull.u16       q3,  d8,  d8   // b * b
+        vmull.u16       q4,  d9,  d9   // b * b
         vqsub.u32       q0,  q0,  q3   // imax(a * n - b * b, 0)
         vqsub.u32       q1,  q1,  q4   // imax(a * n - b * b, 0)
         vmul.i32        q0,  q0,  q12  // p * s
@@ -417,6 +426,9 @@
         vadd.i8         d1,  d1,  d2
         vmovl.u8        q0,  d1        // x
 
+        vmov.i16        q13, #256
+        vdup.32         q14, r5        // one_by_x
+
         vmull.u16       q1,  d0,  d4   // x * BB[i]
         vmull.u16       q2,  d1,  d5   // x * BB[i]
         vmul.i32        q1,  q1,  q14  // x * BB[i] * sgr_one_by_x
@@ -437,5 +449,5 @@
         b               1b
 0:
         vpop            {q4-q7}
-        pop             {r4-r5,pc}
+        pop             {r4-r7,pc}
 endfunc
--- a/src/arm/32/looprestoration_tmpl.S
+++ b/src/arm/32/looprestoration_tmpl.S
@@ -29,11 +29,12 @@
 
 #define FILTER_OUT_STRIDE 384
 
-// void dav1d_sgr_finish_filter1_8bpc_neon(int16_t *tmp,
+.macro sgr_funcs bpc
+// void dav1d_sgr_finish_filter1_Xbpc_neon(int16_t *tmp,
 //                                         const pixel *src, const ptrdiff_t stride,
 //                                         const int32_t *a, const int16_t *b,
 //                                         const int w, const int h);
-function sgr_finish_filter1_8bpc_neon, export=1
+function sgr_finish_filter1_\bpc\()bpc_neon, export=1
         push            {r4-r11,lr}
         vpush           {q4-q7}
         ldrd            r4,  r5,  [sp, #100]
@@ -46,7 +47,11 @@
         mov             r12, #FILTER_OUT_STRIDE
         add             lr,  r5,  #3
         bic             lr,  lr,  #3 // Aligned width
+.if \bpc == 8
         sub             r2,  r2,  lr
+.else
+        sub             r2,  r2,  lr, lsl #1
+.endif
         sub             r12, r12, lr
         sub             r11, r11, lr
         sub             r11, r11, #4 // We read 4 extra elements from both a and b
@@ -90,12 +95,18 @@
         vadd.i32        q3,  q3,  q5
         vext.8          q7,  q12, q13, #4  // +stride
         vext.8          q10, q12, q13, #8  // +1+stride
+.if \bpc == 8
         vld1.32         {d24[0]}, [r1, :32]! // src
+.else
+        vld1.16         {d24}, [r1, :64]!    // src
+.endif
         vadd.i32        q3,  q3,  q7       // +stride
         vadd.i32        q8,  q8,  q10      // +1+stride
         vshl.i32        q3,  q3,  #2
         vmla.i32        q3,  q8,  q15      // * 3 -> b
+.if \bpc == 8
         vmovl.u8        q12, d24           // src
+.endif
         vmov            d0,  d1
         vmlal.u16       q3,  d2,  d24      // b + a * src
         vmov            d2,  d3
@@ -133,11 +144,11 @@
         pop             {r4-r11,pc}
 endfunc
 
-// void dav1d_sgr_finish_filter2_8bpc_neon(int16_t *tmp,
+// void dav1d_sgr_finish_filter2_Xbpc_neon(int16_t *tmp,
 //                                         const pixel *src, const ptrdiff_t stride,
 //                                         const int32_t *a, const int16_t *b,
 //                                         const int w, const int h);
-function sgr_finish_filter2_8bpc_neon, export=1
+function sgr_finish_filter2_\bpc\()bpc_neon, export=1
         push            {r4-r11,lr}
         vpush           {q4-q7}
         ldrd            r4,  r5,  [sp, #100]
@@ -150,7 +161,11 @@
         mov             r10, #FILTER_OUT_STRIDE
         add             r11, r5,  #7
         bic             r11, r11, #7 // Aligned width
+.if \bpc == 8
         sub             r2,  r2,  r11
+.else
+        sub             r2,  r2,  r11, lsl #1
+.endif
         sub             r10, r10, r11
         sub             r9,  r9,  r11
         sub             r9,  r9,  #4 // We read 4 extra elements from a
@@ -195,7 +210,11 @@
         vext.8          q8,  q11, q12, #4  // +stride
         vext.8          q11, q12, q13, #4
 
+.if \bpc == 8
         vld1.8          {d4}, [r1, :64]!
+.else
+        vld1.8          {q2}, [r1, :128]!
+.endif
 
         vmov.i32        q14, #5
         vmov.i32        q15, #6
@@ -207,7 +226,9 @@
         vmul.i32        q5,  q5,  q14      // * 5
         vmla.i32        q5,  q7,  q15      // * 6
 
+.if \bpc == 8
         vmovl.u8        q2,  d4
+.endif
         vmlal.u16       q4,  d0,  d4       // b + a * src
         vmlal.u16       q5,  d1,  d5       // b + a * src
         vmov            q0,  q1
@@ -255,10 +276,16 @@
         vext.8          q7,  q9,  q10, #8
         vmul.i16        q2,  q2,  q13      // * 6
         vmla.i16        q2,  q0,  q12      // * 5 -> a
+.if \bpc == 8
         vld1.8          {d22}, [r1, :64]!
+.else
+        vld1.16         {q11}, [r1, :128]!
+.endif
         vadd.i32        q8,  q8,  q6       // -1, +1
         vadd.i32        q9,  q9,  q7
+.if \bpc == 8
         vmovl.u8        q11, d22
+.endif
         vmul.i32        q4,  q4,  q15      // * 6
         vmla.i32        q4,  q8,  q14      // * 5 -> b
         vmul.i32        q5,  q5,  q15      // * 6
@@ -293,16 +320,22 @@
         pop             {r4-r11,pc}
 endfunc
 
-// void dav1d_sgr_weighted1_8bpc_neon(pixel *dst, const ptrdiff_t dst_stride,
+// void dav1d_sgr_weighted1_Xbpc_neon(pixel *dst, const ptrdiff_t dst_stride,
 //                                    const pixel *src, const ptrdiff_t src_stride,
 //                                    const int16_t *t1, const int w, const int h,
-//                                    const int wt);
-function sgr_weighted1_8bpc_neon, export=1
+//                                    const int wt, const int bitdepth_max);
+function sgr_weighted1_\bpc\()bpc_neon, export=1
         push            {r4-r9,lr}
         ldrd            r4,  r5,  [sp, #28]
         ldrd            r6,  r7,  [sp, #36]
+.if \bpc == 16
+        ldr             r8,  [sp, #44]
+.endif
         vdup.16         d31, r7
         cmp             r6,  #2
+.if \bpc == 16
+        vdup.16         q14, r8
+.endif
         add             r9,  r0,  r1
         add             r12, r2,  r3
         add             lr,  r4,  #2*FILTER_OUT_STRIDE
@@ -311,19 +344,34 @@
         lsl             r3,  r3,  #1
         add             r8,  r5,  #7
         bic             r8,  r8,  #7 // Aligned width
+.if \bpc == 8
         sub             r1,  r1,  r8
         sub             r3,  r3,  r8
+.else
+        sub             r1,  r1,  r8, lsl #1
+        sub             r3,  r3,  r8, lsl #1
+.endif
         sub             r7,  r7,  r8, lsl #1
         mov             r8,  r5
         blt             2f
 1:
+.if \bpc == 8
         vld1.8          {d0},  [r2,  :64]!
         vld1.8          {d16}, [r12, :64]!
+.else
+        vld1.16         {q0},  [r2,  :128]!
+        vld1.16         {q8},  [r12, :128]!
+.endif
         vld1.16         {q1},  [r4,  :128]!
         vld1.16         {q9},  [lr,  :128]!
         subs            r5,  r5,  #8
+.if \bpc == 8
         vshll.u8        q0,  d0,  #4     // u
         vshll.u8        q8,  d16, #4     // u
+.else
+        vshl.i16        q0,  q0,  #4     // u
+        vshl.i16        q8,  q8,  #4     // u
+.endif
         vsub.i16        q1,  q1,  q0     // t1 - u
         vsub.i16        q9,  q9,  q8     // t1 - u
         vshll.u16       q2,  d0,  #7     // u << 7
@@ -334,6 +382,7 @@
         vmlal.s16       q3,  d3,  d31    // v
         vmlal.s16       q10, d18, d31    // v
         vmlal.s16       q11, d19, d31    // v
+.if \bpc == 8
         vrshrn.i32      d4,  q2,  #11
         vrshrn.i32      d5,  q3,  #11
         vrshrn.i32      d20, q10, #11
@@ -342,6 +391,16 @@
         vqmovun.s16     d20, q10
         vst1.8          {d4},  [r0]!
         vst1.8          {d20}, [r9]!
+.else
+        vqrshrun.s32    d4,  q2,  #11
+        vqrshrun.s32    d5,  q3,  #11
+        vqrshrun.s32    d20, q10, #11
+        vqrshrun.s32    d21, q11, #11
+        vmin.u16        q2,  q2,  q14
+        vmin.u16        q10, q10, q14
+        vst1.16         {q2},  [r0]!
+        vst1.16         {q10}, [r9]!
+.endif
         bgt             1b
 
         sub             r6,  r6,  #2
@@ -358,34 +417,53 @@
         b               1b
 
 2:
+.if \bpc == 8
         vld1.8          {d0}, [r2, :64]!
+.else
+        vld1.16         {q0}, [r2, :128]!
+.endif
         vld1.16         {q1}, [r4, :128]!
         subs            r5,  r5,  #8
+.if \bpc == 8
         vshll.u8        q0,  d0,  #4     // u
+.else
+        vshl.i16        q0,  q0,  #4     // u
+.endif
         vsub.i16        q1,  q1,  q0     // t1 - u
         vshll.u16       q2,  d0,  #7     // u << 7
         vshll.u16       q3,  d1,  #7     // u << 7
         vmlal.s16       q2,  d2,  d31    // v
         vmlal.s16       q3,  d3,  d31    // v
+.if \bpc == 8
         vrshrn.i32      d4,  q2,  #11
         vrshrn.i32      d5,  q3,  #11
         vqmovun.s16     d2,  q2
         vst1.8          {d2}, [r0]!
+.else
+        vqrshrun.s32    d4,  q2,  #11
+        vqrshrun.s32    d5,  q3,  #11
+        vmin.u16        q2,  q2,  q14
+        vst1.16         {q2}, [r0]!
+.endif
         bgt             2b
 0:
         pop             {r4-r9,pc}
 endfunc
 
-// void dav1d_sgr_weighted2_8bpc_neon(pixel *dst, const ptrdiff_t stride,
+// void dav1d_sgr_weighted2_Xbpc_neon(pixel *dst, const ptrdiff_t stride,
 //                                    const pixel *src, const ptrdiff_t src_stride,
 //                                    const int16_t *t1, const int16_t *t2,
 //                                    const int w, const int h,
-//                                    const int16_t wt[2]);
-function sgr_weighted2_8bpc_neon, export=1
+//                                    const int16_t wt[2], const int bitdepth_max);
+function sgr_weighted2_\bpc\()bpc_neon, export=1
         push            {r4-r11,lr}
         ldrd            r4,  r5,  [sp, #36]
         ldrd            r6,  r7,  [sp, #44]
+.if \bpc == 8
         ldr             r8,  [sp, #52]
+.else
+        ldrd            r8,  r9,  [sp, #52]
+.endif
         cmp             r7,  #2
         add             r10, r0,  r1
         add             r11, r2,  r3
@@ -392,26 +470,44 @@
         add             r12, r4,  #2*FILTER_OUT_STRIDE
         add             lr,  r5,  #2*FILTER_OUT_STRIDE
         vld2.16         {d30[], d31[]}, [r8] // wt[0], wt[1]
+.if \bpc == 16
+        vdup.16         q14, r9
+.endif
         mov             r8,  #4*FILTER_OUT_STRIDE
         lsl             r1,  r1,  #1
         lsl             r3,  r3,  #1
         add             r9,  r6,  #7
         bic             r9,  r9,  #7 // Aligned width
+.if \bpc == 8
         sub             r1,  r1,  r9
         sub             r3,  r3,  r9
+.else
+        sub             r1,  r1,  r9, lsl #1
+        sub             r3,  r3,  r9, lsl #1
+.endif
         sub             r8,  r8,  r9, lsl #1
         mov             r9,  r6
         blt             2f
 1:
+.if \bpc == 8
         vld1.8          {d0},  [r2,  :64]!
         vld1.8          {d16}, [r11, :64]!
+.else
+        vld1.16         {q0},  [r2,  :128]!
+        vld1.16         {q8},  [r11, :128]!
+.endif
         vld1.16         {q1},  [r4,  :128]!
         vld1.16         {q9},  [r12, :128]!
         vld1.16         {q2},  [r5,  :128]!
         vld1.16         {q10}, [lr,  :128]!
         subs            r6,  r6,  #8
+.if \bpc == 8
         vshll.u8        q0,  d0,  #4     // u
         vshll.u8        q8,  d16, #4     // u
+.else
+        vshl.i16        q0,  q0,  #4     // u
+        vshl.i16        q8,  q8,  #4     // u
+.endif
         vsub.i16        q1,  q1,  q0     // t1 - u
         vsub.i16        q2,  q2,  q0     // t2 - u
         vsub.i16        q9,  q9,  q8     // t1 - u
@@ -428,6 +524,7 @@
         vmlal.s16       q11, d20, d31    // wt[1] * (t2 - u)
         vmlal.s16       q8,  d19, d30    // wt[0] * (t1 - u)
         vmlal.s16       q8,  d21, d31    // wt[1] * (t2 - u)
+.if \bpc == 8
         vrshrn.i32      d6,  q3,  #11
         vrshrn.i32      d7,  q0,  #11
         vrshrn.i32      d22, q11, #11
@@ -436,6 +533,16 @@
         vqmovun.s16     d22, q11
         vst1.8          {d6},  [r0]!
         vst1.8          {d22}, [r10]!
+.else
+        vqrshrun.s32    d6,  q3,  #11
+        vqrshrun.s32    d7,  q0,  #11
+        vqrshrun.s32    d22, q11, #11
+        vqrshrun.s32    d23, q8,  #11
+        vmin.u16        q3,  q3,  q14
+        vmin.u16        q11, q11, q14
+        vst1.16         {q3},  [r0]!
+        vst1.16         {q11}, [r10]!
+.endif
         bgt             1b
 
         subs            r7,  r7,  #2
@@ -454,11 +561,19 @@
         b               1b
 
 2:
+.if \bpc == 8
         vld1.8          {d0}, [r2, :64]!
+.else
+        vld1.16         {q0}, [r2, :128]!
+.endif
         vld1.16         {q1}, [r4, :128]!
         vld1.16         {q2}, [r5, :128]!
         subs            r6,  r6,  #8
+.if \bpc == 8
         vshll.u8        q0,  d0,  #4     // u
+.else
+        vshl.i16        q0,  q0,  #4     // u
+.endif
         vsub.i16        q1,  q1,  q0     // t1 - u
         vsub.i16        q2,  q2,  q0     // t2 - u
         vshll.u16       q3,  d0,  #7     // u << 7
@@ -467,11 +582,19 @@
         vmlal.s16       q3,  d4,  d31    // wt[1] * (t2 - u)
         vmlal.s16       q0,  d3,  d30    // wt[0] * (t1 - u)
         vmlal.s16       q0,  d5,  d31    // wt[1] * (t2 - u)
+.if \bpc == 8
         vrshrn.i32      d6,  q3,  #11
         vrshrn.i32      d7,  q0,  #11
         vqmovun.s16     d6,  q3
         vst1.8          {d6}, [r0]!
+.else
+        vqrshrun.s32    d6,  q3,  #11
+        vqrshrun.s32    d7,  q0,  #11
+        vmin.u16        q3,  q3,  q14
+        vst1.16         {q3}, [r0]!
+.endif
         bgt             1b
 0:
         pop             {r4-r11,pc}
 endfunc
+.endm
--- a/src/arm/looprestoration_init_tmpl.c
+++ b/src/arm/looprestoration_init_tmpl.c
@@ -104,7 +104,6 @@
     }
 }
 
-#if BITDEPTH == 8 || ARCH_AARCH64
 void BF(dav1d_sgr_box3_h, neon)(int32_t *sumsq, int16_t *sum,
                                 const pixel (*left)[4],
                                 const pixel *src, const ptrdiff_t stride,
@@ -283,7 +282,6 @@
         }
     }
 }
-#endif // BITDEPTH == 8
 
 COLD void bitfn(dav1d_loop_restoration_dsp_init_arm)(Dav1dLoopRestorationDSPContext *const c, int bpc) {
     const unsigned flags = dav1d_get_cpu_flags();
@@ -291,8 +289,6 @@
     if (!(flags & DAV1D_ARM_CPU_FLAG_NEON)) return;
 
     c->wiener = wiener_filter_neon;
-#if BITDEPTH == 8 || ARCH_AARCH64
     if (bpc <= 10)
         c->selfguided = sgr_filter_neon;
-#endif
 }