shithub: libvpx

Download patch

ref: 3ee1a21a4291fb35f7cb7dc8835ceae1539ccb71
parent: 4df9e7883c0ffb752715c87931baf58b7caee76c
author: Deb Mukherjee <debargha@google.com>
date: Thu Jun 6 07:14:04 EDT 2013

Coding updates for tx-size selection

Changes to the coding of transform sizes, along with forward
and backward probability updates.

Results:
derf300: +0.241%

Context based coding of transform sizes will be in a separate
patch.

Change-Id: I97241d60a926f014fee2de21fa4446ca56495756

--- a/vp9/common/vp9_blockd.h
+++ b/vp9/common/vp9_blockd.h
@@ -122,6 +122,15 @@
 
 #define WHT_UPSCALE_FACTOR 2
 
+#define TX_SIZE_PROBS  6  // (TX_SIZE_MAX_SB * (TX_SIZE_MAX_SB - 1) / 2)
+
+#if TX_SIZE_PROBS == 6
+#define get_tx_probs_offset(b) ((b) < BLOCK_SIZE_MB16X16 ? 0 : \
+                                (b) < BLOCK_SIZE_SB32X32 ? 1 : 3)
+#else
+#define get_tx_probs_offset(b) 0
+#endif
+
 /* For keyframes, intra block modes are predicted by the (already decoded)
    modes for the Y blocks to the left and above us; for interframes, there
    is a single probability table. */
--- a/vp9/common/vp9_entropymode.c
+++ b/vp9/common/vp9_entropymode.c
@@ -149,6 +149,54 @@
   { 235, 248 },
 };
 
+void tx_counts_to_branch_counts(unsigned int *tx_count_32x32p,
+                                unsigned int *tx_count_16x16p,
+                                unsigned int *tx_count_8x8p,
+                                unsigned int (*ct)[2]) {
+#if TX_SIZE_PROBS == 6
+  ct[0][0] = tx_count_8x8p[TX_4X4];
+  ct[0][1] = tx_count_8x8p[TX_8X8];
+  ct[1][0] = tx_count_16x16p[TX_4X4];
+  ct[1][1] = tx_count_16x16p[TX_8X8] + tx_count_16x16p[TX_16X16];
+  ct[2][0] = tx_count_16x16p[TX_8X8];
+  ct[2][1] = tx_count_16x16p[TX_16X16];
+  ct[3][0] = tx_count_32x32p[TX_4X4];
+  ct[3][1] = tx_count_32x32p[TX_8X8] + tx_count_32x32p[TX_16X16] +
+             tx_count_32x32p[TX_32X32];
+  ct[4][0] = tx_count_32x32p[TX_8X8];
+  ct[4][1] = tx_count_32x32p[TX_16X16] + tx_count_32x32p[TX_32X32];
+  ct[5][0] = tx_count_32x32p[TX_16X16];
+  ct[5][1] = tx_count_32x32p[TX_32X32];
+#else
+  ct[0][0] = tx_count_32x32p[TX_4X4] +
+             tx_count_16x16p[TX_4X4] +
+             tx_count_8x8p[TX_4X4];
+  ct[0][1] = tx_count_32x32p[TX_8X8] +
+             tx_count_32x32p[TX_16X16] +
+             tx_count_32x32p[TX_32X32] +
+             tx_count_16x16p[TX_8X8] +
+             tx_count_16x16p[TX_16X16] +
+             tx_count_8x8p[TX_8X8];
+  ct[1][0] = tx_count_32x32p[TX_8X8] +
+             tx_count_16x16p[TX_8X8];
+  ct[1][1] = tx_count_32x32p[TX_16X16] +
+             tx_count_32x32p[TX_32X32] +
+             tx_count_16x16p[TX_16X16];
+  ct[2][0] = tx_count_32x32p[TX_16X16];
+  ct[2][1] = tx_count_32x32p[TX_32X32];
+#endif
+}
+
+#if TX_SIZE_PROBS == 6
+const vp9_prob vp9_default_tx_probs[TX_SIZE_PROBS] = {
+  96, 96, 96, 96, 96, 96
+};
+#else
+const vp9_prob vp9_default_tx_probs[TX_SIZE_PROBS] = {
+  96, 96, 96
+};
+#endif
+
 void vp9_init_mbmode_probs(VP9_COMMON *x) {
   vpx_memcpy(x->fc.uv_mode_prob, default_if_uv_probs,
              sizeof(default_if_uv_probs));
@@ -171,6 +219,8 @@
              sizeof(default_comp_ref_p));
   vpx_memcpy(x->fc.single_ref_prob, default_single_ref_p,
              sizeof(default_single_ref_p));
+  vpx_memcpy(x->fc.tx_probs, vp9_default_tx_probs,
+             sizeof(vp9_default_tx_probs));
 }
 
 #if VP9_SWITCHABLE_FILTERS == 3
@@ -372,6 +422,20 @@
                         fc->switchable_interp_count[i],
                         fc->pre_switchable_interp_prob[i],
                         fc->switchable_interp_prob[i], 0);
+    }
+  }
+  if (cm->txfm_mode == TX_MODE_SELECT) {
+    unsigned int branch_ct[TX_SIZE_PROBS][2];
+    tx_counts_to_branch_counts(cm->fc.tx_count_32x32p,
+                               cm->fc.tx_count_16x16p,
+                               cm->fc.tx_count_8x8p, branch_ct);
+    for (i = 0; i < TX_SIZE_PROBS; ++i) {
+      int factor;
+      int count = branch_ct[i][0] + branch_ct[i][1];
+      vp9_prob prob = get_binary_prob(branch_ct[i][0], branch_ct[i][1]);
+      count = count > MODE_COUNT_SAT ? MODE_COUNT_SAT : count;
+      factor = (MODE_MAX_UPDATE_FACTOR * count / MODE_COUNT_SAT);
+      cm->fc.tx_probs[i] = weighted_prob(cm->fc.pre_tx_probs[i], prob, factor);
     }
   }
 }
--- a/vp9/common/vp9_entropymode.h
+++ b/vp9/common/vp9_entropymode.h
@@ -75,4 +75,11 @@
 extern const  vp9_prob vp9_switchable_interp_prob[VP9_SWITCHABLE_FILTERS + 1]
                                                  [VP9_SWITCHABLE_FILTERS - 1];
 
+extern const vp9_prob vp9_default_tx_probs[TX_SIZE_PROBS];
+
+extern void tx_counts_to_branch_counts(unsigned int *tx_count_32x32p,
+                                       unsigned int *tx_count_16x16p,
+                                       unsigned int *tx_count_8x8p,
+                                       unsigned int (*ct)[2]);
+
 #endif  // VP9_COMMON_VP9_ENTROPYMODE_H_
--- a/vp9/common/vp9_onyxc_int.h
+++ b/vp9/common/vp9_onyxc_int.h
@@ -89,6 +89,11 @@
   unsigned int comp_inter_count[COMP_INTER_CONTEXTS][2];
   unsigned int single_ref_count[REF_CONTEXTS][2][2];
   unsigned int comp_ref_count[REF_CONTEXTS][2];
+  vp9_prob tx_probs[TX_SIZE_PROBS];
+  vp9_prob pre_tx_probs[TX_SIZE_PROBS];
+  unsigned int tx_count_32x32p[TX_SIZE_MAX_SB];
+  unsigned int tx_count_16x16p[TX_SIZE_MAX_SB - 1];
+  unsigned int tx_count_8x8p[TX_SIZE_MAX_SB - 2];
 } FRAME_CONTEXT;
 
 typedef enum {
@@ -238,9 +243,6 @@
   MV_REFERENCE_FRAME comp_fixed_ref;
   MV_REFERENCE_FRAME comp_var_ref[2];
   COMPPREDMODE_TYPE comp_pred_mode;
-
-  // FIXME contextualize
-  vp9_prob prob_tx[TX_SIZE_MAX_SB - 1];
 
   vp9_prob mbskip_pred_probs[MBSKIP_CONTEXTS];
 
--- a/vp9/decoder/vp9_decodemv.c
+++ b/vp9/decoder/vp9_decodemv.c
@@ -64,13 +64,21 @@
 }
 
 static TX_SIZE select_txfm_size(VP9_COMMON *cm, vp9_reader *r,
-                                int allow_16x16, int allow_32x32) {
-  TX_SIZE txfm_size = vp9_read(r, cm->prob_tx[0]);  // TX_4X4 or >TX_4X4
-  if (txfm_size != TX_4X4 && allow_16x16) {
-    txfm_size += vp9_read(r, cm->prob_tx[1]);       // TX_8X8 or >TX_8X8
-    if (txfm_size != TX_8X8 && allow_32x32)
-      txfm_size += vp9_read(r, cm->prob_tx[2]);     // TX_16X16 or >TX_16X16
+                                BLOCK_SIZE_TYPE bsize) {
+  int tx_probs_offset = get_tx_probs_offset(bsize);
+  TX_SIZE txfm_size = vp9_read(r, cm->fc.tx_probs[tx_probs_offset]);
+  if (txfm_size != TX_4X4 && bsize >= BLOCK_SIZE_MB16X16) {
+    txfm_size += vp9_read(r, cm->fc.tx_probs[tx_probs_offset + 1]);
+    if (txfm_size != TX_8X8 && bsize >= BLOCK_SIZE_SB32X32)
+      txfm_size += vp9_read(r, cm->fc.tx_probs[tx_probs_offset + 2]);
   }
+  if (bsize >= BLOCK_SIZE_SB32X32) {
+    cm->fc.tx_count_32x32p[txfm_size]++;
+  } else if (bsize >= BLOCK_SIZE_MB16X16) {
+    cm->fc.tx_count_16x16p[txfm_size]++;
+  } else {
+    cm->fc.tx_count_8x8p[txfm_size]++;
+  }
   return txfm_size;
 }
 
@@ -96,9 +104,7 @@
 
   if (cm->txfm_mode == TX_MODE_SELECT &&
       m->mbmi.sb_type >= BLOCK_SIZE_SB8X8) {
-    const int allow_16x16 = m->mbmi.sb_type >= BLOCK_SIZE_MB16X16;
-    const int allow_32x32 = m->mbmi.sb_type >= BLOCK_SIZE_SB32X32;
-    m->mbmi.txfm_size = select_txfm_size(cm, r, allow_16x16, allow_32x32);
+    m->mbmi.txfm_size = select_txfm_size(cm, r,  m->mbmi.sb_type);
   } else if (cm->txfm_mode >= ALLOW_32X32 &&
              m->mbmi.sb_type >= BLOCK_SIZE_SB32X32) {
     m->mbmi.txfm_size = TX_32X32;
@@ -537,9 +543,7 @@
   if (cm->txfm_mode == TX_MODE_SELECT &&
       (mbmi->mb_skip_coeff == 0 || mbmi->ref_frame[0] == INTRA_FRAME) &&
       bsize >= BLOCK_SIZE_SB8X8) {
-    const int allow_16x16 = bsize >= BLOCK_SIZE_MB16X16;
-    const int allow_32x32 = bsize >= BLOCK_SIZE_SB32X32;
-    mbmi->txfm_size = select_txfm_size(cm, r, allow_16x16, allow_32x32);
+    mbmi->txfm_size = select_txfm_size(cm, r, bsize);
   } else if (bsize >= BLOCK_SIZE_SB32X32 &&
              cm->txfm_mode >= ALLOW_32X32) {
     mbmi->txfm_size = TX_32X32;
--- a/vp9/decoder/vp9_decodframe.c
+++ b/vp9/decoder/vp9_decodframe.c
@@ -57,11 +57,16 @@
     pc->txfm_mode = vp9_read_literal(r, 2);
     if (pc->txfm_mode == ALLOW_32X32)
       pc->txfm_mode += vp9_read_bit(r);
-
     if (pc->txfm_mode == TX_MODE_SELECT) {
-      pc->prob_tx[0] = vp9_read_prob(r);
-      pc->prob_tx[1] = vp9_read_prob(r);
-      pc->prob_tx[2] = vp9_read_prob(r);
+      int i;
+      for (i = 0; i < TX_SIZE_PROBS; ++i) {
+        if (vp9_read(r, VP9_DEF_UPDATE_PROB))
+           pc->fc.tx_probs[i] =
+               vp9_read_prob_diff_update(r, pc->fc.tx_probs[i]);
+      }
+    } else {
+      vpx_memcpy(pc->fc.tx_probs, vp9_default_tx_probs,
+                 sizeof(vp9_default_tx_probs));
     }
   }
 }
@@ -779,6 +784,7 @@
   fc->pre_nmvc = fc->nmvc;
   vp9_copy(fc->pre_switchable_interp_prob, fc->switchable_interp_prob);
   vp9_copy(fc->pre_inter_mode_probs, fc->inter_mode_probs);
+  vp9_copy(fc->pre_tx_probs, fc->tx_probs);
 
   vp9_zero(fc->coef_counts);
   vp9_zero(fc->eob_branch_counts);
@@ -792,6 +798,9 @@
   vp9_zero(fc->comp_inter_count);
   vp9_zero(fc->single_ref_count);
   vp9_zero(fc->comp_ref_count);
+  vp9_zero(fc->tx_count_8x8p);
+  vp9_zero(fc->tx_count_16x16p);
+  vp9_zero(fc->tx_count_32x32p);
 }
 
 static void decode_tile(VP9D_COMP *pbi, vp9_reader *r) {
@@ -1050,9 +1059,9 @@
 
   pc->fc = pc->frame_contexts[pc->frame_context_idx];
 
-  setup_txfm_mode(pc, xd->lossless, &header_bc);
-
   update_frame_context(&pc->fc);
+
+  setup_txfm_mode(pc, xd->lossless, &header_bc);
 
   read_coef_probs(pbi, &header_bc);
 
--- a/vp9/encoder/vp9_bitstream.c
+++ b/vp9/encoder/vp9_bitstream.c
@@ -575,12 +575,12 @@
       !(rf != INTRA_FRAME &&
         (skip_coeff || vp9_segfeature_active(xd, segment_id, SEG_LVL_SKIP)))) {
     TX_SIZE sz = mi->txfm_size;
-    // FIXME(rbultje) code ternary symbol once all experiments are merged
-    vp9_write(bc, sz != TX_4X4, pc->prob_tx[0]);
+    int tx_probs_offset = get_tx_probs_offset(mi->sb_type);
+    vp9_write(bc, sz != TX_4X4, pc->fc.tx_probs[tx_probs_offset]);
     if (mi->sb_type >= BLOCK_SIZE_MB16X16 && sz != TX_4X4) {
-      vp9_write(bc, sz != TX_8X8, pc->prob_tx[1]);
+      vp9_write(bc, sz != TX_8X8, pc->fc.tx_probs[tx_probs_offset + 1]);
       if (mi->sb_type >= BLOCK_SIZE_SB32X32 && sz != TX_8X8)
-        vp9_write(bc, sz != TX_16X16, pc->prob_tx[2]);
+        vp9_write(bc, sz != TX_16X16, pc->fc.tx_probs[tx_probs_offset + 2]);
     }
   }
 
@@ -706,12 +706,12 @@
 
   if (m->mbmi.sb_type >= BLOCK_SIZE_SB8X8 && c->txfm_mode == TX_MODE_SELECT) {
     TX_SIZE sz = m->mbmi.txfm_size;
-    // FIXME(rbultje) code ternary symbol once all experiments are merged
-    vp9_write(bc, sz != TX_4X4, c->prob_tx[0]);
+    int tx_probs_offset = get_tx_probs_offset(m->mbmi.sb_type);
+    vp9_write(bc, sz != TX_4X4, c->fc.tx_probs[tx_probs_offset]);
     if (m->mbmi.sb_type >= BLOCK_SIZE_MB16X16 && sz != TX_4X4) {
-      vp9_write(bc, sz != TX_8X8, c->prob_tx[1]);
+      vp9_write(bc, sz != TX_8X8, c->fc.tx_probs[tx_probs_offset + 1]);
       if (m->mbmi.sb_type >= BLOCK_SIZE_SB32X32 && sz != TX_8X8)
-        vp9_write(bc, sz != TX_16X16, c->prob_tx[2]);
+        vp9_write(bc, sz != TX_16X16, c->fc.tx_probs[tx_probs_offset + 2]);
     }
   }
 
@@ -1217,7 +1217,7 @@
 }
 
 
-static void encode_txfm(VP9_COMP *cpi, vp9_writer *w) {
+static void encode_txfm_probs(VP9_COMP *cpi, vp9_writer *w) {
   VP9_COMMON *const cm = &cpi->common;
 
   // Mode
@@ -1227,35 +1227,19 @@
 
   // Probabilities
   if (cm->txfm_mode == TX_MODE_SELECT) {
-    cm->prob_tx[0] = get_prob(cpi->txfm_count_32x32p[TX_4X4] +
-                              cpi->txfm_count_16x16p[TX_4X4] +
-                              cpi->txfm_count_8x8p[TX_4X4],
-                              cpi->txfm_count_32x32p[TX_4X4] +
-                              cpi->txfm_count_32x32p[TX_8X8] +
-                              cpi->txfm_count_32x32p[TX_16X16] +
-                              cpi->txfm_count_32x32p[TX_32X32] +
-                              cpi->txfm_count_16x16p[TX_4X4] +
-                              cpi->txfm_count_16x16p[TX_8X8] +
-                              cpi->txfm_count_16x16p[TX_16X16] +
-                              cpi->txfm_count_8x8p[TX_4X4] +
-                              cpi->txfm_count_8x8p[TX_8X8]);
-    cm->prob_tx[1] = get_prob(cpi->txfm_count_32x32p[TX_8X8] +
-                              cpi->txfm_count_16x16p[TX_8X8],
-                              cpi->txfm_count_32x32p[TX_8X8] +
-                              cpi->txfm_count_32x32p[TX_16X16] +
-                              cpi->txfm_count_32x32p[TX_32X32] +
-                              cpi->txfm_count_16x16p[TX_8X8] +
-                              cpi->txfm_count_16x16p[TX_16X16]);
-    cm->prob_tx[2] = get_prob(cpi->txfm_count_32x32p[TX_16X16],
-                              cpi->txfm_count_32x32p[TX_16X16] +
-                              cpi->txfm_count_32x32p[TX_32X32]);
-    vp9_write_prob(w, cm->prob_tx[0]);
-    vp9_write_prob(w, cm->prob_tx[1]);
-    vp9_write_prob(w, cm->prob_tx[2]);
+    int i;
+    unsigned int ct[TX_SIZE_PROBS][2];
+    tx_counts_to_branch_counts(cm->fc.tx_count_32x32p,
+                               cm->fc.tx_count_16x16p,
+                               cm->fc.tx_count_8x8p, ct);
+
+    for (i = 0; i < TX_SIZE_PROBS; i++) {
+      vp9_cond_prob_diff_update(w, &cm->fc.tx_probs[i],
+                                VP9_DEF_UPDATE_PROB, ct[i]);
+    }
   } else {
-    cm->prob_tx[0] = 128;
-    cm->prob_tx[1] = 128;
-    cm->prob_tx[2] = 128;
+    vpx_memcpy(cm->fc.tx_probs, vp9_default_tx_probs,
+               sizeof(vp9_default_tx_probs));
   }
 }
 
@@ -1440,11 +1424,6 @@
     active_section = 7;
 #endif
 
-  if (xd->lossless)
-    pc->txfm_mode = ONLY_4X4;
-  else
-    encode_txfm(cpi, &header_bc);
-
   vp9_clear_system_state();  // __asm emms;
 
   vp9_copy(pc->fc.pre_coef_probs, pc->fc.coef_probs);
@@ -1460,6 +1439,13 @@
   vp9_copy(pc->fc.pre_comp_ref_prob, pc->fc.comp_ref_prob);
   vp9_copy(pc->fc.pre_single_ref_prob, pc->fc.single_ref_prob);
   cpi->common.fc.pre_nmvc = cpi->common.fc.nmvc;
+  vp9_copy(cpi->common.fc.pre_tx_probs, cpi->common.fc.tx_probs);
+
+  if (xd->lossless) {
+    pc->txfm_mode = ONLY_4X4;
+  } else {
+    encode_txfm_probs(cpi, &header_bc);
+  }
 
   update_coef_probs(cpi, &header_bc);
 
--- a/vp9/encoder/vp9_encodeframe.c
+++ b/vp9/encoder/vp9_encodeframe.c
@@ -1464,12 +1464,15 @@
 
   vp9_zero(cpi->y_mode_count)
   vp9_zero(cpi->y_uv_mode_count)
-  vp9_zero(cpi->common.fc.inter_mode_counts)
+  vp9_zero(cm->fc.inter_mode_counts)
   vp9_zero(cpi->partition_count);
   vp9_zero(cpi->intra_inter_count);
   vp9_zero(cpi->comp_inter_count);
   vp9_zero(cpi->single_ref_count);
   vp9_zero(cpi->comp_ref_count);
+  vp9_zero(cm->fc.tx_count_32x32p);
+  vp9_zero(cm->fc.tx_count_16x16p);
+  vp9_zero(cm->fc.tx_count_8x8p);
 
   // Note: this memset assumes above_context[0], [1] and [2]
   // are allocated as part of the same buffer.
@@ -1560,9 +1563,6 @@
   init_encode_frame_mb_context(cpi);
 
   vpx_memset(cpi->rd_comp_pred_diff, 0, sizeof(cpi->rd_comp_pred_diff));
-  vpx_memset(cpi->txfm_count_32x32p, 0, sizeof(cpi->txfm_count_32x32p));
-  vpx_memset(cpi->txfm_count_16x16p, 0, sizeof(cpi->txfm_count_16x16p));
-  vpx_memset(cpi->txfm_count_8x8p, 0, sizeof(cpi->txfm_count_8x8p));
   vpx_memset(cpi->rd_tx_select_diff, 0, sizeof(cpi->rd_tx_select_diff));
   vpx_memset(cpi->rd_tx_select_threshes, 0, sizeof(cpi->rd_tx_select_threshes));
 
@@ -1841,11 +1841,6 @@
                     ALLOW_32X32 : TX_MODE_SELECT;
 #endif
     cpi->common.txfm_mode = txfm_type;
-    if (txfm_type != TX_MODE_SELECT) {
-      cpi->common.prob_tx[0] = 128;
-      cpi->common.prob_tx[1] = 128;
-      cpi->common.prob_tx[2] = 128;
-    }
     cpi->common.comp_pred_mode = pred_type;
     encode_frame_internal(cpi);
 
@@ -1885,15 +1880,15 @@
     }
 
     if (cpi->common.txfm_mode == TX_MODE_SELECT) {
-      const int count4x4 = cpi->txfm_count_16x16p[TX_4X4] +
-                           cpi->txfm_count_32x32p[TX_4X4] +
-                           cpi->txfm_count_8x8p[TX_4X4];
-      const int count8x8_lp = cpi->txfm_count_32x32p[TX_8X8] +
-                              cpi->txfm_count_16x16p[TX_8X8];
-      const int count8x8_8x8p = cpi->txfm_count_8x8p[TX_8X8];
-      const int count16x16_16x16p = cpi->txfm_count_16x16p[TX_16X16];
-      const int count16x16_lp = cpi->txfm_count_32x32p[TX_16X16];
-      const int count32x32 = cpi->txfm_count_32x32p[TX_32X32];
+      const int count4x4 = cm->fc.tx_count_16x16p[TX_4X4] +
+                           cm->fc.tx_count_32x32p[TX_4X4] +
+                           cm->fc.tx_count_8x8p[TX_4X4];
+      const int count8x8_lp = cm->fc.tx_count_32x32p[TX_8X8] +
+                              cm->fc.tx_count_16x16p[TX_8X8];
+      const int count8x8_8x8p = cm->fc.tx_count_8x8p[TX_8X8];
+      const int count16x16_16x16p = cm->fc.tx_count_16x16p[TX_16X16];
+      const int count16x16_lp = cm->fc.tx_count_32x32p[TX_16X16];
+      const int count32x32 = cm->fc.tx_count_32x32p[TX_32X32];
 
       if (count4x4 == 0 && count16x16_lp == 0 && count16x16_16x16p == 0 &&
           count32x32 == 0) {
@@ -2077,11 +2072,11 @@
         !(mbmi->ref_frame[0] != INTRA_FRAME && (mbmi->mb_skip_coeff ||
           vp9_segfeature_active(xd, segment_id, SEG_LVL_SKIP)))) {
       if (bsize >= BLOCK_SIZE_SB32X32) {
-        cpi->txfm_count_32x32p[mbmi->txfm_size]++;
+        cm->fc.tx_count_32x32p[mbmi->txfm_size]++;
       } else if (bsize >= BLOCK_SIZE_MB16X16) {
-        cpi->txfm_count_16x16p[mbmi->txfm_size]++;
+        cm->fc.tx_count_16x16p[mbmi->txfm_size]++;
       } else {
-        cpi->txfm_count_8x8p[mbmi->txfm_size]++;
+        cm->fc.tx_count_8x8p[mbmi->txfm_size]++;
       }
     } else {
       int x, y;
--- a/vp9/encoder/vp9_onyx_if.c
+++ b/vp9/encoder/vp9_onyx_if.c
@@ -1298,8 +1298,6 @@
   cpi->frames_till_gf_update_due    = 0;
   cpi->gf_overspend_bits            = 0;
   cpi->non_gf_bitrate_adjustment    = 0;
-  for (i = 0; i < TX_SIZE_MAX_SB - 1; i++)
-    cm->prob_tx[i]               = 128;
 
   // Set reference frame sign bias for ALTREF frame to 1 (for now)
   cpi->common.ref_frame_sign_bias[ALTREF_FRAME] = 1;
--- a/vp9/encoder/vp9_onyx_int.h
+++ b/vp9/encoder/vp9_onyx_int.h
@@ -89,6 +89,7 @@
   int inter_mode_counts[INTER_MODE_CONTEXTS][VP9_INTER_MODES - 1][2];
   vp9_prob inter_mode_probs[INTER_MODE_CONTEXTS][VP9_INTER_MODES - 1];
 
+  vp9_prob tx_probs[TX_SIZE_PROBS];
 } CODING_CONTEXT;
 
 typedef struct {
@@ -326,9 +327,7 @@
   unsigned int comp_ref_count[REF_CONTEXTS][2];
 
   // FIXME contextualize
-  int txfm_count_32x32p[TX_SIZE_MAX_SB];
-  int txfm_count_16x16p[TX_SIZE_MAX_SB - 1];
-  int txfm_count_8x8p[TX_SIZE_MAX_SB - 2];
+
   int64_t rd_tx_select_diff[NB_TXFM_MODES];
   int rd_tx_select_threshes[4][NB_TXFM_MODES];
 
--- a/vp9/encoder/vp9_ratectrl.c
+++ b/vp9/encoder/vp9_ratectrl.c
@@ -143,6 +143,7 @@
 
   vp9_copy(cc->coef_probs, cm->fc.coef_probs);
   vp9_copy(cc->switchable_interp_prob, cm->fc.switchable_interp_prob);
+  vp9_copy(cc->tx_probs, cm->fc.tx_probs);
 }
 
 void vp9_restore_coding_context(VP9_COMP *cpi) {
@@ -180,6 +181,7 @@
 
   vp9_copy(cm->fc.coef_probs, cc->coef_probs);
   vp9_copy(cm->fc.switchable_interp_prob, cc->switchable_interp_prob);
+  vp9_copy(cm->fc.tx_probs, cc->tx_probs);
 }
 
 void vp9_setup_key_frame(VP9_COMP *cpi) {
--- a/vp9/encoder/vp9_rdopt.c
+++ b/vp9/encoder/vp9_rdopt.c
@@ -420,6 +420,7 @@
                                      int *d, int *distortion,
                                      int *s, int *skip,
                                      int64_t txfm_cache[NB_TXFM_MODES],
+                                     BLOCK_SIZE_TYPE bs,
                                      TX_SIZE max_txfm_size) {
   VP9_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
@@ -429,13 +430,15 @@
   int n, m;
   int s0, s1;
 
+  int tx_probs_offset = get_tx_probs_offset(bs);
+
   for (n = TX_4X4; n <= max_txfm_size; n++) {
     r[n][1] = r[n][0];
     for (m = 0; m <= n - (n == max_txfm_size); m++) {
       if (m == n)
-        r[n][1] += vp9_cost_zero(cm->prob_tx[m]);
+        r[n][1] += vp9_cost_zero(cm->fc.tx_probs[tx_probs_offset + m]);
       else
-        r[n][1] += vp9_cost_one(cm->prob_tx[m]);
+        r[n][1] += vp9_cost_one(cm->fc.tx_probs[tx_probs_offset + m]);
     }
   }
 
@@ -608,6 +611,8 @@
   MACROBLOCKD *xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mode_info_context->mbmi;
 
+  assert(bs == mbmi->sb_type);
+
   if (mbmi->ref_frame[0] > INTRA_FRAME)
     vp9_subtract_sby(x, bs);
 
@@ -637,7 +642,8 @@
   super_block_yrd_for_txfm(cm, x, &r[TX_4X4][0], &d[TX_4X4], &s[TX_4X4], bs,
                            TX_4X4);
 
-  choose_txfm_size_from_rd(cpi, x, r, rate, d, distortion, s, skip, txfm_cache,
+  choose_txfm_size_from_rd(cpi, x, r, rate, d, distortion, s,
+                           skip, txfm_cache, bs,
                            TX_32X32 - (bs < BLOCK_SIZE_SB32X32)
                            - (bs < BLOCK_SIZE_MB16X16));
 }