shithub: dav1d

Download patch

ref: b9d4630c6dfb9bfeac7c3fa5aa59217670595f7b
parent: 79c4aa95cd1f0fd849e130aa282c632d51fb70da
author: Ronald S. Bultje <rsbultje@gmail.com>
date: Sun Sep 1 07:18:46 EDT 2019

Split out film grain block functions into a DSPContext

--- a/src/decode.c
+++ b/src/decode.c
@@ -42,6 +42,7 @@
 #include "src/decode.h"
 #include "src/dequant_tables.h"
 #include "src/env.h"
+#include "src/film_grain.h"
 #include "src/log.h"
 #include "src/qm.h"
 #include "src/recon.h"
@@ -3190,6 +3191,7 @@
             dav1d_loop_filter_dsp_init_##bd##bpc(&dsp->lf); \
             dav1d_loop_restoration_dsp_init_##bd##bpc(&dsp->lr); \
             dav1d_mc_dsp_init_##bd##bpc(&dsp->mc); \
+            dav1d_film_grain_dsp_init_##bd##bpc(&dsp->fg); \
             break
 #if CONFIG_8BPC
         case 8:
--- /dev/null
+++ b/src/fg_apply.h
@@ -1,0 +1,41 @@
+/*
+ * Copyright © 2018, VideoLAN and dav1d authors
+ * Copyright © 2018, Two Orioles, LLC
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ *    list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ *    this list of conditions and the following disclaimer in the documentation
+ *    and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#ifndef DAV1D_SRC_FG_APPLY_H
+#define DAV1D_SRC_FG_APPLY_H
+
+#include "dav1d/picture.h"
+
+#include "common/bitdepth.h"
+
+#include "src/film_grain.h"
+
+bitfn_decls(void dav1d_apply_grain, const Dav1dFilmGrainDSPContext *const dsp,
+                                    Dav1dPicture *const out,
+                                    const Dav1dPicture *const in);
+
+#endif /* DAV1D_SRC_FG_APPLY_H */
--- /dev/null
+++ b/src/fg_apply_tmpl.c
@@ -1,0 +1,176 @@
+/*
+ * Copyright © 2018, Niklas Haas
+ * Copyright © 2018, VideoLAN and dav1d authors
+ * Copyright © 2018, Two Orioles, LLC
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ *    list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ *    this list of conditions and the following disclaimer in the documentation
+ *    and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include "config.h"
+
+#include <stdint.h>
+
+#include "dav1d/picture.h"
+
+#include "common.h"
+#include "common/intops.h"
+#include "common/bitdepth.h"
+
+#include "fg_apply.h"
+
+static void generate_scaling(const int bitdepth,
+                             const uint8_t points[][2], const int num,
+                             uint8_t scaling[SCALING_SIZE])
+{
+    const int shift_x = bitdepth - 8;
+    const int scaling_size = 1 << bitdepth;
+    const int pad = 1 << shift_x;
+
+    // Fill up the preceding entries with the initial value
+    for (int i = 0; i < points[0][0] << shift_x; i++)
+        scaling[i] = points[0][1];
+
+    // Linearly interpolate the values in the middle
+    for (int i = 0; i < num - 1; i++) {
+        const int bx = points[i][0];
+        const int by = points[i][1];
+        const int ex = points[i+1][0];
+        const int ey = points[i+1][1];
+        const int dx = ex - bx;
+        const int dy = ey - by;
+        const int delta = dy * ((0x10000 + (dx >> 1)) / dx);
+        for (int x = 0; x < dx; x++) {
+            const int v = by + ((x * delta + 0x8000) >> 16);
+            scaling[(bx + x) << shift_x] = v;
+        }
+    }
+
+    // Fill up the remaining entries with the final value
+    for (int i = points[num - 1][0] << shift_x; i < scaling_size; i++)
+        scaling[i] = points[num - 1][1];
+
+    if (pad <= 1) return;
+
+    const int rnd = pad >> 1;
+    for (int i = 0; i < num - 1; i++) {
+        const int bx = points[i][0] << shift_x;
+        const int ex = points[i+1][0] << shift_x;
+        const int dx = ex - bx;
+        for (int x = 0; x < dx; x += pad) {
+            const int range = scaling[bx + x + pad] - scaling[bx + x];
+            for (int n = 1; n < pad; n++) {
+                scaling[bx + x + n] = scaling[bx + x] + ((range * n + rnd) >> shift_x);
+            }
+        }
+    }
+}
+
+#ifndef UNIT_TEST
+void bitfn(dav1d_apply_grain)(const Dav1dFilmGrainDSPContext *const dsp,
+                              Dav1dPicture *const out,
+                              const Dav1dPicture *const in)
+{
+    const Dav1dFilmGrainData *const data = &out->frame_hdr->film_grain.data;
+
+    entry grain_lut[3][GRAIN_HEIGHT][GRAIN_WIDTH];
+    uint8_t scaling[3][SCALING_SIZE];
+#if BITDEPTH != 8
+    const int bitdepth_max = (1 << out->p.bpc) - 1;
+#endif
+
+    // Generate grain LUTs as needed
+    dsp->generate_grain_y(grain_lut[0], data HIGHBD_TAIL_SUFFIX); // always needed
+    if (data->num_uv_points[0] || data->chroma_scaling_from_luma)
+        dsp->generate_grain_uv[in->p.layout - 1](grain_lut[1], grain_lut[0],
+                                                 data, 0 HIGHBD_TAIL_SUFFIX);
+    if (data->num_uv_points[1] || data->chroma_scaling_from_luma)
+        dsp->generate_grain_uv[in->p.layout - 1](grain_lut[2], grain_lut[0],
+                                                 data, 1 HIGHBD_TAIL_SUFFIX);
+
+    // Generate scaling LUTs as needed
+    if (data->num_y_points)
+        generate_scaling(in->p.bpc, data->y_points, data->num_y_points, scaling[0]);
+    if (data->num_uv_points[0])
+        generate_scaling(in->p.bpc, data->uv_points[0], data->num_uv_points[0], scaling[1]);
+    if (data->num_uv_points[1])
+        generate_scaling(in->p.bpc, data->uv_points[1], data->num_uv_points[1], scaling[2]);
+
+    // Copy over the non-modified planes
+    // TODO: eliminate in favor of per-plane refs
+    assert(out->stride[0] == in->stride[0]);
+    if (!data->num_y_points) {
+        memcpy(out->data[0], in->data[0], out->p.h * out->stride[0]);
+    }
+
+    if (in->p.layout != DAV1D_PIXEL_LAYOUT_I400) {
+        assert(out->stride[1] == in->stride[1]);
+        for (int i = 0; i < 2; i++) {
+            if (!data->num_uv_points[i] && !data->chroma_scaling_from_luma) {
+                const int suby = in->p.layout == DAV1D_PIXEL_LAYOUT_I420;
+                memcpy(out->data[1+i], in->data[1+i],
+                       (out->p.h >> suby) * out->stride[1]);
+            }
+        }
+    }
+
+    // Synthesize grain for the affected planes
+    const int rows = (out->p.h + 31) >> 5;
+    const int ss_y = in->p.layout == DAV1D_PIXEL_LAYOUT_I420;
+    const int is_id = out->seq_hdr->mtrx == DAV1D_MC_IDENTITY;
+    for (int row = 0; row < rows; row++) {
+        const pixel *const luma_src =
+            ((pixel *) in->data[0]) + row * BLOCK_SIZE * PXSTRIDE(in->stride[0]);
+
+        if (data->num_y_points) {
+            const int bh = imin(out->p.h - row * BLOCK_SIZE, BLOCK_SIZE);
+            dsp->fgy_32x32xn(((pixel *) out->data[0]) + row * BLOCK_SIZE * PXSTRIDE(out->stride[0]),
+                             luma_src, out->stride[0], &out->frame_hdr->film_grain.data,
+                             out->p.w, scaling[0], grain_lut[0], bh, row HIGHBD_TAIL_SUFFIX);
+        }
+
+        const int bh = (imin(out->p.h - row * BLOCK_SIZE, BLOCK_SIZE) + ss_y) >> ss_y;
+        const ptrdiff_t uv_off = row * BLOCK_SIZE * PXSTRIDE(out->stride[1]) >> ss_y;
+        if (data->chroma_scaling_from_luma) {
+            for (int pl = 0; pl < 2; pl++)
+                dsp->fguv_32x32xn[in->p.layout - 1](((pixel *) out->data[1 + pl]) + uv_off,
+                                                    ((const pixel *) in->data[1 + pl]) + uv_off,
+                                                    in->stride[1], luma_src,
+                                                    in->stride[0], out->p.w, bh,
+                                                    &out->frame_hdr->film_grain.data,
+                                                    grain_lut[1 + pl], scaling[0],
+                                                    pl, row, is_id HIGHBD_TAIL_SUFFIX);
+        } else {
+            for (int pl = 0; pl < 2; pl++)
+                if (data->num_uv_points[pl])
+                    dsp->fguv_32x32xn[in->p.layout - 1](((pixel *) out->data[1 + pl]) + uv_off,
+                                                        ((const pixel *) in->data[1 + pl]) + uv_off,
+                                                        in->stride[1], luma_src,
+                                                        in->stride[0], out->p.w, bh,
+                                                        &out->frame_hdr->film_grain.data,
+                                                        grain_lut[1 + pl],
+                                                        scaling[1 + pl], pl, row, is_id
+                                                        HIGHBD_TAIL_SUFFIX);
+        }
+    }
+}
+#endif
--- a/src/film_grain.h
+++ b/src/film_grain.h
@@ -28,9 +28,57 @@
 #ifndef DAV1D_SRC_FILM_GRAIN_H
 #define DAV1D_SRC_FILM_GRAIN_H
 
-#include "dav1d/dav1d.h"
+#include "common/bitdepth.h"
 
-bitfn_decls(void dav1d_apply_grain, Dav1dPicture *const out,
-                                    const Dav1dPicture *const in);
+#include "src/levels.h"
+
+#define GRAIN_WIDTH 82
+#define GRAIN_HEIGHT 73
+#define BLOCK_SIZE 32
+#if !defined(BITDEPTH) || BITDEPTH == 8
+#define SCALING_SIZE 256
+typedef int8_t entry;
+#else
+#define SCALING_SIZE 4096
+typedef int16_t entry;
+#endif
+
+#define decl_generate_grain_y_fn(name) \
+void (name)(entry buf[GRAIN_HEIGHT][GRAIN_WIDTH], \
+            const Dav1dFilmGrainData *const data HIGHBD_DECL_SUFFIX)
+typedef decl_generate_grain_y_fn(*generate_grain_y_fn);
+
+#define decl_generate_grain_uv_fn(name) \
+void (name)(entry buf[GRAIN_HEIGHT][GRAIN_WIDTH], \
+            const entry buf_y[GRAIN_HEIGHT][GRAIN_WIDTH], \
+            const Dav1dFilmGrainData *const data, const int uv HIGHBD_DECL_SUFFIX)
+typedef decl_generate_grain_uv_fn(*generate_grain_uv_fn);
+
+#define decl_fgy_32x32xn_fn(name) \
+void (name)(pixel *dst_row, const pixel *src_row, ptrdiff_t stride, \
+            const Dav1dFilmGrainData *data, \
+            int pw, const uint8_t scaling[SCALING_SIZE], \
+            const entry grain_lut[GRAIN_HEIGHT][GRAIN_WIDTH], \
+            int bh, int row_num HIGHBD_DECL_SUFFIX)
+typedef decl_fgy_32x32xn_fn(*fgy_32x32xn_fn);
+
+#define decl_fguv_32x32xn_fn(name) \
+void (name)(pixel *dst_row, const pixel *src_row, ptrdiff_t stride, \
+            const pixel *luma_row, ptrdiff_t luma_stride, int pw, int bh, \
+            const Dav1dFilmGrainData *data, \
+            const entry grain_lut[GRAIN_HEIGHT][GRAIN_WIDTH], \
+            const uint8_t scaling[SCALING_SIZE], \
+            int uv_pl, int row_num, int is_id HIGHBD_DECL_SUFFIX)
+typedef decl_fguv_32x32xn_fn(*fguv_32x32xn_fn);
+
+typedef struct Dav1dFilmGrainDSPContext {
+    generate_grain_y_fn generate_grain_y;
+    generate_grain_uv_fn generate_grain_uv[3];
+
+    fgy_32x32xn_fn fgy_32x32xn;
+    fguv_32x32xn_fn fguv_32x32xn[3];
+} Dav1dFilmGrainDSPContext;
+
+bitfn_decls(void dav1d_film_grain_dsp_init, Dav1dFilmGrainDSPContext *c);
 
 #endif /* DAV1D_SRC_FILM_GRAIN_H */
--- a/src/film_grain_tmpl.c
+++ b/src/film_grain_tmpl.c
@@ -26,38 +26,16 @@
  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  */
 
-#include "config.h"
-
-#include <stdint.h>
-
-#include "common.h"
+#include "common/attributes.h"
 #include "common/intops.h"
-#include "common/bitdepth.h"
-#include "tables.h"
 
 #include "film_grain.h"
+#include "tables.h"
 
-#if BITDEPTH == 8
-typedef int8_t entry;
-#else
-typedef int16_t entry;
-#endif
+#define SUB_GRAIN_WIDTH 44
+#define SUB_GRAIN_HEIGHT 38
 
-enum {
-    GRAIN_WIDTH  = 82,
-    GRAIN_HEIGHT = 73,
-    SUB_GRAIN_WIDTH = 44,
-    SUB_GRAIN_HEIGHT = 38,
-    SUB_GRAIN_OFFSET = 6,
-    BLOCK_SIZE = 32,
-#if BITDEPTH == 8
-    SCALING_SIZE = 256
-#else
-    SCALING_SIZE = 4096
-#endif
-};
-
-static inline int get_random_number(const int bits, unsigned *state) {
+static inline int get_random_number(const int bits, unsigned *const state) {
     const int r = *state;
     unsigned bit = ((r >> 0) ^ (r >> 1) ^ (r >> 3) ^ (r >> 12)) & 1;
     *state = (r >> 1) | (bit << 15);
@@ -69,13 +47,14 @@
     return (x + ((1 << shift) >> 1)) >> shift;
 }
 
-static void generate_grain_y(const Dav1dPicture *const in,
-                             entry buf[GRAIN_HEIGHT][GRAIN_WIDTH])
+static void generate_grain_y_c(entry buf[GRAIN_HEIGHT][GRAIN_WIDTH],
+                               const Dav1dFilmGrainData *const data
+                               HIGHBD_DECL_SUFFIX)
 {
-    const Dav1dFilmGrainData *data = &in->frame_hdr->film_grain.data;
+    const int bitdepth_min_8 = bitdepth_from_max(bitdepth_max) - 8;
     unsigned seed = data->seed;
-    const int shift = 12 - in->p.bpc + data->grain_scale_shift;
-    const int grain_ctr = 128 << (in->p.bpc - 8);
+    const int shift = 4 - bitdepth_min_8 + data->grain_scale_shift;
+    const int grain_ctr = 128 << bitdepth_min_8;
     const int grain_min = -grain_ctr, grain_max = grain_ctr - 1;
 
     for (int y = 0; y < GRAIN_HEIGHT; y++) {
@@ -100,25 +79,24 @@
                 }
             }
 
-            int grain = buf[y][x] + round2(sum, data->ar_coeff_shift);
+            const int grain = buf[y][x] + round2(sum, data->ar_coeff_shift);
             buf[y][x] = iclip(grain, grain_min, grain_max);
         }
     }
 }
 
-static void generate_grain_uv(const Dav1dPicture *const in, int uv,
-                              entry buf[GRAIN_HEIGHT][GRAIN_WIDTH],
-                              entry buf_y[GRAIN_HEIGHT][GRAIN_WIDTH])
+static NOINLINE void
+generate_grain_uv_c(entry buf[GRAIN_HEIGHT][GRAIN_WIDTH],
+                    const entry buf_y[GRAIN_HEIGHT][GRAIN_WIDTH],
+                    const Dav1dFilmGrainData *const data, const int uv,
+                    const int subx, const int suby HIGHBD_DECL_SUFFIX)
 {
-    const Dav1dFilmGrainData *data = &in->frame_hdr->film_grain.data;
+    const int bitdepth_min_8 = bitdepth_from_max(bitdepth_max) - 8;
     unsigned seed = data->seed ^ (uv ? 0x49d8 : 0xb524);
-    const int shift = 12 - in->p.bpc + data->grain_scale_shift;
-    const int grain_ctr = 128 << (in->p.bpc - 8);
+    const int shift = 4 - bitdepth_min_8 + data->grain_scale_shift;
+    const int grain_ctr = 128 << bitdepth_min_8;
     const int grain_min = -grain_ctr, grain_max = grain_ctr - 1;
 
-    const int subx = in->p.layout != DAV1D_PIXEL_LAYOUT_I444;
-    const int suby = in->p.layout == DAV1D_PIXEL_LAYOUT_I420;
-
     const int chromaW = subx ? SUB_GRAIN_WIDTH  : GRAIN_WIDTH;
     const int chromaH = suby ? SUB_GRAIN_HEIGHT : GRAIN_HEIGHT;
 
@@ -166,56 +144,18 @@
     }
 }
 
-static void generate_scaling(const int bitdepth,
-                             const uint8_t points[][2], int num,
-                             uint8_t scaling[SCALING_SIZE])
-{
-    const int shift_x = bitdepth - 8;
-    const int scaling_size = 1 << bitdepth;
-    const int pad = 1 << shift_x;
-
-    // Fill up the preceding entries with the initial value
-    for (int i = 0; i < points[0][0] << shift_x; i++)
-        scaling[i] = points[0][1];
-
-    // Linearly interpolate the values in the middle
-    for (int i = 0; i < num - 1; i++) {
-        const int bx = points[i][0];
-        const int by = points[i][1];
-        const int ex = points[i+1][0];
-        const int ey = points[i+1][1];
-        const int dx = ex - bx;
-        const int dy = ey - by;
-        const int delta = dy * ((0x10000 + (dx >> 1)) / dx);
-        for (int x = 0; x < dx; x++) {
-            const int v = by + ((x * delta + 0x8000) >> 16);
-            scaling[(bx + x) << shift_x] = v;
-        }
-    }
-
-    // Fill up the remaining entries with the final value
-    for (int i = points[num - 1][0] << shift_x; i < scaling_size; i++)
-        scaling[i] = points[num - 1][1];
-
-    if (pad > 1) {
-        const int rnd = pad >> 1;
-        for (int i = 0; i < num - 1; i++) {
-            const int bx = points[i][0] << shift_x;
-            const int ex = points[i+1][0] << shift_x;
-            const int dx = ex - bx;
-            for (int x = 0; x < dx; x += pad) {
-                const int range = scaling[bx + x + pad] - scaling[bx + x];
-                for (int n = 1; n < pad; n++) {
-                    scaling[bx + x + n] = scaling[bx + x] + ((range * n + rnd) >> shift_x);
-                }
-            }
-        }
-    }
-}
+#define gnuv_ss_fn(nm, ss_x, ss_y) \
+static decl_generate_grain_uv_fn(generate_grain_uv_##nm##_c) { \
+    generate_grain_uv_c(buf, buf_y, data, uv, ss_x, ss_y HIGHBD_TAIL_SUFFIX); \
+}
 
+gnuv_ss_fn(420, 1, 1);
+gnuv_ss_fn(422, 1, 0);
+gnuv_ss_fn(444, 0, 0);
+
 // samples from the correct block of a grain LUT, while taking into account the
 // offsets provided by the offsets cache
-static inline entry sample_lut(entry grain_lut[GRAIN_HEIGHT][GRAIN_WIDTH],
+static inline entry sample_lut(const entry grain_lut[GRAIN_HEIGHT][GRAIN_WIDTH],
                                int offsets[2][2], int subx, int suby,
                                int bx, int by, int x, int y)
 {
@@ -226,13 +166,15 @@
                     [offx + x + (BLOCK_SIZE >> subx) * bx];
 }
 
-static void apply_to_row_y(Dav1dPicture *const out, const Dav1dPicture *const in,
-                           entry grain_lut[GRAIN_HEIGHT][GRAIN_WIDTH],
-                           uint8_t scaling[SCALING_SIZE], int row_num)
+static void fgy_32x32xn_c(pixel *const dst_row, const pixel *const src_row,
+                          const ptrdiff_t stride,
+                          const Dav1dFilmGrainData *const data, const int pw,
+                          const uint8_t scaling[SCALING_SIZE],
+                          const entry grain_lut[GRAIN_HEIGHT][GRAIN_WIDTH],
+                          const int bh, const int row_num HIGHBD_DECL_SUFFIX)
 {
-    const Dav1dFilmGrainData *const data = &out->frame_hdr->film_grain.data;
     const int rows = 1 + (data->overlap_flag && row_num > 0);
-    const int bitdepth_min_8 = in->p.bpc - 8;
+    const int bitdepth_min_8 = bitdepth_from_max(bitdepth_max) - 8;
     const int grain_ctr = 128 << bitdepth_min_8;
     const int grain_min = -grain_ctr, grain_max = grain_ctr - 1;
 
@@ -242,7 +184,11 @@
         max_value = 235 << bitdepth_min_8;
     } else {
         min_value = 0;
-        max_value = (1U << in->p.bpc) - 1;
+#if BITDEPTH == 8
+        max_value = 0xff;
+#else
+        max_value = bitdepth_max;
+#endif
     }
 
     // seed[0] contains the current row, seed[1] contains the previous
@@ -253,18 +199,13 @@
         seed[i] ^= (((row_num - i) * 173 + 105) & 0xFF);
     }
 
-    const ptrdiff_t stride = out->stride[0];
     assert(stride % (BLOCK_SIZE * sizeof(pixel)) == 0);
-    assert(stride == in->stride[0]);
-    pixel *const src_row = (pixel *)  in->data[0] + PXSTRIDE(stride) * row_num * BLOCK_SIZE;
-    pixel *const dst_row = (pixel *) out->data[0] + PXSTRIDE(stride) * row_num * BLOCK_SIZE;
 
     int offsets[2 /* col offset */][2 /* row offset */];
 
     // process this row in BLOCK_SIZE^2 blocks
-    const int bh = imin(out->p.h - row_num * BLOCK_SIZE, BLOCK_SIZE);
-    for (int bx = 0; bx < out->p.w; bx += BLOCK_SIZE) {
-        const int bw = imin(BLOCK_SIZE, out->p.w - bx);
+    for (int bx = 0; bx < pw; bx += BLOCK_SIZE) {
+        const int bw = imin(BLOCK_SIZE, pw - bx);
 
         if (data->overlap_flag && bx) {
             // shift previous offsets left
@@ -282,11 +223,11 @@
 
         static const int w[2][2] = { { 27, 17 }, { 17, 27 } };
 
-#define add_noise_y(x, y, grain)                                                \
-            pixel *src = src_row + (y) * PXSTRIDE(stride) + (bx + (x));         \
-            pixel *dst = dst_row + (y) * PXSTRIDE(stride) + (bx + (x));         \
-            int noise = round2(scaling[ *src ] * (grain), data->scaling_shift); \
-            *dst = iclip(*src + noise, min_value, max_value);
+#define add_noise_y(x, y, grain)                                                  \
+        const pixel *const src = src_row + (y) * PXSTRIDE(stride) + (x) + bx;     \
+        pixel *const dst = dst_row + (y) * PXSTRIDE(stride) + (x) + bx;           \
+        const int noise = round2(scaling[ *src ] * (grain), data->scaling_shift); \
+        *dst = iclip(*src + noise, min_value, max_value);
 
         for (int y = ystart; y < bh; y++) {
             // Non-overlapped image region (straightforward)
@@ -338,14 +279,18 @@
     }
 }
 
-static void apply_to_row_uv(Dav1dPicture *const out, const Dav1dPicture *const in,
-                            entry grain_lut[GRAIN_HEIGHT][GRAIN_WIDTH],
-                            uint8_t scaling[SCALING_SIZE], int uv, int row_num)
+static NOINLINE void
+fguv_32x32xn_c(pixel *const dst_row, const pixel *const src_row,
+               const ptrdiff_t stride, const pixel *const luma_row,
+               const ptrdiff_t luma_stride, const int pw, const int bh,
+               const Dav1dFilmGrainData *const data,
+               const entry grain_lut[GRAIN_HEIGHT][GRAIN_WIDTH],
+               const uint8_t scaling[SCALING_SIZE],
+               const int uv, const int row_num, const int is_id,
+               const int sx, const int sy HIGHBD_DECL_SUFFIX)
 {
-    const Dav1dFilmGrainData *const data = &out->frame_hdr->film_grain.data;
     const int rows = 1 + (data->overlap_flag && row_num > 0);
-    const int bitdepth_max = (1 << in->p.bpc) - 1;
-    const int bitdepth_min_8 = in->p.bpc - 8;
+    const int bitdepth_min_8 = bitdepth_from_max(bitdepth_max) - 8;
     const int grain_ctr = 128 << bitdepth_min_8;
     const int grain_min = -grain_ctr, grain_max = grain_ctr - 1;
 
@@ -352,19 +297,16 @@
     int min_value, max_value;
     if (data->clip_to_restricted_range) {
         min_value = 16 << bitdepth_min_8;
-        if (out->seq_hdr->mtrx == DAV1D_MC_IDENTITY) {
-            max_value = 235 << bitdepth_min_8;
-        } else {
-            max_value = 240 << bitdepth_min_8;
-        }
+        max_value = (is_id ? 235 : 240) << bitdepth_min_8;
     } else {
         min_value = 0;
+#if BITDEPTH == 8
+        max_value = 0xff;
+#else
         max_value = bitdepth_max;
+#endif
     }
 
-    const int sx = in->p.layout != DAV1D_PIXEL_LAYOUT_I444;
-    const int sy = in->p.layout == DAV1D_PIXEL_LAYOUT_I420;
-
     // seed[0] contains the current row, seed[1] contains the previous
     unsigned seed[2];
     for (int i = 0; i < rows; i++) {
@@ -373,21 +315,13 @@
         seed[i] ^= (((row_num - i) * 173 + 105) & 0xFF);
     }
 
-    const ptrdiff_t stride = out->stride[1];
     assert(stride % (BLOCK_SIZE * sizeof(pixel)) == 0);
-    assert(stride == in->stride[1]);
 
-    const int by = row_num * (BLOCK_SIZE >> sy);
-    pixel *const dst_row = (pixel *) out->data[1 + uv] + PXSTRIDE(stride) * by;
-    pixel *const src_row = (pixel *)  in->data[1 + uv] + PXSTRIDE(stride) * by;
-    pixel *const luma_row = (pixel *) in->data[0] + PXSTRIDE(in->stride[0]) * row_num * BLOCK_SIZE;
-
     int offsets[2 /* col offset */][2 /* row offset */];
 
     // process this row in BLOCK_SIZE^2 blocks (subsampled)
-    const int bh = (imin(out->p.h - row_num * BLOCK_SIZE, BLOCK_SIZE) + sy) >> sy;
-    for (int bx = 0; bx < (out->p.w + sx) >> sx; bx += BLOCK_SIZE >> sx) {
-        const int bw = (imin(BLOCK_SIZE, out->p.w - (bx << sx)) + sx) >> sx;
+    for (int bx = 0; bx < (pw + sx) >> sx; bx += BLOCK_SIZE >> sx) {
+        const int bw = (imin(BLOCK_SIZE, pw - (bx << sx)) + sx) >> sx;
         if (data->overlap_flag && bx) {
             // shift previous offsets left
             for (int i = 0; i < rows; i++)
@@ -407,25 +341,23 @@
             { { 23, 22 } },
         };
 
-#define add_noise_uv(x, y, grain)                                               \
-            const int lx = (bx + x) << sx;                                      \
-            const int ly = y << sy;                                             \
-            pixel *luma = luma_row + ly * PXSTRIDE(in->stride[0]) + lx;         \
-            pixel avg = luma[0];                                                \
-            if (sx && lx + 1 < out->p.w)                                        \
-                avg = (avg + luma[1] + 1) >> 1;                                 \
-                                                                                \
-            pixel *src = src_row + (y) * PXSTRIDE(stride) + (bx + (x));         \
-            pixel *dst = dst_row + (y) * PXSTRIDE(stride) + (bx + (x));         \
-            int val = avg;                                                      \
-            if (!data->chroma_scaling_from_luma) {                              \
-                int combined = avg * data->uv_luma_mult[uv] +                   \
-                               *src * data->uv_mult[uv];                        \
-                val = iclip_pixel( (combined >> 6) +                            \
-                                   (data->uv_offset[uv] * (1 << bitdepth_min_8)) );   \
-            }                                                                   \
-                                                                                \
-            int noise = round2(scaling[ val ] * (grain), data->scaling_shift);  \
+#define add_noise_uv(x, y, grain)                                                    \
+            const int lx = (bx + x) << sx;                                           \
+            const int ly = y << sy;                                                  \
+            const pixel *const luma = luma_row + ly * PXSTRIDE(luma_stride) + lx;    \
+            pixel avg = luma[0];                                                     \
+            if (sx && lx + 1 < pw)                                                   \
+                avg = (avg + luma[1] + 1) >> 1;                                      \
+            const pixel *const src = src_row + (y) * PXSTRIDE(stride) + (bx + (x));  \
+            pixel *const dst = dst_row + (y) * PXSTRIDE(stride) + (bx + (x));        \
+            int val = avg;                                                           \
+            if (!data->chroma_scaling_from_luma) {                                   \
+                const int combined = avg * data->uv_luma_mult[uv] +                  \
+                               *src * data->uv_mult[uv];                             \
+                val = iclip_pixel( (combined >> 6) +                                 \
+                                   (data->uv_offset[uv] * (1 << bitdepth_min_8)) );  \
+            }                                                                        \
+            const int noise = round2(scaling[ val ] * (grain), data->scaling_shift); \
             *dst = iclip(*src + noise, min_value, max_value);
 
         for (int y = ystart; y < bh; y++) {
@@ -478,61 +410,25 @@
     }
 }
 
-void bitfn(dav1d_apply_grain)(Dav1dPicture *const out,
-                              const Dav1dPicture *const in)
-{
-    const Dav1dFilmGrainData *const data = &out->frame_hdr->film_grain.data;
+#define fguv_ss_fn(nm, ss_x, ss_y) \
+static decl_fguv_32x32xn_fn(fguv_32x32xn_##nm##_c) { \
+    fguv_32x32xn_c(dst_row, src_row, stride, luma_row, luma_stride, pw, bh, \
+                   data, grain_lut, scaling, uv_pl, row_num, is_id, ss_x, ss_y \
+                   HIGHBD_TAIL_SUFFIX); \
+}
 
-    entry grain_lut[3][GRAIN_HEIGHT][GRAIN_WIDTH];
-    uint8_t scaling[3][SCALING_SIZE];
+fguv_ss_fn(420, 1, 1);
+fguv_ss_fn(422, 1, 0);
+fguv_ss_fn(444, 0, 0);
 
-    // Generate grain LUTs as needed
-    generate_grain_y(out, grain_lut[0]); // always needed
-    if (data->num_uv_points[0] || data->chroma_scaling_from_luma)
-        generate_grain_uv(out, 0, grain_lut[1], grain_lut[0]);
-    if (data->num_uv_points[1] || data->chroma_scaling_from_luma)
-        generate_grain_uv(out, 1, grain_lut[2], grain_lut[0]);
+COLD void bitfn(dav1d_film_grain_dsp_init)(Dav1dFilmGrainDSPContext *const c) {
+    c->generate_grain_y = generate_grain_y_c;
+    c->generate_grain_uv[DAV1D_PIXEL_LAYOUT_I420 - 1] = generate_grain_uv_420_c;
+    c->generate_grain_uv[DAV1D_PIXEL_LAYOUT_I422 - 1] = generate_grain_uv_422_c;
+    c->generate_grain_uv[DAV1D_PIXEL_LAYOUT_I444 - 1] = generate_grain_uv_444_c;
 
-    // Generate scaling LUTs as needed
-    if (data->num_y_points)
-        generate_scaling(in->p.bpc, data->y_points, data->num_y_points, scaling[0]);
-    if (data->num_uv_points[0])
-        generate_scaling(in->p.bpc, data->uv_points[0], data->num_uv_points[0], scaling[1]);
-    if (data->num_uv_points[1])
-        generate_scaling(in->p.bpc, data->uv_points[1], data->num_uv_points[1], scaling[2]);
-
-    // Copy over the non-modified planes
-    // TODO: eliminate in favor of per-plane refs
-    if (!data->num_y_points) {
-        assert(out->stride[0] == in->stride[0]);
-        memcpy(out->data[0], in->data[0], out->p.h * out->stride[0]);
-    }
-
-    if (in->p.layout != DAV1D_PIXEL_LAYOUT_I400) {
-        for (int i = 0; i < 2; i++) {
-            if (!data->num_uv_points[i] && !data->chroma_scaling_from_luma) {
-                const int suby = in->p.layout == DAV1D_PIXEL_LAYOUT_I420;
-                assert(out->stride[1] == in->stride[1]);
-                memcpy(out->data[1+i], in->data[1+i],
-                       (out->p.h >> suby) * out->stride[1]);
-            }
-        }
-    }
-
-    // Synthesize grain for the affected planes
-    int rows = (out->p.h + 31) >> 5;
-    for (int row = 0; row < rows; row++) {
-        if (data->num_y_points)
-            apply_to_row_y(out, in, grain_lut[0], scaling[0], row);
-
-        if (data->chroma_scaling_from_luma) {
-            apply_to_row_uv(out, in, grain_lut[1], scaling[0], 0, row);
-            apply_to_row_uv(out, in, grain_lut[2], scaling[0], 1, row);
-        } else {
-            if (data->num_uv_points[0])
-                apply_to_row_uv(out, in, grain_lut[1], scaling[1], 0, row);
-            if (data->num_uv_points[1])
-                apply_to_row_uv(out, in, grain_lut[2], scaling[2], 1, row);
-        }
-    }
+    c->fgy_32x32xn = fgy_32x32xn_c;
+    c->fguv_32x32xn[DAV1D_PIXEL_LAYOUT_I420 - 1] = fguv_32x32xn_420_c;
+    c->fguv_32x32xn[DAV1D_PIXEL_LAYOUT_I422 - 1] = fguv_32x32xn_422_c;
+    c->fguv_32x32xn[DAV1D_PIXEL_LAYOUT_I444 - 1] = fguv_32x32xn_444_c;
 }
--- a/src/internal.h
+++ b/src/internal.h
@@ -42,6 +42,7 @@
 #include "src/cdf.h"
 #include "src/data.h"
 #include "src/env.h"
+#include "src/film_grain.h"
 #include "src/intra_edge.h"
 #include "src/ipred.h"
 #include "src/itx.h"
@@ -57,6 +58,7 @@
 #include "src/thread.h"
 
 typedef struct Dav1dDSPContext {
+    Dav1dFilmGrainDSPContext fg;
     Dav1dIntraPredDSPContext ipred;
     Dav1dMCDSPContext mc;
     Dav1dInvTxfmDSPContext itx;
--- a/src/lib.c
+++ b/src/lib.c
@@ -37,6 +37,7 @@
 #include "common/mem.h"
 #include "common/validate.h"
 
+#include "src/fg_apply.h"
 #include "src/internal.h"
 #include "src/log.h"
 #include "src/obu.h"
@@ -44,7 +45,6 @@
 #include "src/ref.h"
 #include "src/thread_task.h"
 #include "src/wedge.h"
-#include "src/film_grain.h"
 
 static COLD void init_internal(void) {
     dav1d_init_wedge_masks();
@@ -290,13 +290,13 @@
     switch (out->p.bpc) {
 #if CONFIG_8BPC
     case 8:
-        dav1d_apply_grain_8bpc(out, in);
+        dav1d_apply_grain_8bpc(&c->dsp[0].fg, out, in);
         break;
 #endif
 #if CONFIG_16BPC
     case 10:
     case 12:
-        dav1d_apply_grain_16bpc(out, in);
+        dav1d_apply_grain_16bpc(&c->dsp[(out->p.bpc >> 1) - 4].fg, out, in);
         break;
 #endif
     default:
--- a/src/meson.build
+++ b/src/meson.build
@@ -55,6 +55,7 @@
 libdav1d_tmpl_sources = files(
     'cdef_apply_tmpl.c',
     'cdef_tmpl.c',
+    'fg_apply_tmpl.c',
     'film_grain_tmpl.c',
     'ipred_prepare_tmpl.c',
     'ipred_tmpl.c',