shithub: libvpx

Download patch

ref: 3a4b18bc679f97590df9c6fc2004d8b5beed7a32
parent: 2f6fce3e5a30c9016aa816e2d57678d88729af37
author: Ronald S. Bultje <rbultje@google.com>
date: Wed Jan 30 04:30:46 EST 2013

don't code the branch for the predicted seg_id if that flag is false.

Change-Id: Icb6e21dc0c2d9918faa33c8bf70943660df7ad88

--- a/vp9/common/vp9_blockd.h
+++ b/vp9/common/vp9_blockd.h
@@ -358,6 +358,7 @@
 
   // Probability Tree used to code Segment number
   vp9_prob mb_segment_tree_probs[MB_FEATURE_TREE_PROBS];
+  vp9_prob mb_segment_mispred_tree_probs[MAX_MB_SEGMENTS];
 
 #if CONFIG_NEW_MVREF
   vp9_prob mb_mv_ref_probs[MAX_REF_FRAMES][MAX_MV_REF_CANDIDATES-1];
--- a/vp9/decoder/vp9_decodemv.c
+++ b/vp9/decoder/vp9_decodemv.c
@@ -87,6 +87,33 @@
   }
 }
 
+// This function reads the current macro block's segnent id from the bitstream
+// It should only be called if a segment map update is indicated.
+static void read_mb_segid_except(VP9_COMMON *cm,
+                                 vp9_reader *r, MB_MODE_INFO *mi,
+                                 MACROBLOCKD *xd, int mb_row, int mb_col) {
+  int pred_seg_id = vp9_get_pred_mb_segid(cm, xd,
+                                          mb_row * cm->mb_cols + mb_col);
+  const vp9_prob *p = xd->mb_segment_tree_probs;
+  vp9_prob p1 = xd->mb_segment_mispred_tree_probs[pred_seg_id];
+
+  /* Is segmentation enabled */
+  if (xd->segmentation_enabled && xd->update_mb_segmentation_map) {
+    /* If so then read the segment id. */
+    if (vp9_read(r, p1)) {
+      if (pred_seg_id < 2)
+        mi->segment_id = 2 + vp9_read(r, p[2]);
+      else
+        mi->segment_id = 2 + (pred_seg_id == 2);
+    } else {
+      if (pred_seg_id >= 2)
+        mi->segment_id = vp9_read(r, p[1]);
+      else
+        mi->segment_id = pred_seg_id == 0;
+    }
+  }
+}
+
 #if CONFIG_NEW_MVREF
 int vp9_read_mv_ref_id(vp9_reader *r,
                        vp9_prob * ref_id_probs) {
@@ -602,7 +629,7 @@
         }
         // Else .... decode it explicitly
         else {
-          read_mb_segid(bc, mbmi, xd);
+          read_mb_segid_except(cm, bc, mbmi, xd, mb_row, mb_col);
         }
       }
       // Normal unpredicted coding mode
--- a/vp9/decoder/vp9_decodframe.c
+++ b/vp9/decoder/vp9_decodframe.c
@@ -1458,6 +1458,22 @@
           pc->segment_pred_probs[i] = 255;
         }
       }
+
+      if (pc->temporal_update) {
+        int count[4];
+        const vp9_prob *p = xd->mb_segment_tree_probs;
+        vp9_prob *p_mod = xd->mb_segment_mispred_tree_probs;
+
+        count[0] =        p[0]  *        p[1];
+        count[1] =        p[0]  * (256 - p[1]);
+        count[2] = (256 - p[0]) *        p[2];
+        count[3] = (256 - p[0]) * (256 - p[2]);
+
+        p_mod[0] = get_binary_prob(count[1], count[2] + count[3]);
+        p_mod[1] = get_binary_prob(count[0], count[2] + count[3]);
+        p_mod[2] = get_binary_prob(count[0] + count[1], count[3]);
+        p_mod[3] = get_binary_prob(count[0] + count[1], count[2]);
+      }
     }
     // Is the segment data being updated
     xd->update_mb_segmentation_data = (unsigned char)vp9_read_bit(&header_bc);
--- a/vp9/encoder/vp9_bitstream.c
+++ b/vp9/encoder/vp9_bitstream.c
@@ -577,6 +577,28 @@
   }
 }
 
+static void write_mb_segid_except(VP9_COMMON *cm,
+                                  vp9_writer *bc,
+                                  const MB_MODE_INFO *mi,
+                                  const MACROBLOCKD *xd,
+                                  int mb_row, int mb_col) {
+  // Encode the MB segment id.
+  int seg_id = mi->segment_id;
+  int pred_seg_id = vp9_get_pred_mb_segid(cm, xd,
+                                          mb_row * cm->mb_cols + mb_col);
+  const vp9_prob *p = xd->mb_segment_tree_probs;
+  const vp9_prob p1 = xd->mb_segment_mispred_tree_probs[pred_seg_id];
+
+  if (xd->segmentation_enabled && xd->update_mb_segmentation_map) {
+    vp9_write(bc, seg_id >= 2, p1);
+    if (pred_seg_id >= 2 && seg_id < 2) {
+      vp9_write(bc, seg_id == 1, p[1]);
+    } else if (pred_seg_id < 2 && seg_id >= 2) {
+      vp9_write(bc, seg_id == 3, p[2]);
+    }
+  }
+}
+
 // This function encodes the reference frame
 static void encode_ref_frame(vp9_writer *const bc,
                              VP9_COMMON *const cm,
@@ -720,7 +742,7 @@
 
       // If the mb segment id wasn't predicted code explicitly
       if (!prediction_flag)
-        write_mb_segid(bc, mi, &cpi->mb.e_mbd);
+        write_mb_segid_except(pc, bc, mi, &cpi->mb.e_mbd, mb_row, mb_col);
     } else {
       // Normal unpredicted coding
       write_mb_segid(bc, mi, &cpi->mb.e_mbd);
--- a/vp9/encoder/vp9_segmentation.c
+++ b/vp9/encoder/vp9_segmentation.c
@@ -143,11 +143,74 @@
   return cost;
 }
 
+// Based on set of segment counts calculate a probability tree
+static void calc_segtree_probs_pred(MACROBLOCKD *xd,
+                                    int (*segcounts)[MAX_MB_SEGMENTS],
+                                    vp9_prob *segment_tree_probs,
+                                    vp9_prob *mod_probs) {
+  int count[4];
+
+  assert(!segcounts[0][0] && !segcounts[1][1] &&
+         !segcounts[2][2] && !segcounts[3][3]);
+
+  // Total count for all segments
+  count[0] = segcounts[3][0] + segcounts[1][0] + segcounts[2][0];
+  count[1] = segcounts[2][1] + segcounts[0][1] + segcounts[3][1];
+  count[2] = segcounts[0][2] + segcounts[3][2] + segcounts[1][2];
+  count[3] = segcounts[1][3] + segcounts[2][3] + segcounts[0][3];
+
+  // Work out probabilities of each segment
+  segment_tree_probs[0] = get_binary_prob(count[0] + count[1],
+                                          count[2] + count[3]);
+  segment_tree_probs[1] = get_binary_prob(count[0], count[1]);
+  segment_tree_probs[2] = get_binary_prob(count[2], count[3]);
+
+  // now work out modified counts that the decoder would have
+  count[0] =        segment_tree_probs[0]  *        segment_tree_probs[1];
+  count[1] =        segment_tree_probs[0]  * (256 - segment_tree_probs[1]);
+  count[2] = (256 - segment_tree_probs[0]) *        segment_tree_probs[2];
+  count[3] = (256 - segment_tree_probs[0]) * (256 - segment_tree_probs[2]);
+
+  // Work out modified probabilties depending on what segment was predicted
+  mod_probs[0] = get_binary_prob(count[1], count[2] + count[3]);
+  mod_probs[1] = get_binary_prob(count[0], count[2] + count[3]);
+  mod_probs[2] = get_binary_prob(count[0] + count[1], count[3]);
+  mod_probs[3] = get_binary_prob(count[0] + count[1], count[2]);
+}
+
+// Based on set of segment counts and probabilities calculate a cost estimate
+static int cost_segmap_pred(MACROBLOCKD *xd,
+                            int (*segcounts)[MAX_MB_SEGMENTS],
+                            vp9_prob *probs, vp9_prob *mod_probs) {
+  int pred_seg, cost = 0;
+
+  for (pred_seg = 0; pred_seg < MAX_MB_SEGMENTS; pred_seg++) {
+    int count1, count2;
+
+    // Cost the top node of the tree
+    count1 = segcounts[pred_seg][0] + segcounts[pred_seg][1];
+    count2 = segcounts[pred_seg][2] + segcounts[pred_seg][3];
+    cost += count1 * vp9_cost_zero(mod_probs[pred_seg]) +
+            count2 * vp9_cost_one(mod_probs[pred_seg]);
+
+    // Now add the cost of each individual segment branch
+    if (pred_seg >= 2 && count1) {
+      cost += segcounts[pred_seg][0] * vp9_cost_zero(probs[1]) +
+              segcounts[pred_seg][1] * vp9_cost_one(probs[1]);
+    } else if (pred_seg < 2 && count2 > 0) {
+      cost += segcounts[pred_seg][2] * vp9_cost_zero(probs[2]) +
+              segcounts[pred_seg][3] * vp9_cost_one(probs[2]);
+    }
+  }
+
+  return cost;
+}
+
 static void count_segs(VP9_COMP *cpi,
                        MODE_INFO *mi,
                        int *no_pred_segcounts,
                        int (*temporal_predictor_count)[2],
-                       int *t_unpred_seg_counts,
+                       int (*t_unpred_seg_counts)[MAX_MB_SEGMENTS],
                        int mb_size, int mb_row, int mb_col) {
   VP9_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &cpi->mb.e_mbd;
@@ -166,8 +229,8 @@
   // Temporal prediction not allowed on key frames
   if (cm->frame_type != KEY_FRAME) {
     // Test to see if the segment id matches the predicted value.
-    const int seg_predicted =
-        (segment_id == vp9_get_pred_mb_segid(cm, xd, segmap_index));
+    const int pred_seg_id = vp9_get_pred_mb_segid(cm, xd, segmap_index);
+    const int seg_predicted = (segment_id == pred_seg_id);
 
     // Get the segment id prediction context
     const int pred_context = vp9_get_pred_context(cm, xd, PRED_SEG_ID);
@@ -179,7 +242,7 @@
 
     if (!seg_predicted)
       // Update the "unpredicted" segment count
-      t_unpred_seg_counts[segment_id]++;
+      t_unpred_seg_counts[pred_seg_id][segment_id]++;
   }
 }
 
@@ -195,10 +258,11 @@
 
   int temporal_predictor_count[PREDICTION_PROBS][2];
   int no_pred_segcounts[MAX_MB_SEGMENTS];
-  int t_unpred_seg_counts[MAX_MB_SEGMENTS];
+  int t_unpred_seg_counts[MAX_MB_SEGMENTS][MAX_MB_SEGMENTS];
 
   vp9_prob no_pred_tree[MB_FEATURE_TREE_PROBS];
   vp9_prob t_pred_tree[MB_FEATURE_TREE_PROBS];
+  vp9_prob t_pred_tree_mod[MAX_MB_SEGMENTS];
   vp9_prob t_nopred_prob[PREDICTION_PROBS];
 
   const int mis = cm->mode_info_stride;
@@ -270,8 +334,10 @@
   if (cm->frame_type != KEY_FRAME) {
     // Work out probability tree for coding those segments not
     // predicted using the temporal method and the cost.
-    calc_segtree_probs(xd, t_unpred_seg_counts, t_pred_tree);
-    t_pred_cost = cost_segmap(xd, t_unpred_seg_counts, t_pred_tree);
+    calc_segtree_probs_pred(xd, t_unpred_seg_counts, t_pred_tree,
+                            t_pred_tree_mod);
+    t_pred_cost = cost_segmap_pred(xd, t_unpred_seg_counts, t_pred_tree,
+                                   t_pred_tree_mod);
 
     // Add in the cost of the signalling for each prediction context
     for (i = 0; i < PREDICTION_PROBS; i++) {
@@ -291,6 +357,8 @@
     cm->temporal_update = 1;
     vpx_memcpy(xd->mb_segment_tree_probs,
                t_pred_tree, sizeof(t_pred_tree));
+    vpx_memcpy(xd->mb_segment_mispred_tree_probs,
+               t_pred_tree_mod, sizeof(t_pred_tree_mod));
     vpx_memcpy(&cm->segment_pred_probs,
                t_nopred_prob, sizeof(t_nopred_prob));
   } else {