shithub: libvpx

Download patch

ref: cf8039c25f3e275ef4f2ca350fcd42680b9413d3
parent: 6fbc354c97ca45754a76674dba5f6b43c2b0c15a
parent: b383a17fa4c36a4242816ba6a1c57dca46d042d6
author: Johann Koenig <johannkoenig@google.com>
date: Wed Nov 8 11:28:40 EST 2017

Merge "Support building AVX-512 and implement sadx4 for AVX-512"

--- a/build/make/Makefile
+++ b/build/make/Makefile
@@ -139,6 +139,8 @@
 $(BUILD_PFX)%_avx.c.o: CFLAGS += -mavx
 $(BUILD_PFX)%_avx2.c.d: CFLAGS += -mavx2
 $(BUILD_PFX)%_avx2.c.o: CFLAGS += -mavx2
+$(BUILD_PFX)%_avx512.c.d: CFLAGS += -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl
+$(BUILD_PFX)%_avx512.c.o: CFLAGS += -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl
 
 # POWER
 $(BUILD_PFX)%_vsx.c.d: CFLAGS += -maltivec -mvsx
--- a/build/make/configure.sh
+++ b/build/make/configure.sh
@@ -403,6 +403,23 @@
   fi
 }
 
+# tests for -m$2, -m$3, -m$4... toggling the feature given in $1.
+check_gcc_machine_options() {
+  feature="$1"
+  shift
+  flags="-m$1"
+  shift
+  for opt in $*; do
+    flags="$flags -m$opt"
+  done
+
+  if enabled gcc && ! disabled "$feature" && ! check_cflags $flags; then
+    RTCD_OPTIONS="${RTCD_OPTIONS}--disable-$feature "
+  else
+    soft_enable "$feature"
+  fi
+}
+
 write_common_config_banner() {
   print_webm_license config.mk "##" ""
   echo '# This file automatically generated by configure. Do not edit!' >> config.mk
@@ -1238,6 +1255,13 @@
           msvs_arch_dir=x86-msvs
           vc_version=${tgt_cc##vs}
           case $vc_version in
+            7|8|9|10|11|12|13|14)
+              echo "${tgt_cc} does not support avx512, disabling....."
+              RTCD_OPTIONS="${RTCD_OPTIONS}--disable-avx512 "
+              soft_disable avx512
+              ;;
+          esac
+          case $vc_version in
             7|8|9|10)
               echo "${tgt_cc} does not support avx/avx2, disabling....."
               RTCD_OPTIONS="${RTCD_OPTIONS}--disable-avx --disable-avx2 "
@@ -1281,8 +1305,12 @@
         elif disabled $ext; then
           disable_exts="yes"
         else
-          # use the shortened version for the flag: sse4_1 -> sse4
-          check_gcc_machine_option ${ext%_*} $ext
+          if [ "$ext" = "avx512" ]; then
+            check_gcc_machine_options $ext avx512f avx512cd avx512bw avx512dq avx512vl
+          else
+            # use the shortened version for the flag: sse4_1 -> sse4
+            check_gcc_machine_option ${ext%_*} $ext
+          fi
         fi
       done
 
--- a/build/make/rtcd.pl
+++ b/build/make/rtcd.pl
@@ -391,10 +391,10 @@
 
 &require("c");
 if ($opts{arch} eq 'x86') {
-  @ALL_ARCHS = filter(qw/mmx sse sse2 sse3 ssse3 sse4_1 avx avx2/);
+  @ALL_ARCHS = filter(qw/mmx sse sse2 sse3 ssse3 sse4_1 avx avx2 avx512/);
   x86;
 } elsif ($opts{arch} eq 'x86_64') {
-  @ALL_ARCHS = filter(qw/mmx sse sse2 sse3 ssse3 sse4_1 avx avx2/);
+  @ALL_ARCHS = filter(qw/mmx sse sse2 sse3 ssse3 sse4_1 avx avx2 avx512/);
   @REQUIRES = filter(keys %required ? keys %required : qw/mmx sse sse2/);
   &require(@REQUIRES);
   x86;
--- a/configure
+++ b/configure
@@ -244,6 +244,7 @@
     sse4_1
     avx
     avx2
+    avx512
 "
 
 ARCH_EXT_LIST_LOONGSON="
--- a/test/sad_test.cc
+++ b/test/sad_test.cc
@@ -896,6 +896,14 @@
 INSTANTIATE_TEST_CASE_P(AVX2, SADx4Test, ::testing::ValuesIn(x4d_avx2_tests));
 #endif  // HAVE_AVX2
 
+#if HAVE_AVX512
+const SadMxNx4Param x4d_avx512_tests[] = {
+  SadMxNx4Param(64, 64, &vpx_sad64x64x4d_avx512),
+};
+INSTANTIATE_TEST_CASE_P(AVX512, SADx4Test,
+                        ::testing::ValuesIn(x4d_avx512_tests));
+#endif  // HAVE_AVX512
+
 //------------------------------------------------------------------------------
 // MIPS functions
 #if HAVE_MSA
--- a/test/test_libvpx.cc
+++ b/test/test_libvpx.cc
@@ -53,6 +53,9 @@
   }
   if (!(simd_caps & HAS_AVX)) append_negative_gtest_filter(":AVX.*:AVX/*");
   if (!(simd_caps & HAS_AVX2)) append_negative_gtest_filter(":AVX2.*:AVX2/*");
+  if (!(simd_caps & HAS_AVX512)) {
+    append_negative_gtest_filter(":AVX512.*:AVX512/*");
+  }
 #endif  // ARCH_X86 || ARCH_X86_64
 
 #if !CONFIG_SHARED
--- a/vp9/common/vp9_rtcd_defs.pl
+++ b/vp9/common/vp9_rtcd_defs.pl
@@ -30,6 +30,7 @@
   $ssse3_x86_64 = 'ssse3';
   $avx_x86_64 = 'avx';
   $avx2_x86_64 = 'avx2';
+  $avx512_x86_64 = 'avx512';
 }
 
 #
--- a/vpx_dsp/vpx_dsp.mk
+++ b/vpx_dsp/vpx_dsp.mk
@@ -327,6 +327,7 @@
 DSP_SRCS-$(HAVE_SSE4_1) += x86/sad_sse4.asm
 DSP_SRCS-$(HAVE_AVX2)   += x86/sad4d_avx2.c
 DSP_SRCS-$(HAVE_AVX2)   += x86/sad_avx2.c
+DSP_SRCS-$(HAVE_AVX512) += x86/sad4d_avx512.c
 
 DSP_SRCS-$(HAVE_SSE)    += x86/sad4d_sse2.asm
 DSP_SRCS-$(HAVE_SSE)    += x86/sad_sse2.asm
--- a/vpx_dsp/vpx_dsp_rtcd_defs.pl
+++ b/vpx_dsp/vpx_dsp_rtcd_defs.pl
@@ -20,6 +20,7 @@
   $ssse3_x86_64 = 'ssse3';
   $avx_x86_64 = 'avx';
   $avx2_x86_64 = 'avx2';
+  $avx512_x86_64 = 'avx512';
 }
 
 #
@@ -872,7 +873,7 @@
 # Multi-block SAD, comparing a reference to N independent blocks
 #
 add_proto qw/void vpx_sad64x64x4d/, "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[], int ref_stride, uint32_t *sad_array";
-specialize qw/vpx_sad64x64x4d avx2 neon msa sse2 vsx mmi/;
+specialize qw/vpx_sad64x64x4d avx512 avx2 neon msa sse2 vsx mmi/;
 
 add_proto qw/void vpx_sad64x32x4d/, "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[], int ref_stride, uint32_t *sad_array";
 specialize qw/vpx_sad64x32x4d neon msa sse2 vsx mmi/;
--- /dev/null
+++ b/vpx_dsp/x86/sad4d_avx512.c
@@ -1,0 +1,83 @@
+/*
+ *  Copyright (c) 2017 The WebM project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+#include <immintrin.h>  // AVX512
+#include "./vpx_dsp_rtcd.h"
+#include "vpx/vpx_integer.h"
+
+void vpx_sad64x64x4d_avx512(const uint8_t *src, int src_stride,
+                            const uint8_t *const ref[4], int ref_stride,
+                            uint32_t res[4]) {
+  __m512i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
+  __m512i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
+  __m512i sum_mlow, sum_mhigh;
+  int i;
+  const uint8_t *ref0, *ref1, *ref2, *ref3;
+
+  ref0 = ref[0];
+  ref1 = ref[1];
+  ref2 = ref[2];
+  ref3 = ref[3];
+  sum_ref0 = _mm512_set1_epi16(0);
+  sum_ref1 = _mm512_set1_epi16(0);
+  sum_ref2 = _mm512_set1_epi16(0);
+  sum_ref3 = _mm512_set1_epi16(0);
+  for (i = 0; i < 64; i++) {
+    // load src and all refs
+    src_reg = _mm512_loadu_si512((const __m512i *)src);
+    ref0_reg = _mm512_loadu_si512((const __m512i *)ref0);
+    ref1_reg = _mm512_loadu_si512((const __m512i *)ref1);
+    ref2_reg = _mm512_loadu_si512((const __m512i *)ref2);
+    ref3_reg = _mm512_loadu_si512((const __m512i *)ref3);
+    // sum of the absolute differences between every ref-i to src
+    ref0_reg = _mm512_sad_epu8(ref0_reg, src_reg);
+    ref1_reg = _mm512_sad_epu8(ref1_reg, src_reg);
+    ref2_reg = _mm512_sad_epu8(ref2_reg, src_reg);
+    ref3_reg = _mm512_sad_epu8(ref3_reg, src_reg);
+    // sum every ref-i
+    sum_ref0 = _mm512_add_epi32(sum_ref0, ref0_reg);
+    sum_ref1 = _mm512_add_epi32(sum_ref1, ref1_reg);
+    sum_ref2 = _mm512_add_epi32(sum_ref2, ref2_reg);
+    sum_ref3 = _mm512_add_epi32(sum_ref3, ref3_reg);
+
+    src += src_stride;
+    ref0 += ref_stride;
+    ref1 += ref_stride;
+    ref2 += ref_stride;
+    ref3 += ref_stride;
+  }
+  {
+    __m256i sum256;
+    __m128i sum128;
+    // in sum_ref-i the result is saved in the first 4 bytes
+    // the other 4 bytes are zeroed.
+    // sum_ref1 and sum_ref3 are shifted left by 4 bytes
+    sum_ref1 = _mm512_bslli_epi128(sum_ref1, 4);
+    sum_ref3 = _mm512_bslli_epi128(sum_ref3, 4);
+
+    // merge sum_ref0 and sum_ref1 also sum_ref2 and sum_ref3
+    sum_ref0 = _mm512_or_si512(sum_ref0, sum_ref1);
+    sum_ref2 = _mm512_or_si512(sum_ref2, sum_ref3);
+
+    // merge every 64 bit from each sum_ref-i
+    sum_mlow = _mm512_unpacklo_epi64(sum_ref0, sum_ref2);
+    sum_mhigh = _mm512_unpackhi_epi64(sum_ref0, sum_ref2);
+
+    // add the low 64 bit to the high 64 bit
+    sum_mlow = _mm512_add_epi32(sum_mlow, sum_mhigh);
+
+    // add the low 128 bit to the high 128 bit
+    sum256 = _mm256_add_epi32(_mm512_castsi512_si256(sum_mlow),
+                              _mm512_extracti32x8_epi32(sum_mlow, 1));
+    sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum256),
+                           _mm256_extractf128_si256(sum256, 1));
+
+    _mm_storeu_si128((__m128i *)(res), sum128);
+  }
+}
--- a/vpx_ports/x86.h
+++ b/vpx_ports/x86.h
@@ -151,16 +151,17 @@
 #endif
 #endif
 
-#define HAS_MMX 0x01
-#define HAS_SSE 0x02
-#define HAS_SSE2 0x04
-#define HAS_SSE3 0x08
-#define HAS_SSSE3 0x10
-#define HAS_SSE4_1 0x20
-#define HAS_AVX 0x40
-#define HAS_AVX2 0x80
+#define HAS_MMX 0x001
+#define HAS_SSE 0x002
+#define HAS_SSE2 0x004
+#define HAS_SSE3 0x008
+#define HAS_SSSE3 0x010
+#define HAS_SSE4_1 0x020
+#define HAS_AVX 0x040
+#define HAS_AVX2 0x080
+#define HAS_AVX512 0x100
 #ifndef BIT
-#define BIT(n) (1 << n)
+#define BIT(n) (1u << n)
 #endif
 
 static INLINE int x86_simd_caps(void) {
@@ -209,6 +210,12 @@
         cpuid(7, 0, reg_eax, reg_ebx, reg_ecx, reg_edx);
 
         if (reg_ebx & BIT(5)) flags |= HAS_AVX2;
+
+        // bits 16 (AVX-512F) & 17 (AVX-512DQ) & 28 (AVX-512CD) &
+        // 30 (AVX-512BW) & 32 (AVX-512VL)
+        if ((reg_ebx & (BIT(16) | BIT(17) | BIT(28) | BIT(30) | BIT(31))) ==
+            (BIT(16) | BIT(17) | BIT(28) | BIT(30) | BIT(31)))
+          flags |= HAS_AVX512;
       }
     }
   }