shithub: libvpx

Download patch

ref: 15b5a6a2c7bda61d62536a95b3f84390efc88350
parent: 7d61f8fe53228dce8edcdd0e5245021e58f616dc
author: Deb Mukherjee <debargha@google.com>
date: Mon Jul 22 10:47:57 EDT 2013

Flexible support for various pattern searches

Adds a few pattern searches to achieve various tradeoffs
between motion estimation complexity and performance.
The search framework is unified across these searches so that a
common pattern search function is used for all. Besides it will
be easier to experiment with various patterns or combinations
thereof at different scales in the future.

The new pattern search is multi-scale and is capable of using
different patterns at different scales.

The new hex search uses 8 points at the smallest scale
and 6 points at other scales.
Two other pattern searches - big-diamond and square are
also added. Big diamond uses 4 points at the smallest scale and
8 points in diamond shape at the larger scales.
Square is very similar conceptually to the default n-step search
but is somewhat faster since it keeps only one survivor across
all scales.

Psnr/speed-up results on derf300:

hex: -1.6% psnr%, 6-8% speed-up
big-diamond: -0.96% psnr, 4-5% speedup
square: -0.93% psnr, 4-5% speedup

Change-Id: I02a7ef5193f762601e0994e2c99399a3535a43d2

--- a/vp9/encoder/vp9_mbgraph.c
+++ b/vp9/encoder/vp9_mbgraph.c
@@ -46,8 +46,9 @@
   ref_full.as_mv.row = ref_mv->as_mv.row >> 3;
 
   /*cpi->sf.search_method == HEX*/
-  best_err = vp9_hex_search(x, &ref_full, dst_mv, step_param, x->errorperbit,
-                            &v_fn_ptr, NULL, NULL, NULL, NULL, ref_mv);
+  best_err = vp9_hex_search(x, &ref_full, step_param, x->errorperbit,
+                            0, &v_fn_ptr,
+                            0, ref_mv, dst_mv);
 
   // Try sub-pixel MC
   // if (bestsme > error_thresh && bestsme < INT_MAX)
--- a/vp9/encoder/vp9_mcomp.c
+++ b/vp9/encoder/vp9_mcomp.c
@@ -1245,8 +1245,10 @@
   {\
     if (thissad < bestsad)\
     {\
-      thissad += mvsad_err_cost(&this_mv, &fcenter_mv, mvjsadcost, mvsadcost, \
-                                sad_per_bit);\
+      if (use_mvcost) \
+        thissad += mvsad_err_cost(&this_mv, &fcenter_mv, \
+                                  mvjsadcost, mvsadcost, \
+                                  sad_per_bit);\
       if (thissad < bestsad)\
       {\
         bestsad = thissad;\
@@ -1255,46 +1257,53 @@
     }\
   }
 
-static const MV next_chkpts[6][3] = {
-  {{ -2, 0}, { -1, -2}, {1, -2}},
-  {{ -1, -2}, {1, -2}, {2, 0}},
-  {{1, -2}, {2, 0}, {1, 2}},
-  {{2, 0}, {1, 2}, { -1, 2}},
-  {{1, 2}, { -1, 2}, { -2, 0}},
-  {{ -1, 2}, { -2, 0}, { -1, -2}}
-};
+#define get_next_chkpts(list, i, n)   \
+    list[0] = ((i) == 0 ? (n) - 1 : (i) - 1);  \
+    list[1] = (i);                             \
+    list[2] = ((i) == (n) - 1 ? 0 : (i) + 1);
 
-int vp9_hex_search
-(
-  MACROBLOCK *x,
-  int_mv *ref_mv,
-  int_mv *best_mv,
-  int search_param,
-  int sad_per_bit,
-  const vp9_variance_fn_ptr_t *vfp,
-  int *mvjsadcost, int *mvsadcost[2],
-  int *mvjcost, int *mvcost[2],
-  int_mv *center_mv
-) {
-  const MACROBLOCKD* const xd = &x->e_mbd;
-  MV hex[6] = { { -1, -2}, {1, -2}, {2, 0}, {1, 2}, { -1, 2}, { -2, 0} };
-  MV neighbors[4] = {{0, -1}, { -1, 0}, {1, 0}, {0, 1}};
-  int i, j;
+#define MAX_PATTERN_SCALES         11
+#define MAX_PATTERN_CANDIDATES      8  // max number of canddiates per scale
+#define PATTERN_CANDIDATES_REF      3  // number of refinement candidates
 
+// Generic pattern search function that searches over multiple scales.
+// Each scale can have a different number of candidates and shape of
+// candidates as indicated in the num_candidates and candidates arrays
+// passed into this function
+static int vp9_pattern_search(MACROBLOCK *x,
+                              int_mv *ref_mv,
+                              int search_param,
+                              int sad_per_bit,
+                              int do_init_search,
+                              int do_refine,
+                              const vp9_variance_fn_ptr_t *vfp,
+                              int use_mvcost,
+                              int_mv *center_mv, int_mv *best_mv,
+                              const int num_candidates[MAX_PATTERN_SCALES],
+                              const MV candidates[MAX_PATTERN_SCALES]
+                                                 [MAX_PATTERN_CANDIDATES]) {
+  const MACROBLOCKD* const xd = &x->e_mbd;
+  static const int search_param_to_steps[MAX_MVSEARCH_STEPS] = {
+    10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
+  };
+  int i, j, s, t;
   uint8_t *what = x->plane[0].src.buf;
   int what_stride = x->plane[0].src.stride;
   int in_what_stride = xd->plane[0].pre[0].stride;
   int br, bc;
   int_mv this_mv;
-  unsigned int bestsad = 0x7fffffff;
-  unsigned int thissad;
+  int bestsad = INT_MAX;
+  int thissad;
   uint8_t *base_offset;
   uint8_t *this_offset;
   int k = -1;
   int all_in;
   int best_site = -1;
-
   int_mv fcenter_mv;
+  int best_init_s = search_param_to_steps[search_param];
+  int *mvjsadcost = x->nmvjointsadcost;
+  int *mvsadcost[2] = {x->nmvsadcost[0], x->nmvsadcost[1]};
+
   fcenter_mv.as_mv.row = center_mv->as_mv.row >> 3;
   fcenter_mv.as_mv.col = center_mv->as_mv.col >> 3;
 
@@ -1306,7 +1315,7 @@
 
   // Work out the start point for the search
   base_offset = (uint8_t *)(xd->plane[0].pre[0].buf);
-  this_offset = base_offset + (br * (xd->plane[0].pre[0].stride)) + bc;
+  this_offset = base_offset + (br * in_what_stride) + bc;
   this_mv.as_mv.row = br;
   this_mv.as_mv.col = bc;
   bestsad = vfp->sdf(what, what_stride, this_offset,
@@ -1314,109 +1323,310 @@
             + mvsad_err_cost(&this_mv, &fcenter_mv, mvjsadcost, mvsadcost,
                              sad_per_bit);
 
-  // hex search
-  // j=0
-  CHECK_BOUNDS(2)
-
-  if (all_in) {
-    for (i = 0; i < 6; i++) {
-      this_mv.as_mv.row = br + hex[i].row;
-      this_mv.as_mv.col = bc + hex[i].col;
-      this_offset = base_offset + (this_mv.as_mv.row * in_what_stride) + this_mv.as_mv.col;
-      thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride, bestsad);
-      CHECK_BETTER
+  // Search all possible scales upto the search param around the center point
+  // pick the scale of the point that is best as the starting scale of
+  // further steps around it.
+  if (do_init_search) {
+    s = best_init_s;
+    best_init_s = -1;
+    for (t = 0; t <= s; ++t) {
+      best_site = -1;
+      CHECK_BOUNDS((1 << t))
+      if (all_in) {
+        for (i = 0; i < num_candidates[t]; i++) {
+          this_mv.as_mv.row = br + candidates[t][i].row;
+          this_mv.as_mv.col = bc + candidates[t][i].col;
+          this_offset = base_offset + (this_mv.as_mv.row * in_what_stride) +
+              this_mv.as_mv.col;
+          thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
+                             bestsad);
+          CHECK_BETTER
+        }
+      } else {
+        for (i = 0; i < num_candidates[t]; i++) {
+          this_mv.as_mv.row = br + candidates[t][i].row;
+          this_mv.as_mv.col = bc + candidates[t][i].col;
+          CHECK_POINT
+          this_offset = base_offset + (this_mv.as_mv.row * in_what_stride) +
+                        this_mv.as_mv.col;
+          thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
+                             bestsad);
+          CHECK_BETTER
+        }
+      }
+      if (best_site == -1) {
+        continue;
+      } else {
+        best_init_s = t;
+        k = best_site;
+      }
     }
-  } else {
-    for (i = 0; i < 6; i++) {
-      this_mv.as_mv.row = br + hex[i].row;
-      this_mv.as_mv.col = bc + hex[i].col;
-      CHECK_POINT
-      this_offset = base_offset + (this_mv.as_mv.row * in_what_stride) + this_mv.as_mv.col;
-      thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride, bestsad);
-      CHECK_BETTER
+    if (best_init_s != -1) {
+      br += candidates[best_init_s][k].row;
+      bc += candidates[best_init_s][k].col;
     }
   }
 
-  if (best_site == -1)
-    goto cal_neighbors;
-  else {
-    br += hex[best_site].row;
-    bc += hex[best_site].col;
-    k = best_site;
-  }
-
-  for (j = 1; j < 127; j++) {
+  // If the center point is still the best, just skip this and move to
+  // the refinement step.
+  if (best_init_s != -1) {
+    s = best_init_s;
     best_site = -1;
-    CHECK_BOUNDS(2)
+    do {
+      // No need to search all 6 points the 1st time if initial search was used
+      if (!do_init_search || s != best_init_s) {
+        CHECK_BOUNDS((1 << s))
+        if (all_in) {
+          for (i = 0; i < num_candidates[s]; i++) {
+            this_mv.as_mv.row = br + candidates[s][i].row;
+            this_mv.as_mv.col = bc + candidates[s][i].col;
+            this_offset = base_offset + (this_mv.as_mv.row * in_what_stride) +
+                this_mv.as_mv.col;
+            thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
+                               bestsad);
+            CHECK_BETTER
+          }
+        } else {
+          for (i = 0; i < num_candidates[s]; i++) {
+            this_mv.as_mv.row = br + candidates[s][i].row;
+            this_mv.as_mv.col = bc + candidates[s][i].col;
+            CHECK_POINT
+            this_offset = base_offset + (this_mv.as_mv.row * in_what_stride) +
+                          this_mv.as_mv.col;
+            thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
+                               bestsad);
+            CHECK_BETTER
+          }
+        }
 
-    if (all_in) {
-      for (i = 0; i < 3; i++) {
-        this_mv.as_mv.row = br + next_chkpts[k][i].row;
-        this_mv.as_mv.col = bc + next_chkpts[k][i].col;
-        this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) + this_mv.as_mv.col;
-        thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride, bestsad);
-        CHECK_BETTER
+        if (best_site == -1) {
+          continue;
+        } else {
+          br += candidates[s][best_site].row;
+          bc += candidates[s][best_site].col;
+          k = best_site;
+        }
       }
-    } else {
-      for (i = 0; i < 3; i++) {
-        this_mv.as_mv.row = br + next_chkpts[k][i].row;
-        this_mv.as_mv.col = bc + next_chkpts[k][i].col;
-        CHECK_POINT
-        this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) + this_mv.as_mv.col;
-        thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride, bestsad);
-        CHECK_BETTER
-      }
-    }
 
-    if (best_site == -1)
-      break;
-    else {
-      br += next_chkpts[k][best_site].row;
-      bc += next_chkpts[k][best_site].col;
-      k += 5 + best_site;
-      if (k >= 12) k -= 12;
-      else if (k >= 6) k -= 6;
-    }
+      do {
+        int next_chkpts_indices[PATTERN_CANDIDATES_REF];
+        best_site = -1;
+        CHECK_BOUNDS((1 << s))
+
+        get_next_chkpts(next_chkpts_indices, k, num_candidates[s]);
+        if (all_in) {
+          for (i = 0; i < PATTERN_CANDIDATES_REF; i++) {
+            this_mv.as_mv.row = br +
+                candidates[s][next_chkpts_indices[i]].row;
+            this_mv.as_mv.col = bc +
+                candidates[s][next_chkpts_indices[i]].col;
+            this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) +
+                this_mv.as_mv.col;
+            thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
+                               bestsad);
+            CHECK_BETTER
+          }
+        } else {
+          for (i = 0; i < PATTERN_CANDIDATES_REF; i++) {
+            this_mv.as_mv.row = br +
+                candidates[s][next_chkpts_indices[i]].row;
+            this_mv.as_mv.col = bc +
+                candidates[s][next_chkpts_indices[i]].col;
+            CHECK_POINT
+            this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) +
+                          this_mv.as_mv.col;
+            thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
+                               bestsad);
+            CHECK_BETTER
+          }
+        }
+
+        if (best_site != -1) {
+          k = next_chkpts_indices[best_site];
+          br += candidates[s][k].row;
+          bc += candidates[s][k].col;
+        }
+      } while (best_site != -1);
+    } while (s--);
   }
 
-  // check 4 1-away neighbors
-cal_neighbors:
-  for (j = 0; j < 32; j++) {
-    best_site = -1;
-    CHECK_BOUNDS(1)
+  // Check 4 1-away neighbors if do_refine is true.
+  // For most well-designed schemes do_refine will not be necessary.
+  if (do_refine) {
+    static const MV neighbors[4] = {
+      {0, -1}, { -1, 0}, {1, 0}, {0, 1},
+    };
+    for (j = 0; j < 16; j++) {
+      best_site = -1;
+      CHECK_BOUNDS(1)
+      if (all_in) {
+        for (i = 0; i < 4; i++) {
+          this_mv.as_mv.row = br + neighbors[i].row;
+          this_mv.as_mv.col = bc + neighbors[i].col;
+          this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) +
+              this_mv.as_mv.col;
+          thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
+                             bestsad);
+          CHECK_BETTER
+        }
+      } else {
+        for (i = 0; i < 4; i++) {
+          this_mv.as_mv.row = br + neighbors[i].row;
+          this_mv.as_mv.col = bc + neighbors[i].col;
+          CHECK_POINT
+          this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) +
+                        this_mv.as_mv.col;
+          thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
+                             bestsad);
+          CHECK_BETTER
+        }
+          }
 
-    if (all_in) {
-      for (i = 0; i < 4; i++) {
-        this_mv.as_mv.row = br + neighbors[i].row;
-        this_mv.as_mv.col = bc + neighbors[i].col;
-        this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) + this_mv.as_mv.col;
-        thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride, bestsad);
-        CHECK_BETTER
+      if (best_site == -1) {
+        break;
+      } else {
+        br += neighbors[best_site].row;
+        bc += neighbors[best_site].col;
       }
-    } else {
-      for (i = 0; i < 4; i++) {
-        this_mv.as_mv.row = br + neighbors[i].row;
-        this_mv.as_mv.col = bc + neighbors[i].col;
-        CHECK_POINT
-        this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) + this_mv.as_mv.col;
-        thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride, bestsad);
-        CHECK_BETTER
-      }
     }
-
-    if (best_site == -1)
-      break;
-    else {
-      br += neighbors[best_site].row;
-      bc += neighbors[best_site].col;
-    }
   }
 
   best_mv->as_mv.row = br;
   best_mv->as_mv.col = bc;
 
-  return bestsad;
+  this_offset = base_offset + (best_mv->as_mv.row * (in_what_stride)) +
+      best_mv->as_mv.col;
+  this_mv.as_mv.row = best_mv->as_mv.row << 3;
+  this_mv.as_mv.col = best_mv->as_mv.col << 3;
+  if (bestsad == INT_MAX)
+    return INT_MAX;
+  return
+      vfp->vf(what, what_stride, this_offset, in_what_stride,
+              (unsigned int *)(&bestsad)) +
+      use_mvcost ? mv_err_cost(&this_mv, center_mv, x->nmvjointcost, x->mvcost,
+                               x->errorperbit) : 0;
 }
+
+
+int vp9_hex_search(MACROBLOCK *x,
+                   int_mv *ref_mv,
+                   int search_param,
+                   int sad_per_bit,
+                   int do_init_search,
+                   const vp9_variance_fn_ptr_t *vfp,
+                   int use_mvcost,
+                   int_mv *center_mv, int_mv *best_mv) {
+  // First scale has 8-closest points, the rest have 6 points in hex shape
+  // at increasing scales
+  static const int hex_num_candidates[MAX_PATTERN_SCALES] = {
+    8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6
+  };
+  // Note that the largest candidate step at each scale is 2^scale
+  static const MV hex_candidates[MAX_PATTERN_SCALES][MAX_PATTERN_CANDIDATES] = {
+    {{-1, -1}, {0, -1}, {1, -1}, {1, 0}, {1, 1}, { 0, 1}, { -1, 1}, {-1, 0}},
+    {{-1, -2}, {1, -2}, {2, 0}, {1, 2}, { -1, 2}, { -2, 0}},
+    {{-2, -4}, {2, -4}, {4, 0}, {2, 4}, { -2, 4}, { -4, 0}},
+    {{-4, -8}, {4, -8}, {8, 0}, {4, 8}, { -4, 8}, { -8, 0}},
+    {{-8, -16}, {8, -16}, {16, 0}, {8, 16}, { -8, 16}, { -16, 0}},
+    {{-16, -32}, {16, -32}, {32, 0}, {16, 32}, { -16, 32}, { -32, 0}},
+    {{-32, -64}, {32, -64}, {64, 0}, {32, 64}, { -32, 64}, { -64, 0}},
+    {{-64, -128}, {64, -128}, {128, 0}, {64, 128}, { -64, 128}, { -128, 0}},
+    {{-128, -256}, {128, -256}, {256, 0}, {128, 256}, { -128, 256}, { -256, 0}},
+    {{-256, -512}, {256, -512}, {512, 0}, {256, 512}, { -256, 512}, { -512, 0}},
+    {{-512, -1024}, {512, -1024}, {1024, 0}, {512, 1024}, { -512, 1024},
+      { -1024, 0}},
+  };
+  return
+      vp9_pattern_search(x, ref_mv, search_param, sad_per_bit,
+                         do_init_search, 0, vfp, use_mvcost,
+                         center_mv, best_mv,
+                         hex_num_candidates, hex_candidates);
+}
+
+int vp9_bigdia_search(MACROBLOCK *x,
+                      int_mv *ref_mv,
+                      int search_param,
+                      int sad_per_bit,
+                      int do_init_search,
+                      const vp9_variance_fn_ptr_t *vfp,
+                      int use_mvcost,
+                      int_mv *center_mv,
+                      int_mv *best_mv) {
+  // First scale has 4-closest points, the rest have 8 points in diamond
+  // shape at increasing scales
+  static const int bigdia_num_candidates[MAX_PATTERN_SCALES] = {
+    4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
+  };
+  // Note that the largest candidate step at each scale is 2^scale
+  static const MV bigdia_candidates[MAX_PATTERN_SCALES]
+                                   [MAX_PATTERN_CANDIDATES] = {
+    {{0, -1}, {1, 0}, { 0, 1}, {-1, 0}},
+    {{-1, -1}, {0, -2}, {1, -1}, {2, 0}, {1, 1}, {0, 2}, {-1, 1}, {-2, 0}},
+    {{-2, -2}, {0, -4}, {2, -2}, {4, 0}, {2, 2}, {0, 4}, {-2, 2}, {-4, 0}},
+    {{-4, -4}, {0, -8}, {4, -4}, {8, 0}, {4, 4}, {0, 8}, {-4, 4}, {-8, 0}},
+    {{-8, -8}, {0, -16}, {8, -8}, {16, 0}, {8, 8}, {0, 16}, {-8, 8}, {-16, 0}},
+    {{-16, -16}, {0, -32}, {16, -16}, {32, 0}, {16, 16}, {0, 32},
+      {-16, 16}, {-32, 0}},
+    {{-32, -32}, {0, -64}, {32, -32}, {64, 0}, {32, 32}, {0, 64},
+      {-32, 32}, {-64, 0}},
+    {{-64, -64}, {0, -128}, {64, -64}, {128, 0}, {64, 64}, {0, 128},
+      {-64, 64}, {-128, 0}},
+    {{-128, -128}, {0, -256}, {128, -128}, {256, 0}, {128, 128}, {0, 256},
+      {-128, 128}, {-256, 0}},
+    {{-256, -256}, {0, -512}, {256, -256}, {512, 0}, {256, 256}, {0, 512},
+      {-256, 256}, {-512, 0}},
+    {{-512, -512}, {0, -1024}, {512, -512}, {1024, 0}, {512, 512}, {0, 1024},
+      {-512, 512}, {-1024, 0}},
+  };
+  return
+      vp9_pattern_search(x, ref_mv, search_param, sad_per_bit,
+                         do_init_search, 0, vfp, use_mvcost,
+                         center_mv, best_mv,
+                         bigdia_num_candidates, bigdia_candidates);
+}
+
+int vp9_square_search(MACROBLOCK *x,
+                      int_mv *ref_mv,
+                      int search_param,
+                      int sad_per_bit,
+                      int do_init_search,
+                      const vp9_variance_fn_ptr_t *vfp,
+                      int use_mvcost,
+                      int_mv *center_mv,
+                      int_mv *best_mv) {
+  // All scales have 8 closest points in square shape
+  static const int square_num_candidates[MAX_PATTERN_SCALES] = {
+    8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
+  };
+  // Note that the largest candidate step at each scale is 2^scale
+  static const MV square_candidates[MAX_PATTERN_SCALES]
+                                   [MAX_PATTERN_CANDIDATES] = {
+    {{-1, -1}, {0, -1}, {1, -1}, {1, 0}, {1, 1}, {0, 1}, {-1, 1}, {-1, 0}},
+    {{-2, -2}, {0, -2}, {2, -2}, {2, 0}, {2, 2}, {0, 2}, {-2, 2}, {-2, 0}},
+    {{-4, -4}, {0, -4}, {4, -4}, {4, 0}, {4, 4}, {0, 4}, {-4, 4}, {-4, 0}},
+    {{-8, -8}, {0, -8}, {8, -8}, {8, 0}, {8, 8}, {0, 8}, {-8, 8}, {-8, 0}},
+    {{-16, -16}, {0, -16}, {16, -16}, {16, 0}, {16, 16}, {0, 16},
+      {-16, 16}, {-16, 0}},
+    {{-32, -32}, {0, -32}, {32, -32}, {32, 0}, {32, 32}, {0, 32},
+      {-32, 32}, {-32, 0}},
+    {{-64, -64}, {0, -64}, {64, -64}, {64, 0}, {64, 64}, {0, 64},
+      {-64, 64}, {-64, 0}},
+    {{-128, -128}, {0, -128}, {128, -128}, {128, 0}, {128, 128}, {0, 128},
+      {-128, 128}, {-128, 0}},
+    {{-256, -256}, {0, -256}, {256, -256}, {256, 0}, {256, 256}, {0, 256},
+      {-256, 256}, {-256, 0}},
+    {{-512, -512}, {0, -512}, {512, -512}, {512, 0}, {512, 512}, {0, 512},
+      {-512, 512}, {-512, 0}},
+    {{-1024, -1024}, {0, -1024}, {1024, -1024}, {1024, 0}, {1024, 1024},
+      {0, 1024}, {-1024, 1024}, {-1024, 0}},
+  };
+  return
+      vp9_pattern_search(x, ref_mv, search_param, sad_per_bit,
+                         do_init_search, 0, vfp, use_mvcost,
+                         center_mv, best_mv,
+                         square_num_candidates, square_candidates);
+};
+
 #undef CHECK_BOUNDS
 #undef CHECK_POINT
 #undef CHECK_BETTER
--- a/vp9/encoder/vp9_mcomp.h
+++ b/vp9/encoder/vp9_mcomp.h
@@ -40,12 +40,32 @@
                            int_mv *ref_mv, int_mv *dst_mv);
 
 int vp9_hex_search(MACROBLOCK *x,
-                   int_mv *ref_mv, int_mv *best_mv,
-                   int search_param, int error_per_bit,
+                   int_mv *ref_mv,
+                   int search_param,
+                   int error_per_bit,
+                   int do_init_search,
                    const vp9_variance_fn_ptr_t *vf,
-                   int *mvjsadcost, int *mvsadcost[2],
-                   int *mvjcost, int *mvcost[2],
-                   int_mv *center_mv);
+                   int use_mvcost,
+                   int_mv *center_mv,
+                   int_mv *best_mv);
+int vp9_bigdia_search(MACROBLOCK *x,
+                      int_mv *ref_mv,
+                      int search_param,
+                      int error_per_bit,
+                      int do_init_search,
+                      const vp9_variance_fn_ptr_t *vf,
+                      int use_mvcost,
+                      int_mv *center_mv,
+                      int_mv *best_mv);
+int vp9_square_search(MACROBLOCK *x,
+                      int_mv *ref_mv,
+                      int search_param,
+                      int error_per_bit,
+                      int do_init_search,
+                      const vp9_variance_fn_ptr_t *vf,
+                      int use_mvcost,
+                      int_mv *center_mv,
+                      int_mv *best_mv);
 
 typedef int (fractional_mv_step_fp) (MACROBLOCK *x, int_mv
   *bestmv, int_mv *ref_mv, int error_per_bit, const vp9_variance_fn_ptr_t *vfp,
--- a/vp9/encoder/vp9_onyx_if.c
+++ b/vp9/encoder/vp9_onyx_if.c
@@ -830,6 +830,7 @@
         sf->disable_splitmv =
             (MIN(cpi->common.width, cpi->common.height) >= 720)? 1 : 0;
         sf->auto_mv_step_size = 1;
+        sf->search_method = SQUARE;
       }
       if (speed == 3) {
         sf->comp_inter_joint_search_thresh = BLOCK_SIZE_TYPES;
@@ -849,6 +850,7 @@
         sf->skip_encode_sb = 1;
         sf->disable_splitmv = 1;
         sf->auto_mv_step_size = 1;
+        sf->search_method = BIGDIA;
       }
       if (speed == 4) {
         sf->comp_inter_joint_search_thresh = BLOCK_SIZE_TYPES;
@@ -872,6 +874,7 @@
         // sf->reference_masking = 1;
 
         sf->disable_splitmv = 1;
+        sf->search_method = HEX;
       }
       /*
       if (speed == 2) {
--- a/vp9/encoder/vp9_onyx_int.h
+++ b/vp9/encoder/vp9_onyx_int.h
@@ -192,7 +192,9 @@
 typedef enum {
   DIAMOND = 0,
   NSTEP = 1,
-  HEX = 2
+  HEX = 2,
+  BIGDIA = 3,
+  SQUARE = 4
 } SEARCH_METHODS;
 
 typedef enum {
--- a/vp9/encoder/vp9_rdopt.c
+++ b/vp9/encoder/vp9_rdopt.c
@@ -1880,6 +1880,7 @@
           int thissme, bestsme = INT_MAX;
           int sadpb = x->sadperbit4;
           int_mv mvp_full;
+          int max_mv;
 
           /* Is the best so far sufficiently good that we cant justify doing
            * and new motion search. */
@@ -1896,19 +1897,16 @@
                 x->e_mbd.mode_info_context->bmi[i - 2].as_mv[0].as_int;
             }
           }
+          if (i == 0)
+            max_mv = x->max_mv_context[mbmi->ref_frame[0]];
+          else
+            max_mv = MAX(abs(bsi->mvp.as_mv.row), abs(bsi->mvp.as_mv.col)) >> 3;
           if (cpi->sf.auto_mv_step_size && cpi->common.show_frame) {
             // Take wtd average of the step_params based on the last frame's
             // max mv magnitude and the best ref mvs of the current block for
             // the given reference.
-            if (i == 0)
-              step_param = (vp9_init_search_range(
-                  cpi, x->max_mv_context[mbmi->ref_frame[0]]) +
-                  cpi->mv_step_param) >> 1;
-            else
-              step_param = (vp9_init_search_range(
-                  cpi, MAX(abs(bsi->mvp.as_mv.row),
-                           abs(bsi->mvp.as_mv.col)) >> 3) +
-                  cpi->mv_step_param) >> 1;
+            step_param = (vp9_init_search_range(cpi, max_mv) +
+                          cpi->mv_step_param) >> 1;
           } else {
             step_param = cpi->mv_step_param;
           }
@@ -1920,9 +1918,26 @@
 
           // adjust src pointer for this block
           mi_buf_shift(x, i);
-          bestsme = vp9_full_pixel_diamond(cpi, x, &mvp_full, step_param,
-                                           sadpb, further_steps, 0, v_fn_ptr,
-                                           bsi->ref_mv, &mode_mv[NEWMV]);
+          if (cpi->sf.search_method == HEX) {
+            bestsme = vp9_hex_search(x, &mvp_full,
+                                     step_param,
+                                     sadpb, 1, v_fn_ptr, 1,
+                                     bsi->ref_mv, &mode_mv[NEWMV]);
+          } else if (cpi->sf.search_method == SQUARE) {
+            bestsme = vp9_square_search(x, &mvp_full,
+                                        step_param,
+                                        sadpb, 1, v_fn_ptr, 1,
+                                        bsi->ref_mv, &mode_mv[NEWMV]);
+          } else if (cpi->sf.search_method == BIGDIA) {
+            bestsme = vp9_bigdia_search(x, &mvp_full,
+                                        step_param,
+                                        sadpb, 1, v_fn_ptr, 1,
+                                        bsi->ref_mv, &mode_mv[NEWMV]);
+          } else {
+            bestsme = vp9_full_pixel_diamond(cpi, x, &mvp_full, step_param,
+                                             sadpb, further_steps, 0, v_fn_ptr,
+                                             bsi->ref_mv, &mode_mv[NEWMV]);
+          }
 
           // Should we do a full search (best quality only)
           if (cpi->compressor_speed == 0) {
@@ -2497,10 +2512,30 @@
   // Further step/diamond searches as necessary
   further_steps = (cpi->sf.max_step_search_steps - 1) - step_param;
 
-  bestsme = vp9_full_pixel_diamond(cpi, x, &mvp_full, step_param,
-                                   sadpb, further_steps, 1,
-                                   &cpi->fn_ptr[block_size],
-                                   &ref_mv, tmp_mv);
+  if (cpi->sf.search_method == HEX) {
+    bestsme = vp9_hex_search(x, &mvp_full,
+                             step_param,
+                             sadpb, 1,
+                             &cpi->fn_ptr[block_size], 1,
+                             &ref_mv, tmp_mv);
+  } else if (cpi->sf.search_method == SQUARE) {
+    bestsme = vp9_square_search(x, &mvp_full,
+                                step_param,
+                                sadpb, 1,
+                                &cpi->fn_ptr[block_size], 1,
+                                &ref_mv, tmp_mv);
+  } else if (cpi->sf.search_method == BIGDIA) {
+    bestsme = vp9_bigdia_search(x, &mvp_full,
+                                step_param,
+                                sadpb, 1,
+                                &cpi->fn_ptr[block_size], 1,
+                                &ref_mv, tmp_mv);
+  } else {
+    bestsme = vp9_full_pixel_diamond(cpi, x, &mvp_full, step_param,
+                                     sadpb, further_steps, 1,
+                                     &cpi->fn_ptr[block_size],
+                                     &ref_mv, tmp_mv);
+  }
 
   x->mv_col_min = tmp_col_min;
   x->mv_col_max = tmp_col_max;
@@ -2508,7 +2543,7 @@
   x->mv_row_max = tmp_row_max;
 
   if (bestsme < INT_MAX) {
-    int dis; /* TODO: use dis in distortion calculation later. */
+    int dis;  /* TODO: use dis in distortion calculation later. */
     unsigned int sse;
     cpi->find_fractional_mv_step(x, tmp_mv, &ref_mv,
                                  x->errorperbit,
--- a/vp9/encoder/vp9_temporal_filter.c
+++ b/vp9/encoder/vp9_temporal_filter.c
@@ -154,10 +154,10 @@
   // TODO Check that the 16x16 vf & sdf are selected here
   // Ignore mv costing by sending NULL pointer instead of cost arrays
   ref_mv = &x->e_mbd.mode_info_context->bmi[0].as_mv[0];
-  bestsme = vp9_hex_search(x, &best_ref_mv1_full, ref_mv,
-                           step_param, sadpb, &cpi->fn_ptr[BLOCK_16X16],
-                           NULL, NULL, NULL, NULL,
-                           &best_ref_mv1);
+  bestsme = vp9_hex_search(x, &best_ref_mv1_full,
+                           step_param, sadpb, 1,
+                           &cpi->fn_ptr[BLOCK_16X16],
+                           0, &best_ref_mv1, ref_mv);
 
 #if ALT_REF_SUBPEL_ENABLED
   // Try sub-pixel MC?