shithub: libvpx

Download patch

ref: 7ef83148cfbfd0163fa22a0550d128d935fe2bad
parent: 12a14913947b510514746389319b49a188a53579
parent: c1f77a3689a6cf5e95e1c1ae35d76f4f171f5ef3
author: James Zern <jzern@google.com>
date: Wed May 5 15:57:10 EDT 2021

Merge "Implement horizontal convolution using Neon SDOT instruction"

--- a/vpx_dsp/arm/mem_neon.h
+++ b/vpx_dsp/arm/mem_neon.h
@@ -19,6 +19,24 @@
 #include "vpx/vpx_integer.h"
 #include "vpx_dsp/vpx_dsp_common.h"
 
+// Support for these xN intrinsics is lacking in older versions of GCC.
+#if defined(__GNUC__) && !defined(__clang__)
+#if __GNUC__ < 8 || defined(__arm__)
+static INLINE uint8x16x2_t vld1q_u8_x2(uint8_t const *ptr) {
+  uint8x16x2_t res = { { vld1q_u8(ptr + 0 * 16), vld1q_u8(ptr + 1 * 16) } };
+  return res;
+}
+#endif
+
+#if __GNUC__ < 9 || defined(__arm__)
+static INLINE uint8x16x3_t vld1q_u8_x3(uint8_t const *ptr) {
+  uint8x16x3_t res = { { vld1q_u8(ptr + 0 * 16), vld1q_u8(ptr + 1 * 16),
+                         vld1q_u8(ptr + 2 * 16) } };
+  return res;
+}
+#endif
+#endif
+
 static INLINE int16x4_t create_s16x4_neon(const int16_t c0, const int16_t c1,
                                           const int16_t c2, const int16_t c3) {
   return vcreate_s16((uint16_t)c0 | ((uint32_t)c1 << 16) |
--- a/vpx_dsp/arm/vpx_convolve8_neon.c
+++ b/vpx_dsp/arm/vpx_convolve8_neon.c
@@ -14,6 +14,7 @@
 #include "./vpx_config.h"
 #include "./vpx_dsp_rtcd.h"
 #include "vpx/vpx_integer.h"
+#include "vpx_dsp/arm/mem_neon.h"
 #include "vpx_dsp/arm/transpose_neon.h"
 #include "vpx_dsp/arm/vpx_convolve8_neon.h"
 #include "vpx_ports/mem.h"
@@ -52,11 +53,117 @@
   vst1_u8(s, s7);
 }
 
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD) && \
+    (__ARM_FEATURE_DOTPROD == 1)
+
+DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = {
+  0, 1, 2,  3,  1, 2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,
+  4, 5, 6,  7,  5, 6,  7,  8,  6,  7,  8,  9,  7,  8,  9,  10,
+  8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
+};
+
 void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
                               uint8_t *dst, ptrdiff_t dst_stride,
                               const InterpKernel *filter, int x0_q4,
                               int x_step_q4, int y0_q4, int y_step_q4, int w,
                               int h) {
+  const int8x8_t filters = vmovn_s16(vld1q_s16(filter[x0_q4]));
+  const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter[x0_q4]), 128);
+  const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp));
+  const uint8x16_t range_limit = vdupq_n_u8(128);
+  uint8x16_t s0, s1, s2, s3;
+
+  assert(!((intptr_t)dst & 3));
+  assert(!(dst_stride & 3));
+  assert(x_step_q4 == 16);
+
+  (void)x_step_q4;
+  (void)y0_q4;
+  (void)y_step_q4;
+
+  src -= 3;
+
+  if (w == 4) {
+    const uint8x16x2_t permute_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
+    do {
+      int32x4_t t0, t1, t2, t3;
+      int16x8_t t01, t23;
+      uint8x8_t d01, d23;
+
+      s0 = vld1q_u8(src);
+      src += src_stride;
+      s1 = vld1q_u8(src);
+      src += src_stride;
+      s2 = vld1q_u8(src);
+      src += src_stride;
+      s3 = vld1q_u8(src);
+      src += src_stride;
+
+      t0 = convolve8_4_dot(s0, filters, correction, range_limit, permute_tbl);
+      t1 = convolve8_4_dot(s1, filters, correction, range_limit, permute_tbl);
+      t2 = convolve8_4_dot(s2, filters, correction, range_limit, permute_tbl);
+      t3 = convolve8_4_dot(s3, filters, correction, range_limit, permute_tbl);
+
+      t01 = vcombine_s16(vqmovn_s32(t0), vqmovn_s32(t1));
+      t23 = vcombine_s16(vqmovn_s32(t2), vqmovn_s32(t3));
+      d01 = vqrshrun_n_s16(t01, 7);
+      d23 = vqrshrun_n_s16(t23, 7);
+
+      vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(d01), 0);
+      dst += dst_stride;
+      vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(d01), 1);
+      dst += dst_stride;
+      vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(d23), 0);
+      dst += dst_stride;
+      vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(d23), 1);
+      dst += dst_stride;
+      h -= 4;
+    } while (h > 0);
+  } else {
+    const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
+    const uint8_t *s;
+    uint8_t *d;
+    int width;
+    uint8x8_t d0, d1, d2, d3;
+
+    do {
+      width = w;
+      s = src;
+      d = dst;
+      do {
+        s0 = vld1q_u8(s + 0 * src_stride);
+        s1 = vld1q_u8(s + 1 * src_stride);
+        s2 = vld1q_u8(s + 2 * src_stride);
+        s3 = vld1q_u8(s + 3 * src_stride);
+
+        d0 = convolve8_8_dot(s0, filters, correction, range_limit, permute_tbl);
+        d1 = convolve8_8_dot(s1, filters, correction, range_limit, permute_tbl);
+        d2 = convolve8_8_dot(s2, filters, correction, range_limit, permute_tbl);
+        d3 = convolve8_8_dot(s3, filters, correction, range_limit, permute_tbl);
+
+        vst1_u8(d + 0 * dst_stride, d0);
+        vst1_u8(d + 1 * dst_stride, d1);
+        vst1_u8(d + 2 * dst_stride, d2);
+        vst1_u8(d + 3 * dst_stride, d3);
+
+        s += 8;
+        d += 8;
+        width -= 8;
+      } while (width > 0);
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      h -= 4;
+    } while (h > 0);
+  }
+}
+
+#else
+
+void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
+                              uint8_t *dst, ptrdiff_t dst_stride,
+                              const InterpKernel *filter, int x0_q4,
+                              int x_step_q4, int y0_q4, int y_step_q4, int w,
+                              int h) {
   const int16x8_t filters = vld1q_s16(filter[x0_q4]);
   uint8x8_t t0, t1, t2, t3;
 
@@ -304,6 +411,8 @@
     }
   }
 }
+
+#endif
 
 void vpx_convolve8_avg_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
                                   uint8_t *dst, ptrdiff_t dst_stride,
--- a/vpx_dsp/arm/vpx_convolve8_neon.h
+++ b/vpx_dsp/arm/vpx_convolve8_neon.h
@@ -72,6 +72,69 @@
   *s7 = vld1q_u8(s);
 }
 
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD) && \
+    (__ARM_FEATURE_DOTPROD == 1)
+
+static INLINE int32x4_t convolve8_4_dot(uint8x16_t samples,
+                                        const int8x8_t filters,
+                                        const int32x4_t correction,
+                                        const uint8x16_t range_limit,
+                                        const uint8x16x2_t permute_tbl) {
+  int8x16_t clamped_samples, permuted_samples[2];
+  int32x4_t sum;
+
+  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
+  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
+
+  /* Accumulate dot product into 'correction' to account for range clamp. */
+  sum = vdotq_lane_s32(correction, permuted_samples[0], filters, 0);
+  sum = vdotq_lane_s32(sum, permuted_samples[1], filters, 1);
+
+  /* Narrowing and packing is performed by the caller. */
+  return sum;
+}
+
+static INLINE uint8x8_t convolve8_8_dot(uint8x16_t samples,
+                                        const int8x8_t filters,
+                                        const int32x4_t correction,
+                                        const uint8x16_t range_limit,
+                                        const uint8x16x3_t permute_tbl) {
+  int8x16_t clamped_samples, permuted_samples[3];
+  int32x4_t sum0, sum1;
+  int16x8_t sum;
+
+  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
+  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
+
+  /* Accumulate dot product into 'correction' to account for range clamp. */
+  /* First 4 output values. */
+  sum0 = vdotq_lane_s32(correction, permuted_samples[0], filters, 0);
+  sum0 = vdotq_lane_s32(sum0, permuted_samples[1], filters, 1);
+  /* Second 4 output values. */
+  sum1 = vdotq_lane_s32(correction, permuted_samples[1], filters, 0);
+  sum1 = vdotq_lane_s32(sum1, permuted_samples[2], filters, 1);
+
+  /* Narrow and re-pack. */
+  sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
+  return vqrshrun_n_s16(sum, 7);
+}
+
+#endif
+
 static INLINE int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1,
                                     const int16x4_t s2, const int16x4_t s3,
                                     const int16x4_t s4, const int16x4_t s5,