shithub: dav1d

Download patch

ref: e5bca59c7bab44b9bef8403b0396b2600fa1932c
parent: 1d3f6364feadbc4694da2c3107eb9e8fe5c5443e
author: Janne Grunau <janne-vlc@jannau.net>
date: Wed Dec 12 18:07:46 EST 2018

itx: cancel 16 out of '(x * 2896) >> 12' to avoid integer overflows

Fixes an integer overflow in inv_dct4_1d with
clusterfuzz-testcase-dav1d_fuzzer-5634807321591808 and in inv_adst16_1d
with clusterfuzz-testcase-dav1d_fuzzer-5761827623927808. Credits to oss-fuzz.

--- a/src/itx_1d.c
+++ b/src/itx_1d.c
@@ -42,8 +42,8 @@
     const int in0 = in[0 * in_s], in1 = in[1 * in_s];
     const int in2 = in[2 * in_s], in3 = in[3 * in_s];
 
-    int t0 = ((in0 + in2) * 2896 + 2048) >> 12;
-    int t1 = ((in0 - in2) * 2896 + 2048) >> 12;
+    int t0 = ((in0 + in2) * 181 + 128) >> 8;
+    int t1 = ((in0 - in2) * 181 + 128) >> 8;
     int t2 = (in1 * 1567 - in3 * 3784 + 2048) >> 12;
     int t3 = (in1 * 3784 + in3 * 1567 + 2048) >> 12;
 
@@ -75,8 +75,8 @@
     int t7  = CLIP(t7a + t6a);
         t6a = CLIP(t7a - t6a);
 
-    int t5  = ((t6a - t5a) * 2896 + 2048) >> 12;
-    int t6  = ((t6a + t5a) * 2896 + 2048) >> 12;
+    int t5  = ((t6a - t5a) * 181 + 128) >> 8;
+    int t6  = ((t6a + t5a) * 181 + 128) >> 8;
 
     out[0 * out_s] = CLIP(tmp[0] + t7);
     out[1 * out_s] = CLIP(tmp[1] + t6);
@@ -134,10 +134,10 @@
     t14  = CLIP(t14a + t13a);
     t15a = CLIP(t15  + t12);
 
-    t10a = ((t13  - t10)  * 2896 + 2048) >> 12;
-    t13a = ((t13  + t10)  * 2896 + 2048) >> 12;
-    t11  = ((t12a - t11a) * 2896 + 2048) >> 12;
-    t12  = ((t12a + t11a) * 2896 + 2048) >> 12;
+    t10a = ((t13  - t10)  * 181 + 128) >> 8;
+    t13a = ((t13  + t10)  * 181 + 128) >> 8;
+    t11  = ((t12a - t11a) * 181 + 128) >> 8;
+    t12  = ((t12a + t11a) * 181 + 128) >> 8;
 
     out[ 0 * out_s] = CLIP(tmp[0] + t15a);
     out[ 1 * out_s] = CLIP(tmp[1] + t14);
@@ -261,14 +261,14 @@
     t30a = CLIP(t30  + t25);
     t31  = CLIP(t31a + t24a);
 
-    t20  = ((t27a - t20a) * 2896 + 2048) >> 12;
-    t27  = ((t27a + t20a) * 2896 + 2048) >> 12;
-    t21a = ((t26  - t21 ) * 2896 + 2048) >> 12;
-    t26a = ((t26  + t21 ) * 2896 + 2048) >> 12;
-    t22  = ((t25a - t22a) * 2896 + 2048) >> 12;
-    t25  = ((t25a + t22a) * 2896 + 2048) >> 12;
-    t23a = ((t24  - t23 ) * 2896 + 2048) >> 12;
-    t24a = ((t24  + t23 ) * 2896 + 2048) >> 12;
+    t20  = ((t27a - t20a) * 181 + 128) >> 8;
+    t27  = ((t27a + t20a) * 181 + 128) >> 8;
+    t21a = ((t26  - t21 ) * 181 + 128) >> 8;
+    t26a = ((t26  + t21 ) * 181 + 128) >> 8;
+    t22  = ((t25a - t22a) * 181 + 128) >> 8;
+    t25  = ((t25a + t22a) * 181 + 128) >> 8;
+    t23a = ((t24  - t23 ) * 181 + 128) >> 8;
+    t24a = ((t24  + t23 ) * 181 + 128) >> 8;
 
     out[ 0 * out_s] = CLIP(tmp[ 0] + t31);
     out[ 1 * out_s] = CLIP(tmp[ 1] + t30a);
@@ -546,22 +546,22 @@
     t62  = CLIP(t62a + t49a);
     t63a = CLIP(t63  + t48);
 
-    t40a = (t40  * -2896 + t55  * 2896 + 2048) >> 12;
-    t41  = (t41a * -2896 + t54a * 2896 + 2048) >> 12;
-    t42a = (t42  * -2896 + t53  * 2896 + 2048) >> 12;
-    t43  = (t43a * -2896 + t52a * 2896 + 2048) >> 12;
-    t44a = (t44  * -2896 + t51  * 2896 + 2048) >> 12;
-    t45  = (t45a * -2896 + t50a * 2896 + 2048) >> 12;
-    t46a = (t46  * -2896 + t49  * 2896 + 2048) >> 12;
-    t47  = (t47a * -2896 + t48a * 2896 + 2048) >> 12;
-    t48  = (t47a *  2896 + t48a * 2896 + 2048) >> 12;
-    t49a = (t46  *  2896 + t49  * 2896 + 2048) >> 12;
-    t50  = (t45a *  2896 + t50a * 2896 + 2048) >> 12;
-    t51a = (t44  *  2896 + t51  * 2896 + 2048) >> 12;
-    t52  = (t43a *  2896 + t52a * 2896 + 2048) >> 12;
-    t53a = (t42  *  2896 + t53  * 2896 + 2048) >> 12;
-    t54  = (t41a *  2896 + t54a * 2896 + 2048) >> 12;
-    t55a = (t40  *  2896 + t55  * 2896 + 2048) >> 12;
+    t40a = (t40  * -181 + t55  * 181 + 128) >> 8;
+    t41  = (t41a * -181 + t54a * 181 + 128) >> 8;
+    t42a = (t42  * -181 + t53  * 181 + 128) >> 8;
+    t43  = (t43a * -181 + t52a * 181 + 128) >> 8;
+    t44a = (t44  * -181 + t51  * 181 + 128) >> 8;
+    t45  = (t45a * -181 + t50a * 181 + 128) >> 8;
+    t46a = (t46  * -181 + t49  * 181 + 128) >> 8;
+    t47  = (t47a * -181 + t48a * 181 + 128) >> 8;
+    t48  = (t47a *  181 + t48a * 181 + 128) >> 8;
+    t49a = (t46  *  181 + t49  * 181 + 128) >> 8;
+    t50  = (t45a *  181 + t50a * 181 + 128) >> 8;
+    t51a = (t44  *  181 + t51  * 181 + 128) >> 8;
+    t52  = (t43a *  181 + t52a * 181 + 128) >> 8;
+    t53a = (t42  *  181 + t53  * 181 + 128) >> 8;
+    t54  = (t41a *  181 + t54a * 181 + 128) >> 8;
+    t55a = (t40  *  181 + t55  * 181 + 128) >> 8;
 
     out[ 0 * out_s] = CLIP(tmp[ 0] + t63a);
     out[ 1 * out_s] = CLIP(tmp[ 1] + t62);
@@ -690,10 +690,10 @@
     t6             = CLIP(  t4a - t6a );
     t7             = CLIP(  t5a - t7a );
 
-    out[3 * out_s] = -(((t2 + t3) * 2896 + 2048) >> 12);
-    out[4 * out_s] =   ((t2 - t3) * 2896 + 2048) >> 12;
-    out[2 * out_s] =   ((t6 + t7) * 2896 + 2048) >> 12;
-    out[5 * out_s] = -(((t6 - t7) * 2896 + 2048) >> 12);
+    out[3 * out_s] = -(((t2 + t3) * 181 + 128) >> 8);
+    out[4 * out_s] =   ((t2 - t3) * 181 + 128) >> 8;
+    out[2 * out_s] =   ((t6 + t7) * 181 + 128) >> 8;
+    out[5 * out_s] = -(((t6 - t7) * 181 + 128) >> 8);
 }
 
 static void NOINLINE
@@ -796,14 +796,14 @@
     t14a            = CLIP(  t12 - t14  );
     t15a            = CLIP(  t13 - t15  );
 
-    out[ 7 * out_s] = -(((t2a  + t3a)  * 2896 + 2048) >> 12);
-    out[ 8 * out_s] =   ((t2a  - t3a)  * 2896 + 2048) >> 12;
-    out[ 4 * out_s] =   ((t6   + t7)   * 2896 + 2048) >> 12;
-    out[11 * out_s] = -(((t6   - t7)   * 2896 + 2048) >> 12);
-    out[ 6 * out_s] =   ((t10  + t11)  * 2896 + 2048) >> 12;
-    out[ 9 * out_s] = -(((t10  - t11)  * 2896 + 2048) >> 12);
-    out[ 5 * out_s] = -(((t14a + t15a) * 2896 + 2048) >> 12);
-    out[10 * out_s] =   ((t14a - t15a) * 2896 + 2048) >> 12;
+    out[ 7 * out_s] = -(((t2a  + t3a)  * 181 + 128) >> 8);
+    out[ 8 * out_s] =   ((t2a  - t3a)  * 181 + 128) >> 8;
+    out[ 4 * out_s] =   ((t6   + t7)   * 181 + 128) >> 8;
+    out[11 * out_s] = -(((t6   - t7)   * 181 + 128) >> 8);
+    out[ 6 * out_s] =   ((t10  + t11)  * 181 + 128) >> 8;
+    out[ 9 * out_s] = -(((t10  - t11)  * 181 + 128) >> 8);
+    out[ 5 * out_s] = -(((t14a + t15a) * 181 + 128) >> 8);
+    out[10 * out_s] =   ((t14a - t15a) * 181 + 128) >> 8;
 }
 
 #define flip_inv_adst(sz) \