shithub: libvpx

Download patch

ref: f39b0f192fed13cc788f76877dd2a6495cfe8dd4
parent: a7333b0a5b9f9d435bba7b1dce72632ae70c0330
author: Ronald S. Bultje <rbultje@google.com>
date: Wed Oct 10 13:18:22 EDT 2012

Use transform-size selection for i8x8_pred also.

Change-Id: Iecb282fc89f9b5145ef31c5eda294ad42bc32a5d

--- a/vp8/decoder/decodemv.c
+++ b/vp8/decoder/decodemv.c
@@ -172,10 +172,10 @@
 
 #if CONFIG_TX_SELECT
   if (cm->txfm_mode == TX_MODE_SELECT && m->mbmi.mb_skip_coeff == 0 &&
-      m->mbmi.mode <= TM_PRED) {
+      m->mbmi.mode <= I8X8_PRED) {
     // FIXME(rbultje) code ternary symbol once all experiments are merged
     m->mbmi.txfm_size = vp8_read(bc, cm->prob_tx[0]);
-    if (m->mbmi.txfm_size != TX_4X4)
+    if (m->mbmi.txfm_size != TX_4X4 && m->mbmi.mode != I8X8_PRED)
       m->mbmi.txfm_size += vp8_read(bc, cm->prob_tx[1]);
   } else
 #endif
@@ -1281,11 +1281,11 @@
 
 #if CONFIG_TX_SELECT
   if (cm->txfm_mode == TX_MODE_SELECT && mbmi->mb_skip_coeff == 0 &&
-      ((mbmi->ref_frame == INTRA_FRAME && mbmi->mode <= TM_PRED) ||
+      ((mbmi->ref_frame == INTRA_FRAME && mbmi->mode <= I8X8_PRED) ||
        (mbmi->ref_frame != INTRA_FRAME && mbmi->mode != SPLITMV))) {
     // FIXME(rbultje) code ternary symbol once all experiments are merged
     mbmi->txfm_size = vp8_read(bc, cm->prob_tx[0]);
-    if (mbmi->txfm_size != TX_4X4)
+    if (mbmi->txfm_size != TX_4X4 && mbmi->mode != I8X8_PRED)
       mbmi->txfm_size += vp8_read(bc, cm->prob_tx[1]);
   } else
 #endif
--- a/vp8/decoder/decodframe.c
+++ b/vp8/decoder/decodframe.c
@@ -364,10 +364,8 @@
   if (mode == I8X8_PRED) {
     for (i = 0; i < 4; i++) {
       int ib = vp8_i8x8_block[i];
-#if !CONFIG_HYBRIDTRANSFORM8X8
       const int iblock[4] = {0, 1, 4, 5};
       int j;
-#endif
       int i8x8mode;
       BLOCKD *b;
 
@@ -381,30 +379,40 @@
 
       b = &xd->block[ib];
       i8x8mode = b->bmi.as_mode.first;
-      RECON_INVOKE(RTCD_VTABLE(recon), intra8x8_predict)
-      (b, i8x8mode, b->predictor);
+      RECON_INVOKE(RTCD_VTABLE(recon), intra8x8_predict)(b, i8x8mode,
+                                                         b->predictor);
 
+      if (xd->mode_info_context->mbmi.txfm_size == TX_8X8) {
 #if CONFIG_HYBRIDTRANSFORM8X8
-      vp8_ht_dequant_idct_add_8x8_c(b->bmi.as_mode.tx_type,
-                                    q, dq, pre, dst, 16, stride);
-      q += 64;
+        vp8_ht_dequant_idct_add_8x8_c(b->bmi.as_mode.tx_type,
+                                      q, dq, pre, dst, 16, stride);
+        q += 64;
 #else
-      vp8_dequant_idct_add_8x8_c(q, dq, pre, dst, 16, stride);
-      q += 64;
+        vp8_dequant_idct_add_8x8_c(q, dq, pre, dst, 16, stride);
+        q += 64;
 #endif
+      } else {
+        for (j = 0; j < 4; j++) {
+          b = &xd->block[ib + iblock[j]];
+          vp8_dequant_idct_add_c(b->qcoeff, b->dequant, b->predictor,
+                                 *(b->base_dst) + b->dst, 16, b->dst_stride);
+        }
+      }
 
       b = &xd->block[16 + i];
-      RECON_INVOKE(RTCD_VTABLE(recon), intra_uv4x4_predict)
-      (b, i8x8mode, b->predictor);
-      DEQUANT_INVOKE(&pbi->dequant, idct_add)
-      (b->qcoeff, b->dequant,  b->predictor,
-       *(b->base_dst) + b->dst, 8, b->dst_stride);
+      RECON_INVOKE(RTCD_VTABLE(recon), intra_uv4x4_predict)(b, i8x8mode,
+                                                            b->predictor);
+      DEQUANT_INVOKE(&pbi->dequant, idct_add)(b->qcoeff, b->dequant,
+                                              b->predictor,
+                                              *(b->base_dst) + b->dst, 8,
+                                              b->dst_stride);
       b = &xd->block[20 + i];
-      RECON_INVOKE(RTCD_VTABLE(recon), intra_uv4x4_predict)
-      (b, i8x8mode, b->predictor);
-      DEQUANT_INVOKE(&pbi->dequant, idct_add)
-      (b->qcoeff, b->dequant,  b->predictor,
-       *(b->base_dst) + b->dst, 8, b->dst_stride);
+      RECON_INVOKE(RTCD_VTABLE(recon), intra_uv4x4_predict)(b, i8x8mode,
+                                                            b->predictor);
+      DEQUANT_INVOKE(&pbi->dequant, idct_add)(b->qcoeff, b->dequant,
+                                              b->predictor,
+                                              *(b->base_dst) + b->dst, 8,
+                                              b->dst_stride);
     }
   } else if (mode == B_PRED) {
     for (i = 0; i < 16; i++) {
--- a/vp8/encoder/bitstream.c
+++ b/vp8/encoder/bitstream.c
@@ -239,6 +239,18 @@
   }
 }
 
+static __inline int get_prob(int num, int den) {
+  int p;
+  if (den <= 0)
+    return 128;
+  p = (num * 255 + (den >> 1)) / den;
+  if (p > 255)
+    return 255;
+  else if (p < 1)
+    return 1;
+  return p;
+}
+
 void update_skip_probs(VP8_COMP *cpi) {
   VP8_COMMON *const pc = & cpi->common;
   int prob_skip_false[3] = {0, 0, 0};
@@ -1289,7 +1301,7 @@
         }
 
 #if CONFIG_TX_SELECT
-        if (((rf == INTRA_FRAME && mode <= TM_PRED) ||
+        if (((rf == INTRA_FRAME && mode <= I8X8_PRED) ||
              (rf != INTRA_FRAME && mode != SPLITMV)) &&
             pc->txfm_mode == TX_MODE_SELECT &&
             !((pc->mb_no_coeff_skip && mi->mb_skip_coeff) ||
@@ -1298,7 +1310,7 @@
           TX_SIZE sz = mi->txfm_size;
           // FIXME(rbultje) code ternary symbol once all experiments are merged
           vp8_write(w, sz != TX_4X4, pc->prob_tx[0]);
-          if (sz != TX_4X4)
+          if (sz != TX_4X4 && mode != I8X8_PRED)
             vp8_write(w, sz != TX_8X8, pc->prob_tx[1]);
         }
 #endif
@@ -1467,7 +1479,7 @@
           write_uv_mode(bc, m->mbmi.uv_mode, c->kf_uv_mode_prob[ym]);
 
 #if CONFIG_TX_SELECT
-        if (ym <= TM_PRED && c->txfm_mode == TX_MODE_SELECT &&
+        if (ym <= I8X8_PRED && c->txfm_mode == TX_MODE_SELECT &&
             !((c->mb_no_coeff_skip && m->mbmi.mb_skip_coeff) ||
               (segfeature_active(xd, segment_id, SEG_LVL_EOB) &&
                get_segdata(xd, segment_id, SEG_LVL_EOB) == 0))) {
@@ -1474,7 +1486,7 @@
           TX_SIZE sz = m->mbmi.txfm_size;
           // FIXME(rbultje) code ternary symbol once all experiments are merged
           vp8_write(bc, sz != TX_4X4, c->prob_tx[0]);
-          if (sz != TX_4X4)
+          if (sz != TX_4X4 && ym <= TM_PRED)
             vp8_write(bc, sz != TX_8X8, c->prob_tx[1]);
         }
 #endif
@@ -2626,32 +2638,13 @@
 
 #if CONFIG_TX_SELECT
   {
-    int cnt = cpi->txfm_count[0] + cpi->txfm_count[1] + cpi->txfm_count[2];
-    if (cnt && pc->txfm_mode == TX_MODE_SELECT) {
-      int prob = (255 * (cpi->txfm_count[1] + cpi->txfm_count[2]) + (cnt >> 1)) / cnt;
-      if (prob <= 1) {
-        pc->prob_tx[0] = 1;
-      } else if (prob >= 255) {
-        pc->prob_tx[0] = 255;
-      } else {
-        pc->prob_tx[0] = prob;
-      }
-      pc->prob_tx[0] = 256 - pc->prob_tx[0];
+    if (pc->txfm_mode == TX_MODE_SELECT) {
+      pc->prob_tx[0] = get_prob(cpi->txfm_count[0] + cpi->txfm_count_8x8p[0],
+                                cpi->txfm_count[0] + cpi->txfm_count[1] + cpi->txfm_count[2] +
+                                cpi->txfm_count_8x8p[0] + cpi->txfm_count_8x8p[1]);
+      pc->prob_tx[1] = get_prob(cpi->txfm_count[1], cpi->txfm_count[1] + cpi->txfm_count[2]);
     } else {
       pc->prob_tx[0] = 128;
-    }
-    cnt -= cpi->txfm_count[0];
-    if (cnt && pc->txfm_mode == TX_MODE_SELECT) {
-      int prob = (255 * cpi->txfm_count[2] + (cnt >> 1)) / cnt;
-      if (prob <= 1) {
-        pc->prob_tx[1] = 1;
-      } else if (prob >= 255) {
-        pc->prob_tx[1] = 255;
-      } else {
-        pc->prob_tx[1] = prob;
-      }
-      pc->prob_tx[1] = 256 - pc->prob_tx[1];
-    } else {
       pc->prob_tx[1] = 128;
     }
     vp8_write_literal(bc, pc->txfm_mode, 2);
--- a/vp8/encoder/encodeframe.c
+++ b/vp8/encoder/encodeframe.c
@@ -1396,6 +1396,7 @@
   vpx_memset(cpi->comp_pred_count, 0, sizeof(cpi->comp_pred_count));
 #if CONFIG_TX_SELECT
   vpx_memset(cpi->txfm_count, 0, sizeof(cpi->txfm_count));
+  vpx_memset(cpi->txfm_count_8x8p, 0, sizeof(cpi->txfm_count_8x8p));
   vpx_memset(cpi->rd_tx_select_diff, 0, sizeof(cpi->rd_tx_select_diff));
 #endif
   {
@@ -1598,13 +1599,14 @@
 
 #if CONFIG_TX_SELECT
     if (cpi->common.txfm_mode == TX_MODE_SELECT) {
-      const int count4x4 = cpi->txfm_count[TX_4X4];
+      const int count4x4 = cpi->txfm_count[TX_4X4] + cpi->txfm_count_8x8p[TX_4X4];
       const int count8x8 = cpi->txfm_count[TX_8X8];
+      const int count8x8_8x8p = cpi->txfm_count_8x8p[TX_8X8];
       const int count16x16 = cpi->txfm_count[TX_16X16];
 
       if (count4x4 == 0 && count16x16 == 0) {
         cpi->common.txfm_mode = ALLOW_8X8;
-      } else if (count8x8 == 0 && count16x16 == 0) {
+      } else if (count8x8 == 0 && count16x16 == 0 && count8x8_8x8p == 0) {
         cpi->common.txfm_mode = ONLY_4X4;
       } else if (count8x8 == 0 && count4x4 == 0) {
         cpi->common.txfm_mode = ALLOW_16X16;
@@ -1946,6 +1948,8 @@
            get_segdata(&x->e_mbd, segment_id, SEG_LVL_EOB) == 0))) {
       if (mbmi->mode != B_PRED && mbmi->mode != I8X8_PRED) {
         cpi->txfm_count[mbmi->txfm_size]++;
+      } else if (mbmi->mode == I8X8_PRED) {
+        cpi->txfm_count_8x8p[mbmi->txfm_size]++;
       }
     } else
 #endif
@@ -2138,6 +2142,8 @@
       if (mbmi->mode != B_PRED && mbmi->mode != I8X8_PRED &&
           mbmi->mode != SPLITMV) {
         cpi->txfm_count[mbmi->txfm_size]++;
+      } else if (mbmi->mode == I8X8_PRED) {
+        cpi->txfm_count_8x8p[mbmi->txfm_size]++;
       }
     } else
 #endif
--- a/vp8/encoder/encodeintra.c
+++ b/vp8/encoder/encodeintra.c
@@ -238,7 +238,8 @@
 
 void vp8_encode_intra8x8(const VP8_ENCODER_RTCD *rtcd,
                          MACROBLOCK *x, int ib) {
-  BLOCKD *b = &x->e_mbd.block[ib];
+  MACROBLOCKD *xd = &x->e_mbd;
+  BLOCKD *b = &xd->block[ib];
   BLOCK *be = &x->block[ib];
   const int iblock[4] = {0, 1, 4, 5};
   int i;
@@ -255,8 +256,7 @@
   }
 #endif
 
-  {
-    MACROBLOCKD *xd = &x->e_mbd;
+  if (x->e_mbd.mode_info_context->mbmi.txfm_size == TX_8X8) {
     int idx = (ib & 0x02) ? (ib + 2) : ib;
 
     // generate residual blocks
@@ -274,13 +274,22 @@
     x->quantize_b_8x8(x->block + idx, xd->block + idx);
     vp8_idct_idct8(xd->block[idx].dqcoeff, xd->block[ib].diff, 32);
 #endif
-
-    // reconstruct submacroblock
+  } else {
     for (i = 0; i < 4; i++) {
       b = &xd->block[ib + iblock[i]];
-      vp8_recon_b_c(b->predictor, b->diff, *(b->base_dst) + b->dst,
-                    b->dst_stride);
+      be = &x->block[ib + iblock[i]];
+      ENCODEMB_INVOKE(&rtcd->encodemb, subb)(be, b, 16);
+      x->vp8_short_fdct4x4(be->src_diff, be->coeff, 32);
+      x->quantize_b(be, b);
+      vp8_inverse_transform_b(IF_RTCD(&rtcd->common->idct), b, 32);
     }
+  }
+
+  // reconstruct submacroblock
+  for (i = 0; i < 4; i++) {
+    b = &xd->block[ib + iblock[i]];
+    vp8_recon_b_c(b->predictor, b->diff, *(b->base_dst) + b->dst,
+                  b->dst_stride);
   }
 }
 
--- a/vp8/encoder/onyx_int.h
+++ b/vp8/encoder/onyx_int.h
@@ -470,7 +470,8 @@
   int single_pred_count[COMP_PRED_CONTEXTS];
 #if CONFIG_TX_SELECT
   // FIXME contextualize
-  int txfm_count[TX_SIZE_MAX + 1];
+  int txfm_count[TX_SIZE_MAX];
+  int txfm_count_8x8p[TX_SIZE_MAX - 1];
   int64_t rd_tx_select_diff[NB_TXFM_MODES];
   int rd_tx_select_threshes[4][NB_TXFM_MODES];
 #endif
--- a/vp8/encoder/rdopt.c
+++ b/vp8/encoder/rdopt.c
@@ -1565,30 +1565,63 @@
 
       vp8_subtract_4b_c(be, b, 16);
 
+      if (xd->mode_info_context->mbmi.txfm_size == TX_8X8) {
 #if CONFIG_HYBRIDTRANSFORM8X8
-      txfm_map(b, pred_mode_conv(mode));
-      vp8_fht_c(be->src_diff, (x->block + idx)->coeff, 32,
-                b->bmi.as_mode.tx_type, 8);
+        txfm_map(b, pred_mode_conv(mode));
+        vp8_fht_c(be->src_diff, (x->block + idx)->coeff, 32,
+                  b->bmi.as_mode.tx_type, 8);
 
 #else
-      x->vp8_short_fdct8x8(be->src_diff, (x->block + idx)->coeff, 32);
+        x->vp8_short_fdct8x8(be->src_diff, (x->block + idx)->coeff, 32);
 #endif
+        x->quantize_b_8x8(x->block + idx, xd->block + idx);
 
-      x->quantize_b_8x8(x->block + idx, xd->block + idx);
+        // compute quantization mse of 8x8 block
+        distortion = vp8_block_error_c((x->block + idx)->coeff,
+                                       (xd->block + idx)->dqcoeff, 64);
+        ta0 = *(a + vp8_block2above_8x8[idx]);
+        tl0 = *(l + vp8_block2left_8x8 [idx]);
 
-      // compute quantization mse of 8x8 block
-      distortion = vp8_block_error_c((x->block + idx)->coeff,
-                                     (xd->block + idx)->dqcoeff, 64)>>2;
-      ta0 = *(a + vp8_block2above_8x8[idx]);
-      tl0 = *(l + vp8_block2left_8x8 [idx]);
+        rate_t = cost_coeffs(x, xd->block + idx, PLANE_TYPE_Y_WITH_DC,
+                             &ta0, &tl0, TX_8X8);
 
-      rate_t = cost_coeffs(x, xd->block + idx, PLANE_TYPE_Y_WITH_DC,
-                           &ta0, &tl0, TX_8X8);
+        rate += rate_t;
+        ta1 = ta0;
+        tl1 = tl0;
+      } else {
+        x->vp8_short_fdct8x4(be->src_diff, be->coeff, 32);
+        x->vp8_short_fdct8x4((be + 4)->src_diff, (be + 4)->coeff, 32);
 
-      rate += rate_t;
-      ta1 = ta0;
-      tl1 = tl0;
+        x->quantize_b_pair(x->block + ib, x->block + ib + 1,
+                           xd->block + ib, xd->block + ib + 1);
+        x->quantize_b_pair(x->block + ib + 4, x->block + ib + 5,
+                           xd->block + ib + 4, xd->block + ib + 5);
 
+        distortion = vp8_block_error_c((x->block + ib)->coeff,
+                                       (xd->block + ib)->dqcoeff, 16);
+        distortion += vp8_block_error_c((x->block + ib + 1)->coeff,
+                                        (xd->block + ib + 1)->dqcoeff, 16);
+        distortion += vp8_block_error_c((x->block + ib + 4)->coeff,
+                                        (xd->block + ib + 4)->dqcoeff, 16);
+        distortion += vp8_block_error_c((x->block + ib + 5)->coeff,
+                                        (xd->block + ib + 5)->dqcoeff, 16);
+
+        ta0 = *(a + vp8_block2above[ib]);
+        ta1 = *(a + vp8_block2above[ib + 1]);
+        tl0 = *(l + vp8_block2above[ib]);
+        tl1 = *(l + vp8_block2above[ib + 4]);
+        rate_t = cost_coeffs(x, xd->block + ib, PLANE_TYPE_Y_WITH_DC,
+                             &ta0, &tl0, TX_4X4);
+        rate_t += cost_coeffs(x, xd->block + ib + 1, PLANE_TYPE_Y_WITH_DC,
+                              &ta1, &tl0, TX_4X4);
+        rate_t += cost_coeffs(x, xd->block + ib + 4, PLANE_TYPE_Y_WITH_DC,
+                              &ta0, &tl1, TX_4X4);
+        rate_t += cost_coeffs(x, xd->block + ib + 5, PLANE_TYPE_Y_WITH_DC,
+                              &ta1, &tl1, TX_4X4);
+        rate += rate_t;
+      }
+
+      distortion >>= 2;
       this_rd = RDCOST(x->rdmult, x->rddiv, rate, distortion);
       if (this_rd < best_rd) {
         *bestrate = rate;
@@ -1617,17 +1650,18 @@
 #endif
   vp8_encode_intra8x8(IF_RTCD(&cpi->rtcd), x, ib);
 
-#if CONFIG_HYBRIDTRANSFORM8X8
-  *(a + vp8_block2above_8x8[idx])     = besta0;
-  *(a + vp8_block2above_8x8[idx] + 1) = besta1;
-  *(l + vp8_block2left_8x8 [idx])     = bestl0;
-  *(l + vp8_block2left_8x8 [idx] + 1) = bestl1;
-#else
-  *(a + vp8_block2above[ib])   = besta0;
-  *(a + vp8_block2above[ib + 1]) = besta1;
-  *(l + vp8_block2above[ib])   = bestl0;
-  *(l + vp8_block2above[ib + 4]) = bestl1;
-#endif
+  if (xd->mode_info_context->mbmi.txfm_size == TX_8X8) {
+    *(a + vp8_block2above_8x8[idx])     = besta0;
+    *(a + vp8_block2above_8x8[idx] + 1) = besta1;
+    *(l + vp8_block2left_8x8 [idx])     = bestl0;
+    *(l + vp8_block2left_8x8 [idx] + 1) = bestl1;
+  } else {
+    *(a + vp8_block2above[ib])     = besta0;
+    *(a + vp8_block2above[ib + 1]) = besta1;
+    *(l + vp8_block2above[ib])     = bestl0;
+    *(l + vp8_block2above[ib + 4]) = bestl1;
+  }
+
   return best_rd;
 }
 
@@ -3395,7 +3429,7 @@
     int other_cost = 0;
     int compmode_cost = 0;
     int mode_excluded = 0;
-    int64_t txfm_cache[NB_TXFM_MODES];
+    int64_t txfm_cache[NB_TXFM_MODES] = { 0 };
 
     // These variables hold are rolling total cost and distortion for this mode
     rate2 = 0;
@@ -3579,13 +3613,16 @@
         }
         break;
         case I8X8_PRED: {
-          int64_t tmp_rd;
-          mbmi->txfm_size = TX_8X8; // FIXME wrong in case of hybridtransform8x8
-          tmp_rd = rd_pick_intra8x8mby_modes(cpi, x, &rate, &rate_y,
-                                             &distortion, best_yrd);
-          rate2 += rate;
-          distortion2 += distortion;
-
+#if CONFIG_TX_SELECT
+          int cost0 = vp8_cost_bit(cm->prob_tx[0], 0);
+          int cost1 = vp8_cost_bit(cm->prob_tx[0], 1);
+          int64_t tmp_rd_4x4s, tmp_rd_8x8s;
+#endif
+          int64_t tmp_rd_4x4, tmp_rd_8x8, tmp_rd;
+          int r4x4, tok4x4, d4x4, r8x8, tok8x8, d8x8;
+          mbmi->txfm_size = TX_4X4;
+          tmp_rd_4x4 = rd_pick_intra8x8mby_modes(cpi, x, &r4x4, &tok4x4,
+                                                 &d4x4, best_yrd);
           mode8x8[0][0] = x->e_mbd.mode_info_context->bmi[0].as_mode.first;
           mode8x8[0][1] = x->e_mbd.mode_info_context->bmi[2].as_mode.first;
           mode8x8[0][2] = x->e_mbd.mode_info_context->bmi[8].as_mode.first;
@@ -3596,7 +3633,73 @@
           mode8x8[1][2] = x->e_mbd.mode_info_context->bmi[8].as_mode.second;
           mode8x8[1][3] = x->e_mbd.mode_info_context->bmi[10].as_mode.second;
 #endif
+          mbmi->txfm_size = TX_8X8;
+          tmp_rd_8x8 = rd_pick_intra8x8mby_modes(cpi, x, &r8x8, &tok8x8,
+                                                 &d8x8, best_yrd);
+          txfm_cache[ONLY_4X4]  = tmp_rd_4x4;
+          txfm_cache[ALLOW_8X8] = tmp_rd_8x8;
+#if CONFIG_TX16X16
+          txfm_cache[ALLOW_16X16] = tmp_rd_8x8;
+#endif
+#if CONFIG_TX_SELECT
+          tmp_rd_4x4s = tmp_rd_4x4 + RDCOST(x->rdmult, x->rddiv, cost0, 0);
+          tmp_rd_8x8s = tmp_rd_8x8 + RDCOST(x->rdmult, x->rddiv, cost1, 0);
+          txfm_cache[TX_MODE_SELECT] = tmp_rd_4x4s < tmp_rd_8x8s ? tmp_rd_4x4s : tmp_rd_8x8s;
+          if (cm->txfm_mode == TX_MODE_SELECT) {
+            if (tmp_rd_4x4s < tmp_rd_8x8s) {
+              rate = r4x4 + cost0;
+              rate_y = tok4x4 + cost0;
+              distortion = d4x4;
+              mbmi->txfm_size = TX_4X4;
+              tmp_rd = tmp_rd_4x4s;
+            } else {
+              rate = r8x8 + cost1;
+              rate_y = tok8x8 + cost1;
+              distortion = d8x8;
+              mbmi->txfm_size = TX_8X8;
+              tmp_rd = tmp_rd_8x8s;
 
+              mode8x8[0][0] = x->e_mbd.mode_info_context->bmi[0].as_mode.first;
+              mode8x8[0][1] = x->e_mbd.mode_info_context->bmi[2].as_mode.first;
+              mode8x8[0][2] = x->e_mbd.mode_info_context->bmi[8].as_mode.first;
+              mode8x8[0][3] = x->e_mbd.mode_info_context->bmi[10].as_mode.first;
+#if CONFIG_COMP_INTRA_PRED
+              mode8x8[1][0] = x->e_mbd.mode_info_context->bmi[0].as_mode.second;
+              mode8x8[1][1] = x->e_mbd.mode_info_context->bmi[2].as_mode.second;
+              mode8x8[1][2] = x->e_mbd.mode_info_context->bmi[8].as_mode.second;
+              mode8x8[1][3] = x->e_mbd.mode_info_context->bmi[10].as_mode.second;
+#endif
+            }
+          } else
+#endif
+          if (cm->txfm_mode == ONLY_4X4) {
+            rate = r4x4;
+            rate_y = tok4x4;
+            distortion = d4x4;
+            mbmi->txfm_size = TX_4X4;
+            tmp_rd = tmp_rd_4x4;
+          } else {
+            rate = r8x8;
+            rate_y = tok8x8;
+            distortion = d8x8;
+            mbmi->txfm_size = TX_8X8;
+            tmp_rd = tmp_rd_8x8;
+
+            mode8x8[0][0] = x->e_mbd.mode_info_context->bmi[0].as_mode.first;
+            mode8x8[0][1] = x->e_mbd.mode_info_context->bmi[2].as_mode.first;
+            mode8x8[0][2] = x->e_mbd.mode_info_context->bmi[8].as_mode.first;
+            mode8x8[0][3] = x->e_mbd.mode_info_context->bmi[10].as_mode.first;
+#if CONFIG_COMP_INTRA_PRED
+            mode8x8[1][0] = x->e_mbd.mode_info_context->bmi[0].as_mode.second;
+            mode8x8[1][1] = x->e_mbd.mode_info_context->bmi[2].as_mode.second;
+            mode8x8[1][2] = x->e_mbd.mode_info_context->bmi[8].as_mode.second;
+            mode8x8[1][3] = x->e_mbd.mode_info_context->bmi[10].as_mode.second;
+#endif
+          }
+
+          rate2 += rate;
+          distortion2 += distortion;
+
           /* TODO: uv rate maybe over-estimated here since there is UV intra
                    mode coded in I8X8_PRED prediction */
           if (tmp_rd < best_yrd) {
@@ -4026,8 +4129,7 @@
       if (!mode_excluded && this_rd != INT64_MAX) {
         for (i = 0; i < NB_TXFM_MODES; i++) {
           int64_t adj_rd;
-          if (this_mode != B_PRED && this_mode != I8X8_PRED &&
-              this_mode != SPLITMV) {
+          if (this_mode != B_PRED && this_mode != SPLITMV) {
             adj_rd = this_rd + txfm_cache[i] - txfm_cache[cm->txfm_mode];
           } else {
             adj_rd = this_rd;
@@ -4264,6 +4366,8 @@
   mbmi->mode_rdopt = I8X8_PRED;
 #endif
 
+  // FIXME(rbultje) support transform-size selection
+  mbmi->txfm_size = (cm->txfm_mode == ONLY_4X4) ? TX_4X4 : TX_8X8;
   error8x8 = rd_pick_intra8x8mby_modes(cpi, x, &rate8x8, &rate8x8_tokenonly,
                                        &dist8x8, error16x16);
   mode8x8[0][0]= xd->mode_info_context->bmi[0].as_mode.first;
@@ -4365,8 +4469,9 @@
              sizeof(x->mb_context[xd->mb_index].txfm_rd_diff));
 #endif
     } else {
+      // FIXME(rbultje) support transform-size selection
       mbmi->mode = I8X8_PRED;
-      mbmi->txfm_size = TX_8X8;
+      mbmi->txfm_size = (cm->txfm_mode == ONLY_4X4) ? TX_4X4 : TX_8X8;
       set_i8x8_block_modes(x, mode8x8);
       rate = rate8x8 + rateuv;
       dist = dist8x8 + (distuv >> 2);