ref: 544b3e576c8edd1785914c988882b62d60652f26
parent: 98b8be09d56d03d220fff3536842c0703bae865c
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sun Nov 5 22:10:59 EST 2023
DRED: quantize r and p0 parameters with 8 bits Only code non-degenerate symbols, which makes the encoder faster
--- a/autogen.sh
+++ b/autogen.sh
@@ -9,7 +9,7 @@
srcdir=`dirname $0`
test -n "$srcdir" && cd "$srcdir"
-dnn/download_model.sh c99054d
+dnn/download_model.sh 98b8be0
echo "Updating build configuration files, please wait...."
--- a/dnn/dred_rdovae.c
+++ b/dnn/dred_rdovae.c
@@ -79,9 +79,9 @@
}
-const opus_uint16 * DRED_rdovae_get_p0_pointer(void)
+const opus_uint8 * DRED_rdovae_get_p0_pointer(void)
{
- return &dred_p0_q15[0];
+ return &dred_p0_q8[0];
}
const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void)
@@ -89,9 +89,9 @@
return &dred_dead_zone_q10[0];
}
-const opus_uint16 * DRED_rdovae_get_r_pointer(void)
+const opus_uint8 * DRED_rdovae_get_r_pointer(void)
{
- return &dred_r_q15[0];
+ return &dred_r_q8[0];
}
const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void)
--- a/dnn/dred_rdovae.h
+++ b/dnn/dred_rdovae.h
@@ -58,9 +58,9 @@
void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float *qframe, const float * z);
-const opus_uint16 * DRED_rdovae_get_p0_pointer(void);
+const opus_uint8 * DRED_rdovae_get_p0_pointer(void);
const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void);
-const opus_uint16 * DRED_rdovae_get_r_pointer(void);
+const opus_uint8 * DRED_rdovae_get_r_pointer(void);
const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void);
#endif
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -63,20 +63,20 @@
quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16)
- r_q15 = np.round(r * 2**15).astype(np.uint16)
- p0_q15 = np.round(p0 * 2**15).astype(np.uint16)
+ r_q15 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
+ p0_q15 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
- print_vector(writer.source, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False)
- print_vector(writer.source, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False)
+ print_vector(writer.source, r_q15, 'dred_r_q8', dtype='opus_uint8', static=False)
+ print_vector(writer.source, p0_q15, 'dred_p0_q8', dtype='opus_uint8', static=False)
writer.header.write(
f"""
extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
-extern const opus_uint16 dred_r_q15[{levels * N}];
-extern const opus_uint16 dred_p0_q15[{levels * N}];
+extern const opus_uint8 dred_r_q8[{levels * N}];
+extern const opus_uint8 dred_p0_q8[{levels * N}];
"""
)
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -222,6 +222,9 @@
return clip_weights
+def n(x):
+ return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
+
# RDOVAE module and submodules
class MyConv(nn.Module):
@@ -295,17 +298,17 @@
device = x.device
# run encoding layer stack
- x = torch.tanh(self.dense_1(x))
- x = torch.cat([x, self.gru1(x)[0]], -1)
- x = torch.cat([x, self.conv1(x)], -1)
- x = torch.cat([x, self.gru2(x)[0]], -1)
- x = torch.cat([x, self.conv2(x)], -1)
- x = torch.cat([x, self.gru3(x)[0]], -1)
- x = torch.cat([x, self.conv3(x)], -1)
- x = torch.cat([x, self.gru4(x)[0]], -1)
- x = torch.cat([x, self.conv4(x)], -1)
- x = torch.cat([x, self.gru5(x)[0]], -1)
- x = torch.cat([x, self.conv5(x)], -1)
+ x = n(torch.tanh(self.dense_1(x)))
+ x = torch.cat([x, n(self.gru1(x)[0])], -1)
+ x = torch.cat([x, n(self.conv1(x))], -1)
+ x = torch.cat([x, n(self.gru2(x)[0])], -1)
+ x = torch.cat([x, n(self.conv2(x))], -1)
+ x = torch.cat([x, n(self.gru3(x)[0])], -1)
+ x = torch.cat([x, n(self.conv3(x))], -1)
+ x = torch.cat([x, n(self.gru4(x)[0])], -1)
+ x = torch.cat([x, n(self.conv4(x))], -1)
+ x = torch.cat([x, n(self.gru5(x)[0])], -1)
+ x = torch.cat([x, n(self.conv5(x))], -1)
z = self.z_dense(x)
# init state for decoder
@@ -372,18 +375,18 @@
h5_state = gru_state[:,:,384:].contiguous()
# run decoding layer stack
- x = torch.tanh(self.dense_1(z))
+ x = n(torch.tanh(self.dense_1(z)))
- x = torch.cat([x, self.gru1(x, h1_state)[0]], -1)
- x = torch.cat([x, self.conv1(x)], -1)
- x = torch.cat([x, self.gru2(x, h2_state)[0]], -1)
- x = torch.cat([x, self.conv2(x)], -1)
- x = torch.cat([x, self.gru3(x, h3_state)[0]], -1)
- x = torch.cat([x, self.conv3(x)], -1)
- x = torch.cat([x, self.gru4(x, h4_state)[0]], -1)
- x = torch.cat([x, self.conv4(x)], -1)
- x = torch.cat([x, self.gru5(x, h5_state)[0]], -1)
- x = torch.cat([x, self.conv5(x)], -1)
+ x = torch.cat([x, n(self.gru1(x, h1_state)[0])], -1)
+ x = torch.cat([x, n(self.conv1(x))], -1)
+ x = torch.cat([x, n(self.gru2(x, h2_state)[0])], -1)
+ x = torch.cat([x, n(self.conv2(x))], -1)
+ x = torch.cat([x, n(self.gru3(x, h3_state)[0])], -1)
+ x = torch.cat([x, n(self.conv3(x))], -1)
+ x = torch.cat([x, n(self.gru4(x, h4_state)[0])], -1)
+ x = torch.cat([x, n(self.conv4(x))], -1)
+ x = torch.cat([x, n(self.gru5(x, h5_state)[0])], -1)
+ x = torch.cat([x, n(self.conv5(x))], -1)
# output layer and reshaping
x10 = self.output(x)
@@ -451,7 +454,7 @@
cond_size2,
state_dim=24,
split_mode='split',
- clip_weights=True,
+ clip_weights=False,
pvq_num_pulses=82,
state_dropout_rate=0):
@@ -487,7 +490,7 @@
if not type(self.weight_clip_fn) == type(None):
self.apply(self.weight_clip_fn)
- def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 24):
+ def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
enc_stride = self.enc_stride
dec_stride = self.dec_stride
--- a/silk/dred_config.h
+++ b/silk/dred_config.h
@@ -32,7 +32,7 @@
#define DRED_EXTENSION_ID 126
/* Remove these two completely once DRED gets an extension number assigned. */
-#define DRED_EXPERIMENTAL_VERSION 6
+#define DRED_EXPERIMENTAL_VERSION 7
#define DRED_EXPERIMENTAL_BYTES 2
--- a/silk/dred_decoder.c
+++ b/silk/dred_decoder.c
@@ -43,11 +43,12 @@
return (x ^ m) - m;
}
-static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint16 *r, const opus_uint16 *p0, int dim) {
+static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
int i;
for (i=0;i<dim;i++) {
int q;
- q = ec_laplace_decode_p0(dec, p0[i], r[i]);
+ if (r[i] == 0 || p0[i] == 255) q = 0;
+ else q = ec_laplace_decode_p0(dec, p0[i]<<7, r[i]<<7);
x[i] = q*256.f/(scale[i] == 0 ? 1 : scale[i]);
}
}
@@ -54,9 +55,9 @@
int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int min_feature_frames)
{
- const opus_uint16 *p0 = DRED_rdovae_get_p0_pointer();
+ const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer();
const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer();
- const opus_uint16 *r = DRED_rdovae_get_r_pointer();
+ const opus_uint8 *r = DRED_rdovae_get_r_pointer();
ec_dec ec;
int q_level;
int i;
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -217,7 +217,7 @@
}
}
-static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint16 *r, const opus_uint16 *p0, int dim) {
+static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
int i;
int q[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
float xq[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
@@ -238,16 +238,16 @@
}
for (i=0;i<dim;i++) {
/* Make the impossible actually impossible. */
- if (r[i] == 0 || p0[i] >= 32767) q[i] = 0;
- ec_laplace_encode_p0(enc, q[i], p0[i], r[i]);
+ if (r[i] == 0 || p0[i] == 255) q[i] = 0;
+ else ec_laplace_encode_p0(enc, q[i], p0[i]<<7, r[i]<<7);
}
}
int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunks, int max_bytes) {
const opus_uint16 *dead_zone = DRED_rdovae_get_dead_zone_pointer();
- const opus_uint16 *p0 = DRED_rdovae_get_p0_pointer();
+ const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer();
const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer();
- const opus_uint16 *r = DRED_rdovae_get_r_pointer();
+ const opus_uint8 *r = DRED_rdovae_get_r_pointer();
ec_enc ec_encoder;
int q_level;
--
⑨