ref: 44f448b2ffa12c7920f35a34d3dd6a44e784fe66
parent: b5a5f14036322957a8669de7f955220baabaa823
author: Jan Buethe <jan.buethe@gmx.net>
date: Sat Mar 15 12:45:07 EDT 2025
BBWENet C implementation + decoder integration
--- a/Makefile.am
+++ b/Makefile.am
@@ -326,6 +326,12 @@
lossgen_demo_LDADD = $(LIBM)
endif
+if ENABLE_OSCE
+noinst_PROGRAMS += bwe_demo
+bwe_demo_SOURCES = dnn/bwe_demo.c
+bwe_demo_LDADD = $(LPCNET_OBJ) $(CELT_OBJ) $(LIBM)
+endif
+
endif
--- a/autogen.sh
+++ b/autogen.sh
@@ -9,7 +9,7 @@
srcdir=`dirname $0`
test -n "$srcdir" && cd "$srcdir"
-dnn/download_model.sh "a86f0a9db852691d4335608733ec8384a407e585801ab9e4b490e0be297ac382"
+dnn/download_model.sh "4ec556dd87e63c17c4a805c40685ef3fe1fad7c8b26b123f2ede553b50158cb1"
echo "Updating build configuration files, please wait...."
--- a/configure.ac
+++ b/configure.ac
@@ -948,7 +948,8 @@
)
AS_IF([test "$enable_osce" = "yes" || test "$enable_osce_training_data" = "yes"], [
- AC_DEFINE([ENABLE_OSCE], [1], [Enable Opus Speech Coding Enhancement])
+ AC_DEFINE([ENABLE_OSCE], [1], [Enable Opus Speech Coding Enhancement]),
+ AC_DEFINE([ENABLE_OSCE_BWE], [1], [Enable Opus Speech Coding Enhancement Blind BWE])
])
AM_CONDITIONAL([ENABLE_OSCE], [test "$enable_osce" = "yes" || test "$enable_osce_training_data" = "yes"])
--- /dev/null
+++ b/dnn/bwe_demo.c
@@ -1,0 +1,102 @@
+/* Copyright (c) 2018 Mozilla */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
+ CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include <math.h>
+#include <stdio.h>
+#include <string.h>
+#include <stdlib.h>
+#include "arch.h"
+#include "lpcnet.h"
+#include "os_support.h"
+#include "cpu_support.h"
+#include "osce_features.h"
+#include "osce_structs.h"
+#include "osce.h"
+#include "silk/structs.h"
+
+
+void usage(void) {
+ fprintf(stderr, "usage: bwe_demo <input.pcm> <output.pcm>\n");
+ exit(1);
+}
+
+#define BWE_FRAME_SIZE 320
+
+int main(int argc, char **argv) {
+ int arch;
+ FILE *fin, *fout;
+ silk_OSCE_BWE_struct *hOSCEBWE;
+ OSCEModel *osce;
+
+
+ arch = opus_select_arch();
+ hOSCEBWE = calloc(1, sizeof(*hOSCEBWE));
+ osce = calloc(1, sizeof(*osce));
+ osce_load_models(osce, NULL, arch);
+ osce_bwe_reset(hOSCEBWE);
+
+ if (argc != 3) usage();
+ fin = fopen(argv[1], "rb");
+ if (fin == NULL) {
+ fprintf(stderr, "Can't open %s\n", argv[2]);
+ exit(1);
+ }
+
+ fout = fopen(argv[2], "wb");
+ if (fout == NULL) {
+ fprintf(stderr, "Can't open %s\n", argv[3]);
+ exit(1);
+ }
+
+ int16_t x_in[BWE_FRAME_SIZE];
+ int16_t x_out[3 * BWE_FRAME_SIZE];
+
+
+ while(fread(x_in, sizeof(x_in[0]), BWE_FRAME_SIZE, fin) == BWE_FRAME_SIZE) {
+ osce_bwe(
+ osce,
+ hOSCEBWE,
+ x_out,
+ x_in,
+ BWE_FRAME_SIZE,
+ arch
+ );
+
+ fwrite(x_out, sizeof(x_out[0]), 3 * BWE_FRAME_SIZE, fout);
+ }
+
+ free(hOSCEBWE);
+ free(osce);
+
+ fclose(fin);
+ fclose(fout);
+
+ return 0;
+}
\ No newline at end of file
--- a/dnn/nndsp.c
+++ b/dnn/nndsp.c
@@ -346,6 +346,7 @@
int feature_dim,
int frame_size,
int avg_pool_k,
+ int interpolate_k,
int arch
)
{
@@ -354,10 +355,12 @@
float tmp_buffer[ADASHAPE_MAX_FRAME_SIZE];
int i, k;
int tenv_size;
+ int hidden_dim = frame_size / interpolate_k;
float mean;
float *tenv;
celt_assert(frame_size % avg_pool_k == 0);
+ celt_assert(frame_size % interpolate_k == 0);
celt_assert(feature_dim + frame_size / avg_pool_k + 1 < ADASHAPE_MAX_INPUT_DIM);
tenv_size = frame_size / avg_pool_k;
@@ -396,8 +399,8 @@
#ifdef DEBUG_NNDSP
print_float_vector("alpha1_out", out_buffer, frame_size);
#endif
- /* compute leaky ReLU by hand. ToDo: try tanh activation */
- for (i = 0; i < frame_size; i ++)
+ /* compute leaky ReLU by hand. */
+ for (i = 0; i < hidden_dim; i ++)
{
float tmp = out_buffer[i] + tmp_buffer[i];
in_buffer[i] = tmp >= 0 ? tmp : 0.2 * tmp;
@@ -405,7 +408,26 @@
#ifdef DEBUG_NNDSP
print_float_vector("post_alpha1", in_buffer, frame_size);
#endif
- compute_generic_conv1d(alpha2, out_buffer, hAdaShape->conv_alpha2_state, in_buffer, frame_size, ACTIVATION_LINEAR, arch);
+ compute_generic_conv1d(alpha2, tmp_buffer, hAdaShape->conv_alpha2_state, in_buffer, hidden_dim, ACTIVATION_LINEAR, arch);
+
+#ifdef DEBUG_NNDSP
+ print_float_vector("alpha2_out", tmp_buffer, hidden_dim);
+#endif
+
+ /* upsampling by linear interpolation */
+ for (i = 0; i < hidden_dim; i ++)
+ {
+ for (k = 0; k < interpolate_k; k++)
+ {
+ float alpha = (float) (k + 1) / interpolate_k;
+ out_buffer[i * interpolate_k + k] = alpha * tmp_buffer[i] + (1.f - alpha) * hAdaShape->interpolate_state[0];
+ }
+ hAdaShape->interpolate_state[0] = tmp_buffer[i];
+ }
+
+#ifdef DEBUG_NNDSP
+ print_float_vector("interpolate_out", out_buffer, frame_size);
+#endif
/* shape signal */
for (i = 0; i < frame_size; i ++)
--- a/dnn/nndsp.h
+++ b/dnn/nndsp.h
@@ -33,11 +33,11 @@
#include <string.h>
-#define ADACONV_MAX_KERNEL_SIZE 16
-#define ADACONV_MAX_INPUT_CHANNELS 2
-#define ADACONV_MAX_OUTPUT_CHANNELS 2
-#define ADACONV_MAX_FRAME_SIZE 80
-#define ADACONV_MAX_OVERLAP_SIZE 40
+#define ADACONV_MAX_KERNEL_SIZE 32
+#define ADACONV_MAX_INPUT_CHANNELS 3
+#define ADACONV_MAX_OUTPUT_CHANNELS 3
+#define ADACONV_MAX_FRAME_SIZE 240
+#define ADACONV_MAX_OVERLAP_SIZE 120
#define ADACOMB_MAX_LAG 300
#define ADACOMB_MAX_KERNEL_SIZE 16
@@ -45,7 +45,7 @@
#define ADACOMB_MAX_OVERLAP_SIZE 40
#define ADASHAPE_MAX_INPUT_DIM 512
-#define ADASHAPE_MAX_FRAME_SIZE 160
+#define ADASHAPE_MAX_FRAME_SIZE 240
/*#define DEBUG_NNDSP*/
#ifdef DEBUG_NNDSP
@@ -74,6 +74,7 @@
float conv_alpha1f_state[ADASHAPE_MAX_INPUT_DIM];
float conv_alpha1t_state[ADASHAPE_MAX_INPUT_DIM];
float conv_alpha2_state[ADASHAPE_MAX_FRAME_SIZE];
+ float interpolate_state[1];
} AdaShapeState;
void init_adaconv_state(AdaConvState *hAdaConv);
@@ -137,6 +138,7 @@
int feature_dim,
int frame_size,
int avg_pool_k,
+ int interpolate_k,
int arch
);
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -64,7 +64,11 @@
}
#ifdef ENABLE_OSCE
+#ifdef ENABLE_OSCE_BWE
+#define MAX_RNN_NEURONS_ALL IMAX(IMAX(IMAX(IMAX(FARGAN_MAX_RNN_NEURONS, PLC_MAX_RNN_UNITS), DRED_MAX_RNN_NEURONS), OSCE_MAX_RNN_NEURONS), OSCE_BWE_MAX_RNN_NEURONS)
+#else
#define MAX_RNN_NEURONS_ALL IMAX(IMAX(IMAX(FARGAN_MAX_RNN_NEURONS, PLC_MAX_RNN_UNITS), DRED_MAX_RNN_NEURONS), OSCE_MAX_RNN_NEURONS)
+#endif
#else
#define MAX_RNN_NEURONS_ALL IMAX(IMAX(FARGAN_MAX_RNN_NEURONS, PLC_MAX_RNN_UNITS), DRED_MAX_RNN_NEURONS)
#endif
--- a/dnn/osce.c
+++ b/dnn/osce.c
@@ -37,15 +37,73 @@
#include "nndsp.h"
#include "float_cast.h"
#include "arch.h"
-
+/*#define OSCE_DEBUG*/
#ifdef OSCE_DEBUG
#include <stdio.h>
/*#define WRITE_FEATURES*/
/*#define DEBUG_LACE*/
/*#define DEBUG_NOLACE*/
+#define DEBUG_BBWENET
#define FINIT(fid, name, mode) do{if (fid == NULL) {fid = fopen(name, mode);}} while(0)
#endif
+#if 0
+#include <stdio.h>
+static void print_float_array(FILE *fid, const char *name, const float *array, int n)
+{
+ int i;
+ for (i = 0; i < n; i++)
+ {
+ fprintf(fid, "%s[%d]: %f\n", name, i, array[i]);
+ }
+}
+
+static void print_int_array(FILE *fid, const char *name, const int *array, int n)
+{
+ int i;
+ for (i = 0; i < n; i++)
+ {
+ fprintf(fid, "%s[%d]: %d\n", name, i, array[i]);
+ }
+}
+
+static void print_int8_array(FILE *fid, const char *name, const opus_int8 *array, int n)
+{
+ int i;
+ for (i = 0; i < n; i++)
+ {
+ fprintf(fid, "%s[%d]: %d\n", name, i, array[i]);
+ }
+}
+
+static void print_linear_layer(FILE *fid, const char *name, LinearLayer *layer)
+{
+ int i, n_in, n_out, n_total;
+ char tmp[256];
+
+ n_in = layer->nb_inputs;
+ n_out = layer->nb_outputs;
+ n_total = n_in * n_out;
+
+ fprintf(fid, "\nprinting layer %s...\n", name);
+ fprintf(fid, "%s.nb_inputs: %d\n%s.nb_outputs: %d\n", name, n_in, name, n_out);
+
+ if (layer->bias !=NULL){}
+ if (layer->subias !=NULL){}
+ if (layer->weights !=NULL){}
+ if (layer->float_weights !=NULL){}
+
+ if (layer->bias != NULL) {sprintf(tmp, "%s.bias", name); print_float_array(fid, tmp, layer->bias, n_out);}
+ if (layer->subias != NULL) {sprintf(tmp, "%s.subias", name); print_float_array(fid, tmp, layer->subias, n_out);}
+ if (layer->weights != NULL) {sprintf(tmp, "%s.weights", name); print_int8_array(fid, tmp, layer->weights, n_total);}
+ if (layer->float_weights != NULL) {sprintf(tmp, "%s.float_weights", name); print_float_array(fid, tmp, layer->float_weights, n_total);}
+ /*if (layer->weights_idx != NULL) {sprintf(tmp, "%s.weights_idx", name); print_float_array(fid, tmp, layer->weights_idx, n_total);}*/
+ if (layer->diag != NULL) {sprintf(tmp, "%s.diag", name); print_float_array(fid, tmp, layer->diag, n_in);}
+ if (layer->scale != NULL) {sprintf(tmp, "%s.scale", name); print_float_array(fid, tmp, layer->scale, n_out);}
+
+}
+#endif
+
#ifdef ENABLE_OSCE_TRAINING_DATA
#include <stdio.h>
#endif
@@ -54,6 +112,7 @@
extern const WeightArray lacelayers_arrays[];
extern const WeightArray nolacelayers_arrays[];
+extern const WeightArray bbwenetlayers_arrays[];
/* LACE */
@@ -639,6 +698,7 @@
NOLACE_TDSHAPE1_FEATURE_DIM,
NOLACE_TDSHAPE1_FRAME_SIZE,
NOLACE_TDSHAPE1_AVG_POOL_K,
+ 1,
arch
);
@@ -695,6 +755,7 @@
NOLACE_TDSHAPE2_FEATURE_DIM,
NOLACE_TDSHAPE2_FRAME_SIZE,
NOLACE_TDSHAPE2_AVG_POOL_K,
+ 1,
arch
);
@@ -747,6 +808,7 @@
NOLACE_TDSHAPE3_FEATURE_DIM,
NOLACE_TDSHAPE3_FRAME_SIZE,
NOLACE_TDSHAPE3_AVG_POOL_K,
+ 1,
arch
);
@@ -786,6 +848,523 @@
#endif /* #ifndef DISABLE_NOLACE */
+
+#ifdef ENABLE_OSCE_BWE
+#ifndef DISABLE_BBWENET
+static void bbwe_feature_net(
+ BBWENet *hBBWENET,
+ BBWENetState *state,
+ float *output,
+ const float *features,
+ int num_frames,
+ int arch
+)
+{
+ float input_buffer[4 * BBWENET_FNET_GRU_STATE_SIZE];
+ float output_buffer[4 * BBWENET_FNET_GRU_STATE_SIZE];
+ int i_subframe;
+ int i_frame;
+
+#ifdef DEBUG_BBWENET
+ static FILE *f_features=NULL, *f_conv1=NULL, *f_conv2=NULL, *f_tconv=NULL, *f_gru=NULL;
+
+ FINIT(f_features, "debug/bbwenet_features.f32", "wb");
+ FINIT(f_conv1, "debug/bbwenet_conv1.f32", "wb");
+ FINIT(f_conv2, "debug/bbwenet_conv2.f32", "wb");
+ FINIT(f_tconv, "debug/bbwenet_tconv.f32", "wb");
+ FINIT(f_gru, "debug/bbwenet_gru.f32", "wb");
+
+ fwrite(features, sizeof(*features), num_frames * BBWENET_FEATURE_DIM, f_features);
+#endif
+
+ /* adjust buffer sizes if any of this breaks */
+ celt_assert(BBWENET_FNET_GRU_STATE_SIZE == BBWENET_FNET_TCONV_OUT_CHANNELS);
+ celt_assert(BBWENET_FNET_TCONV_OUT_CHANNELS == BBWENET_FNET_CONV2_OUT_SIZE);
+ celt_assert(BBWENET_FNET_CONV2_OUT_SIZE == BBWENET_FNET_CONV1_OUT_SIZE);
+
+ /* first conv layer */
+ for (i_frame = 0; i_frame < num_frames; i_frame++)
+ {
+ compute_generic_conv1d(
+ &hBBWENET->layers.bbwenet_fnet_conv1,
+ output_buffer + i_frame * BBWENET_FNET_CONV1_OUT_SIZE,
+ state->feature_net_conv1_state,
+ features + i_frame * BBWENET_FEATURE_DIM,
+ BBWENET_FEATURE_DIM,
+ ACTIVATION_TANH,
+ arch
+ );
+#ifdef DEBUG_BBWENET
+ fwrite(output_buffer + i_frame * BBWENET_FNET_CONV1_OUT_SIZE, sizeof(float), BBWENET_FNET_CONV1_OUT_SIZE, f_conv1);
+#endif
+ }
+ OPUS_COPY(input_buffer, output_buffer, num_frames * BBWENET_FNET_CONV1_OUT_SIZE);
+
+ /* second conv layer */
+ for (i_frame = 0; i_frame < num_frames; i_frame++)
+ {
+ compute_generic_conv1d(
+ &hBBWENET->layers.bbwenet_fnet_conv2,
+ output_buffer + i_frame * BBWENET_FNET_CONV2_OUT_SIZE,
+ state->feature_net_conv2_state,
+ input_buffer + i_frame * BBWENET_FNET_CONV1_OUT_SIZE,
+ BBWENET_FNET_CONV1_OUT_SIZE,
+ ACTIVATION_TANH,
+ arch
+ );
+#ifdef DEBUG_BBWENET
+ fwrite(output_buffer + i_frame * BBWENET_FNET_CONV2_OUT_SIZE, sizeof(float), BBWENET_FNET_CONV2_OUT_SIZE, f_conv2);
+#endif
+ }
+ OPUS_COPY(input_buffer, output_buffer, num_frames * BBWENET_FNET_CONV2_OUT_SIZE);
+
+ /* tconv upsampling*/
+ for (i_frame = 0; i_frame < num_frames; i_frame++)
+ {
+ compute_generic_dense(
+ &hBBWENET->layers.bbwenet_fnet_tconv,
+ output_buffer + i_frame * BBWENET_FNET_TCONV_OUT_CHANNELS * BBWENET_FNET_TCONV_STRIDE,
+ input_buffer + i_frame * BBWENET_FNET_CONV2_OUT_SIZE,
+ ACTIVATION_TANH,
+ arch
+ );
+#ifdef DEBUG_BBWENET
+ fwrite(output_buffer + i_frame * BBWENET_FNET_TCONV_OUT_CHANNELS * BBWENET_FNET_TCONV_STRIDE, sizeof(float), BBWENET_FNET_TCONV_OUT_CHANNELS * BBWENET_FNET_TCONV_STRIDE, f_tconv);
+#endif
+ }
+ OPUS_COPY(input_buffer, output_buffer, num_frames * BBWENET_FNET_TCONV_OUT_CHANNELS * BBWENET_FNET_TCONV_STRIDE);
+
+ /* GRU */
+ celt_assert(BBWENET_FNET_TCONV_STRIDE == 2)
+ for (i_subframe = 0; i_subframe < BBWENET_FNET_TCONV_STRIDE * num_frames; i_subframe ++)
+ {
+ compute_generic_gru(
+ &hBBWENET->layers.bbwenet_fnet_gru_input,
+ &hBBWENET->layers.bbwenet_fnet_gru_recurrent,
+ state->feature_net_gru_state,
+ input_buffer + i_subframe * BBWENET_FNET_TCONV_OUT_CHANNELS,
+ arch
+ );
+#ifdef DEBUG_BBWENET
+ fwrite(state->feature_net_gru_state, sizeof(float), BBWENET_FNET_GRU_STATE_SIZE, f_gru);
+#endif
+ OPUS_COPY(output + i_subframe * BBWENET_FNET_GRU_STATE_SIZE, state->feature_net_gru_state, BBWENET_FNET_GRU_STATE_SIZE);
+ }
+}
+
+static float hq_2x_even[3] = {0.026641845703125, 0.228668212890625, -0.4036407470703125};
+static float hq_2x_odd[3] = {0.104583740234375, 0.3932037353515625, -0.152496337890625};
+
+static float frac_01_24[8] = {
+ 0.00576782, -0.01831055, 0.01882935, 0.9328308,
+ 0.09143066, -0.04196167, 0.01296997, -0.00140381
+};
+
+static float frac_17_24[8] = {
+ -3.14331055e-03, 2.73437500e-02, -1.06414795e-01, 3.64685059e-01,
+ 8.03863525e-01, -1.02233887e-01, 1.61437988e-02, -1.22070312e-04
+};
+
+static float frac_09_24[8] = {
+ -0.00146484, 0.02313232, -0.12072754, 0.7315979,
+ 0.4621277, -0.12075806, 0.0295105 , -0.00326538
+};
+
+static void apply_valin_activation(float *x, int len)
+{
+ int i;
+ for (i = 0; i < len; i++)
+ {
+ x[i] *= sin(log(fabs(x[i]) + 1e-6f));
+ }
+}
+
+
+#define DELAY_SAMPLES 8 /* ToDo: this probably should be 7, bug in python code? */
+static void interpol_3_2(resamp_state *state, float *x_out, const float *x_in, int num_samples)
+{
+ int i_sample, i_out = 0;
+ float buffer[8 * BBWENET_FRAME_SIZE16 + DELAY_SAMPLES];
+
+ celt_assert(num_samples > 1);
+ celt_assert(num_samples < 8 * BBWENET_FRAME_SIZE16);
+ celt_assert(num_samples % 2 == 0);
+
+ OPUS_COPY(buffer, state->interpol_buffer, DELAY_SAMPLES);
+ OPUS_COPY(buffer + DELAY_SAMPLES, x_in, num_samples);
+
+ for (i_sample = 0; i_sample < num_samples; i_sample+=2)
+ {
+ x_out[i_out++] = buffer[i_sample + 0] * frac_01_24[0] +
+ buffer[i_sample + 1] * frac_01_24[1] +
+ buffer[i_sample + 2] * frac_01_24[2] +
+ buffer[i_sample + 3] * frac_01_24[3] +
+ buffer[i_sample + 4] * frac_01_24[4] +
+ buffer[i_sample + 5] * frac_01_24[5] +
+ buffer[i_sample + 6] * frac_01_24[6] +
+ buffer[i_sample + 7] * frac_01_24[7];
+
+ x_out[i_out++] = buffer[i_sample + 0] * frac_17_24[0] +
+ buffer[i_sample + 1] * frac_17_24[1] +
+ buffer[i_sample + 2] * frac_17_24[2] +
+ buffer[i_sample + 3] * frac_17_24[3] +
+ buffer[i_sample + 4] * frac_17_24[4] +
+ buffer[i_sample + 5] * frac_17_24[5] +
+ buffer[i_sample + 6] * frac_17_24[6] +
+ buffer[i_sample + 7] * frac_17_24[7];
+
+ x_out[i_out++] = buffer[i_sample + 1] * frac_09_24[0] +
+ buffer[i_sample + 2] * frac_09_24[1] +
+ buffer[i_sample + 3] * frac_09_24[2] +
+ buffer[i_sample + 4] * frac_09_24[3] +
+ buffer[i_sample + 5] * frac_09_24[4] +
+ buffer[i_sample + 6] * frac_09_24[5] +
+ buffer[i_sample + 7] * frac_09_24[6] +
+ buffer[i_sample + 8] * frac_09_24[7];
+ }
+
+ /* copy last samples to buffer */
+ OPUS_COPY(state->interpol_buffer, buffer + num_samples, DELAY_SAMPLES);
+}
+
+static void upsamp_2x(resamp_state *state, float *x_out, const float *x_in, int num_samples)
+{
+ float buffer [4 * BBWENET_FRAME_SIZE16];
+ float *S_even = state->upsamp_buffer[0];
+ float *S_odd = state->upsamp_buffer[1];
+ int k;
+ float x, X, Y, tmp1, tmp2, tmp3;
+
+ celt_assert(num_samples > 1);
+ celt_assert(num_samples < 4 * BBWENET_FRAME_SIZE16);
+
+ OPUS_COPY(buffer, x_in, num_samples);
+
+ for (k = 0; k < num_samples; k++)
+ {
+ x = buffer[k];
+ /* even sample, first pass, */
+ Y = x - S_even[0];
+ X = Y * hq_2x_even[0];
+ tmp1 = S_even[0] + X;
+ S_even[0] = x + X;
+
+ /* ...second pass, */
+ Y = tmp1 - S_even[1];
+ X = Y * hq_2x_even[1];
+ tmp2 = S_even[1] + X;
+ S_even[1] = tmp1 + X;
+
+ /* ...third pass */
+ Y = tmp2 - S_even[2];
+ X = Y * (1 + hq_2x_even[2]);
+ tmp3 = S_even[2] + X;
+ S_even[2] = tmp2 + X;
+
+ x_out[2 * k] = tmp3;
+
+ /* odd sample, first pass, */
+ Y = x - S_odd[0];
+ X = Y * hq_2x_odd[0];
+ tmp1 = S_odd[0] + X;
+ S_odd[0] = x + X;
+
+ /* ...second pass, */
+ Y = tmp1 - S_odd[1];
+ X = Y * hq_2x_odd[1];
+ tmp2 = S_odd[1] + X;
+ S_odd[1] = tmp1 + X;
+
+ /* ...third pass */
+ Y = tmp2 - S_odd[2];
+ X = Y * (1 + hq_2x_odd[2]);
+ tmp3 = S_odd[2] + X;
+ S_odd[2] = tmp2 + X;
+
+ x_out[2 * k + 1] = tmp3;
+ }
+}
+
+static void bbwenet_process_frames(
+ BBWENet *hBBWENET,
+ BBWENetState *state,
+ float *x_out,
+ const float *x_in,
+ const float *features,
+ int num_frames,
+ int arch
+)
+{
+ float latent_features[4 * BBWENET_COND_DIM];
+ int i_subframe, num_subframes = 2 * num_frames, i_channel;
+ float x_buffer1[3 * 3 * 4 * 3*BBWENET_FRAME_SIZE16] = {0}; /* 3x3 channels, 4 subframes, 48 kHz */
+ float x_buffer2[3 * 3 * 4 * 3*BBWENET_FRAME_SIZE16] = {0};
+ BBWENETLayers *layers = &hBBWENET->layers;
+
+#ifdef DEBUG_BBWENET
+ static FILE *f_latent=NULL, *f_xin=NULL, *f_af1_1=NULL, *f_af1_2=NULL, *f_af1_3=NULL;
+ static FILE *f_up2_1=NULL, *f_up2_2=NULL, *f_up2_3=NULL, *f2_up_shape=NULL, *f2_up_func=NULL;
+ static FILE *f_af2_1=NULL, *f_af2_2=NULL, *f_af2_3=NULL;
+ static FILE *f_up15_1=NULL, *f_up15_2=NULL, *f_up15_3=NULL;
+ static FILE *f_up15_shape=NULL, *f_up15_func=NULL;
+ static FILE *f_af3_1=NULL;
+
+ FINIT(f_latent, "dnn/torch/osce/debugdump/feature_net_gru.f32", "rb");
+ FINIT(f_xin, "debug/bbwenet_x_in.f32", "wb");
+ FINIT(f_af1_1, "debug/bbwenet_af1_1.f32", "wb");
+ FINIT(f_af1_2, "debug/bbwenet_af1_2.f32", "wb");
+ FINIT(f_af1_3, "debug/bbwenet_af1_3.f32", "wb");
+ FINIT(f_up2_1, "debug/bbwenet_up2_1.f32", "wb");
+ FINIT(f_up2_2, "debug/bbwenet_up2_2.f32", "wb");
+ FINIT(f_up2_3, "debug/bbwenet_up2_3.f32", "wb");
+ FINIT(f2_up_func, "debug/bbwenet_up2_func.f32", "wb");
+ FINIT(f2_up_shape, "debug/bbwenet_up2_shape.f32", "wb");
+ FINIT(f_af2_1, "debug/bbwenet_af2_1.f32", "wb");
+ FINIT(f_af2_2, "debug/bbwenet_af2_2.f32", "wb");
+ FINIT(f_af2_3, "debug/bbwenet_af2_3.f32", "wb");
+ FINIT(f_up15_1, "debug/bbwenet_up15_1.f32", "wb");
+ FINIT(f_up15_2, "debug/bbwenet_up15_2.f32", "wb");
+ FINIT(f_up15_3, "debug/bbwenet_up15_3.f32", "wb");
+ FINIT(f_up15_shape, "debug/bbwenet_up15_shape.f32", "wb");
+ FINIT(f_up15_func, "debug/bbwenet_up15_func.f32", "wb");
+ FINIT(f_af3_1, "debug/bbwenet_af3_1.f32", "wb");
+ fwrite(x_in, sizeof(*x_in), num_subframes * BBWENET_AF1_FRAME_SIZE, f_xin);
+#endif
+
+ /* feature net */
+ bbwe_feature_net(hBBWENET, state, latent_features, features, num_frames, arch);
+#ifdef DEBUG_BBWENET
+ if (f_latent != NULL){
+ fread(latent_features, sizeof(*latent_features), num_subframes * BBWENET_COND_DIM, f_latent);
+ }
+#endif
+
+
+ /* signal net
+ * first adaptive filtering stage, three output channels */
+ for (i_subframe = 0; i_subframe < num_subframes; i_subframe++)
+ {
+ adaconv_process_frame(
+ &state->af1_state,
+ x_buffer1 + i_subframe * BBWENET_AF1_FRAME_SIZE * BBWENET_AF1_OUT_CHANNELS,
+ x_in + i_subframe * BBWENET_AF1_FRAME_SIZE,
+ latent_features + i_subframe * BBWENET_COND_DIM,
+ &layers->bbwenet_af1_kernel,
+ &layers->bbwenet_af1_gain,
+ BBWENET_COND_DIM,
+ BBWENET_AF1_FRAME_SIZE,
+ BBWENET_AF1_OVERLAP_SIZE,
+ BBWENET_AF1_IN_CHANNELS,
+ BBWENET_AF1_OUT_CHANNELS,
+ BBWENET_AF1_KERNEL_SIZE,
+ BBWENET_AF1_LEFT_PADDING,
+ BBWENET_AF1_FILTER_GAIN_A,
+ BBWENET_AF1_FILTER_GAIN_B,
+ BBWENET_AF1_SHAPE_GAIN,
+ hBBWENET->window16,
+ arch);
+
+#ifdef DEBUG_BBWENET
+ fwrite(x_buffer1 + i_subframe * BBWENET_AF1_FRAME_SIZE * BBWENET_AF1_OUT_CHANNELS, sizeof(float), BBWENET_AF1_FRAME_SIZE, f_af1_1);
+ fwrite(x_buffer1 + i_subframe * BBWENET_AF1_FRAME_SIZE * BBWENET_AF1_OUT_CHANNELS + BBWENET_AF1_FRAME_SIZE, sizeof(float), BBWENET_AF1_FRAME_SIZE, f_af1_2);
+ fwrite(x_buffer1 + i_subframe * BBWENET_AF1_FRAME_SIZE * BBWENET_AF1_OUT_CHANNELS + 2 * BBWENET_AF1_FRAME_SIZE, sizeof(float), BBWENET_AF1_FRAME_SIZE, f_af1_3);
+#endif
+ }
+
+ /* 1st round of non-linear extension */
+ for (i_subframe = 0; i_subframe < num_subframes; i_subframe++)
+ {
+
+ /* 2x upsampling on individual channels */
+ celt_assert(BBWENET_AF1_OUT_CHANNELS == 3);
+ celt_assert(2 * BBWENET_AF1_FRAME_SIZE == BBWENET_TDSHAPE1_FRAME_SIZE);
+ for (i_channel = 0; i_channel < 3; i_channel ++)
+ {
+ upsamp_2x(
+ &state->resampler_state[i_channel],
+ x_buffer2 + i_subframe * BBWENET_TDSHAPE1_FRAME_SIZE * BBWENET_AF1_OUT_CHANNELS + i_channel * BBWENET_TDSHAPE1_FRAME_SIZE,
+ x_buffer1 + i_subframe * BBWENET_AF1_FRAME_SIZE * BBWENET_AF1_OUT_CHANNELS + i_channel * BBWENET_AF1_FRAME_SIZE,
+ BBWENET_AF1_FRAME_SIZE
+ );
+ }
+
+#ifdef DEBUG_BBWENET
+ fwrite(x_buffer2 + i_subframe * BBWENET_AF1_OUT_CHANNELS * BBWENET_TDSHAPE1_FRAME_SIZE, sizeof(float), BBWENET_TDSHAPE1_FRAME_SIZE, f_up2_1);
+ fwrite(x_buffer2 + i_subframe * BBWENET_AF1_OUT_CHANNELS * BBWENET_TDSHAPE1_FRAME_SIZE + BBWENET_TDSHAPE1_FRAME_SIZE, sizeof(float), BBWENET_TDSHAPE1_FRAME_SIZE, f_up2_2);
+ fwrite(x_buffer2 + i_subframe * BBWENET_AF1_OUT_CHANNELS * BBWENET_TDSHAPE1_FRAME_SIZE + 2 * BBWENET_TDSHAPE1_FRAME_SIZE, sizeof(float), BBWENET_TDSHAPE1_FRAME_SIZE, f_up2_3);
+#endif
+
+ /* tdshape on second channel (in place) */
+ adashape_process_frame(
+ &state->tdshape1_state,
+ x_buffer2 + i_subframe * BBWENET_AF1_OUT_CHANNELS * BBWENET_TDSHAPE1_FRAME_SIZE + BBWENET_TDSHAPE1_FRAME_SIZE,
+ x_buffer2 + i_subframe * BBWENET_AF1_OUT_CHANNELS * BBWENET_TDSHAPE1_FRAME_SIZE + BBWENET_TDSHAPE1_FRAME_SIZE,
+ latent_features + i_subframe * BBWENET_COND_DIM,
+ &layers->bbwenet_tdshape1_alpha1_f,
+ &layers->bbwenet_tdshape1_alpha1_t,
+ &layers->bbwenet_tdshape1_alpha2,
+ BBWENET_TDSHAPE1_FEATURE_DIM,
+ BBWENET_TDSHAPE1_FRAME_SIZE,
+ BBWENET_TDSHAPE1_AVG_POOL_K,
+ BBWENET_TDSHAPE1_INTERPOLATE_K,
+ arch
+ );
+
+#ifdef DEBUG_BBWENET
+ fwrite(x_buffer2 + i_subframe * BBWENET_AF1_OUT_CHANNELS * BBWENET_TDSHAPE1_FRAME_SIZE + BBWENET_TDSHAPE1_FRAME_SIZE, sizeof(float), BBWENET_TDSHAPE1_FRAME_SIZE, f2_up_shape);
+#endif
+
+ /* non-linear activation of third channel (in place)*/
+ apply_valin_activation(
+ x_buffer2 + i_subframe * BBWENET_AF1_OUT_CHANNELS * BBWENET_TDSHAPE1_FRAME_SIZE + 2 * BBWENET_TDSHAPE1_FRAME_SIZE,
+ BBWENET_TDSHAPE1_FRAME_SIZE
+ );
+#ifdef DEBUG_BBWENET
+ fwrite(x_buffer2 + i_subframe * BBWENET_AF1_OUT_CHANNELS * BBWENET_TDSHAPE1_FRAME_SIZE + 2 * BBWENET_TDSHAPE1_FRAME_SIZE, sizeof(float), BBWENET_TDSHAPE1_FRAME_SIZE, f2_up_func);
+#endif
+
+ }
+
+ /* mixing */
+ for (i_subframe = 0; i_subframe < num_subframes; i_subframe++)
+ {
+ adaconv_process_frame(
+ &state->af2_state,
+ x_buffer1 + i_subframe * BBWENET_AF2_FRAME_SIZE * BBWENET_AF2_OUT_CHANNELS,
+ x_buffer2 + i_subframe * BBWENET_AF2_FRAME_SIZE * BBWENET_AF1_OUT_CHANNELS,
+ latent_features + i_subframe * BBWENET_COND_DIM,
+ &layers->bbwenet_af2_kernel,
+ &layers->bbwenet_af2_gain,
+ BBWENET_COND_DIM,
+ BBWENET_AF2_FRAME_SIZE,
+ BBWENET_AF2_OVERLAP_SIZE,
+ BBWENET_AF2_IN_CHANNELS,
+ BBWENET_AF2_OUT_CHANNELS,
+ BBWENET_AF2_KERNEL_SIZE,
+ BBWENET_AF2_LEFT_PADDING,
+ BBWENET_AF2_FILTER_GAIN_A,
+ BBWENET_AF2_FILTER_GAIN_B,
+ BBWENET_AF2_SHAPE_GAIN,
+ hBBWENET->window32,
+ arch);
+#ifdef DEBUG_BBWENET
+ fwrite(x_buffer1 + i_subframe * BBWENET_AF2_FRAME_SIZE * BBWENET_AF2_OUT_CHANNELS, sizeof(float), BBWENET_AF2_FRAME_SIZE, f_af2_1);
+ fwrite(x_buffer1 + i_subframe * BBWENET_AF2_FRAME_SIZE * BBWENET_AF2_OUT_CHANNELS + BBWENET_AF2_FRAME_SIZE, sizeof(float), BBWENET_AF2_FRAME_SIZE, f_af2_2);
+ fwrite(x_buffer1 + i_subframe * BBWENET_AF2_FRAME_SIZE * BBWENET_AF2_OUT_CHANNELS + 2 * BBWENET_AF2_FRAME_SIZE, sizeof(float), BBWENET_AF2_FRAME_SIZE, f_af2_3);
+#endif
+ }
+
+ /* second round of extension */
+ for (i_subframe = 0; i_subframe < num_subframes; i_subframe++)
+ {
+ /* 1.5x interpolation on individual channels */
+ celt_assert(BBWENET_AF2_OUT_CHANNELS == 3);
+ celt_assert(3 * BBWENET_AF2_FRAME_SIZE == 2 * BBWENET_TDSHAPE2_FRAME_SIZE);
+ for (i_channel = 0; i_channel < 3; i_channel ++)
+ {
+ interpol_3_2(
+ &state->resampler_state[i_channel],
+ x_buffer2 + i_subframe * BBWENET_AF3_FRAME_SIZE * BBWENET_AF2_OUT_CHANNELS + i_channel * BBWENET_TDSHAPE2_FRAME_SIZE,
+ x_buffer1 + i_subframe * BBWENET_TDSHAPE1_FRAME_SIZE * BBWENET_AF2_OUT_CHANNELS + i_channel * BBWENET_TDSHAPE1_FRAME_SIZE,
+ BBWENET_TDSHAPE1_FRAME_SIZE
+ );
+ }
+
+#ifdef DEBUG_BBWENET
+ fwrite(x_buffer2 + i_subframe * BBWENET_AF2_OUT_CHANNELS * BBWENET_TDSHAPE2_FRAME_SIZE, sizeof(float), BBWENET_TDSHAPE2_FRAME_SIZE, f_up15_1);
+ fwrite(x_buffer2 + i_subframe * BBWENET_AF2_OUT_CHANNELS * BBWENET_TDSHAPE2_FRAME_SIZE + BBWENET_TDSHAPE2_FRAME_SIZE, sizeof(float), BBWENET_TDSHAPE2_FRAME_SIZE, f_up15_2);
+ fwrite(x_buffer2 + i_subframe * BBWENET_AF2_OUT_CHANNELS * BBWENET_TDSHAPE2_FRAME_SIZE + 2 * BBWENET_TDSHAPE2_FRAME_SIZE, sizeof(float), BBWENET_TDSHAPE2_FRAME_SIZE, f_up15_3);
+#endif
+
+ /* tdshape on second channel (in place) */
+ adashape_process_frame(
+ &state->tdshape2_state,
+ x_buffer2 + i_subframe * BBWENET_AF2_OUT_CHANNELS * BBWENET_TDSHAPE2_FRAME_SIZE + BBWENET_TDSHAPE2_FRAME_SIZE,
+ x_buffer2 + i_subframe * BBWENET_AF2_OUT_CHANNELS * BBWENET_TDSHAPE2_FRAME_SIZE + BBWENET_TDSHAPE2_FRAME_SIZE,
+ latent_features + i_subframe * BBWENET_COND_DIM,
+ &layers->bbwenet_tdshape2_alpha1_f,
+ &layers->bbwenet_tdshape2_alpha1_t,
+ &layers->bbwenet_tdshape2_alpha2,
+ BBWENET_TDSHAPE2_FEATURE_DIM,
+ BBWENET_TDSHAPE2_FRAME_SIZE,
+ BBWENET_TDSHAPE2_AVG_POOL_K,
+ BBWENET_TDSHAPE2_INTERPOLATE_K,
+ arch
+ );
+
+#ifdef DEBUG_BBWENET
+ fwrite(x_buffer2 + i_subframe * BBWENET_AF2_OUT_CHANNELS * BBWENET_TDSHAPE2_FRAME_SIZE + BBWENET_TDSHAPE2_FRAME_SIZE, sizeof(float), BBWENET_TDSHAPE2_FRAME_SIZE, f_up15_shape);
+#endif
+
+ /* non-linear activation of third channel (in place)*/
+ apply_valin_activation(
+ x_buffer2 + i_subframe * BBWENET_AF2_OUT_CHANNELS * BBWENET_TDSHAPE2_FRAME_SIZE + 2 * BBWENET_TDSHAPE2_FRAME_SIZE,
+ BBWENET_TDSHAPE2_FRAME_SIZE
+ );
+#ifdef DEBUG_BBWENET
+ fwrite(x_buffer2 + i_subframe * BBWENET_AF2_OUT_CHANNELS * BBWENET_TDSHAPE2_FRAME_SIZE + 2 * BBWENET_TDSHAPE2_FRAME_SIZE, sizeof(float), BBWENET_TDSHAPE2_FRAME_SIZE, f_up15_func);
+#endif
+ }
+
+ /* final mixing */
+ celt_assert(BBWENET_AF3_OUT_CHANNELS == 1);
+ for (i_subframe = 0; i_subframe < num_subframes; i_subframe++)
+ {
+ adaconv_process_frame(
+ &state->af3_state,
+ x_out + i_subframe * BBWENET_AF3_FRAME_SIZE,
+ x_buffer2 + i_subframe * BBWENET_TDSHAPE2_FRAME_SIZE * BBWENET_AF2_OUT_CHANNELS,
+ latent_features + i_subframe * BBWENET_COND_DIM,
+ &layers->bbwenet_af3_kernel,
+ &layers->bbwenet_af3_gain,
+ BBWENET_COND_DIM,
+ BBWENET_AF3_FRAME_SIZE,
+ BBWENET_AF3_OVERLAP_SIZE,
+ BBWENET_AF3_IN_CHANNELS,
+ BBWENET_AF3_OUT_CHANNELS,
+ BBWENET_AF3_KERNEL_SIZE,
+ BBWENET_AF3_LEFT_PADDING,
+ BBWENET_AF3_FILTER_GAIN_A,
+ BBWENET_AF3_FILTER_GAIN_B,
+ BBWENET_AF3_SHAPE_GAIN,
+ hBBWENET->window48,
+ arch);
+ }
+
+#ifdef DEBUG_BBWENET
+ fwrite(x_out, sizeof(float), num_subframes * BBWENET_AF3_FRAME_SIZE, f_af3_1);
+#endif
+}
+
+static void reset_bbwenet_state(BBWENetState *state)
+{
+ OPUS_CLEAR(state, 1);
+
+ init_adaconv_state(&state->af1_state);
+ init_adaconv_state(&state->af2_state);
+ init_adaconv_state(&state->af3_state);
+ init_adashape_state(&state->tdshape1_state);
+ init_adashape_state(&state->tdshape2_state);
+}
+
+static int init_bbwenet(BBWENet *hBBWENET, const WeightArray *weights)
+{
+ int ret = 0;
+ OPUS_CLEAR(hBBWENET, 1);
+ celt_assert(weights != NULL);
+
+ ret = init_bbwenetlayers(&hBBWENET->layers, weights);
+
+ compute_overlap_window(hBBWENET->window16, BBWENET_AF1_OVERLAP_SIZE);
+ compute_overlap_window(hBBWENET->window32, BBWENET_AF2_OVERLAP_SIZE);
+ compute_overlap_window(hBBWENET->window48, BBWENET_AF3_OVERLAP_SIZE);
+
+ return ret;
+}
+
+#endif
+#endif /* ENABLE_OSCE_BWE */
+
/* API */
void osce_reset(silk_OSCE_struct *hOSCE, int method)
@@ -815,64 +1394,26 @@
hOSCE->features.reset = 2;
}
+#ifdef ENABLE_OSCE_BWE
-#if 0
-#include <stdio.h>
-static void print_float_array(FILE *fid, const char *name, const float *array, int n)
+void osce_bwe_reset(silk_OSCE_BWE_struct *hOSCEBWE)
{
- int i;
- for (i = 0; i < n; i++)
+ int k;
+ OPUS_CLEAR(&hOSCEBWE->features, 1);
+#if 1
+ /* weird python initialization: Fix eventually! */
+ for (k = 0; k <= OSCE_BWE_MAX_INSTAFREQ_BIN; k ++)
{
- fprintf(fid, "%s[%d]: %f\n", name, i, array[i]);
+ hOSCEBWE->features.last_spec[2*k] = 1e-9;
}
+#endif
+ reset_bbwenet_state(&hOSCEBWE->state.bbwenet);
}
-static void print_int_array(FILE *fid, const char *name, const int *array, int n)
-{
- int i;
- for (i = 0; i < n; i++)
- {
- fprintf(fid, "%s[%d]: %d\n", name, i, array[i]);
- }
-}
+#endif /* ENABLE_OSCE_BWE */
-static void print_int8_array(FILE *fid, const char *name, const opus_int8 *array, int n)
-{
- int i;
- for (i = 0; i < n; i++)
- {
- fprintf(fid, "%s[%d]: %d\n", name, i, array[i]);
- }
-}
-static void print_linear_layer(FILE *fid, const char *name, LinearLayer *layer)
-{
- int i, n_in, n_out, n_total;
- char tmp[256];
- n_in = layer->nb_inputs;
- n_out = layer->nb_outputs;
- n_total = n_in * n_out;
-
- fprintf(fid, "\nprinting layer %s...\n", name);
- fprintf(fid, "%s.nb_inputs: %d\n%s.nb_outputs: %d\n", name, n_in, name, n_out);
-
- if (layer->bias !=NULL){}
- if (layer->subias !=NULL){}
- if (layer->weights !=NULL){}
- if (layer->float_weights !=NULL){}
-
- if (layer->bias != NULL) {sprintf(tmp, "%s.bias", name); print_float_array(fid, tmp, layer->bias, n_out);}
- if (layer->subias != NULL) {sprintf(tmp, "%s.subias", name); print_float_array(fid, tmp, layer->subias, n_out);}
- if (layer->weights != NULL) {sprintf(tmp, "%s.weights", name); print_int8_array(fid, tmp, layer->weights, n_total);}
- if (layer->float_weights != NULL) {sprintf(tmp, "%s.float_weights", name); print_float_array(fid, tmp, layer->float_weights, n_total);}
- //if (layer->weights_idx != NULL) {sprintf(tmp, "%s.weights_idx", name); print_float_array(fid, tmp, layer->weights_idx, n_total);}
- if (layer->diag != NULL) {sprintf(tmp, "%s.diag", name); print_float_array(fid, tmp, layer->diag, n_in);}
- if (layer->scale != NULL) {sprintf(tmp, "%s.scale", name); print_float_array(fid, tmp, layer->scale, n_out);}
-
-}
-#endif
-
int osce_load_models(OSCEModel *model, const void *data, int len)
{
int ret = 0;
@@ -891,6 +1432,11 @@
if (ret == 0) {ret = init_nolace(&model->nolace, list);}
#endif
+#ifdef ENABLE_OSCE_BWE
+#ifndef DISABLE_BBWENET
+ if (ret == 0) {ret = init_bbwenet(&model->bbwenet, list);}
+#endif
+#endif /* ENABLE_OSCE_BWE */
free(list);
} else
{
@@ -905,6 +1451,11 @@
if (ret == 0) {ret = init_nolace(&model->nolace, nolacelayers_arrays);}
#endif
+#ifdef ENABLE_OSCE_BWE
+#ifndef DISABLE_BBWENET
+ if (ret == 0) {ret = init_bbwenet(&model->bbwenet, bbwenetlayers_arrays);}
+#endif
+#endif /* ENABLE_OSCE_BWE */
#endif /* USE_WEIGHTS_FILE */
}
@@ -911,6 +1462,75 @@
ret = ret ? -1 : 0;
return ret;
}
+
+#ifdef ENABLE_OSCE_BWE
+void osce_bwe(
+ OSCEModel *model, /* I OSCE model struct */
+ silk_OSCE_BWE_struct *psOSCEBWE, /* I/O OSCE BWE state */
+ opus_int16 xq48[], /* O bandwidth-extended speech */
+ opus_int16 xq16[], /* I Decoded speech */
+ opus_int32 xq16_len, /* I Length of xq16 in samples */
+ int arch /* I Run-time architecture */
+ )
+ {
+ float in_buffer[320];
+ float out_buffer[3*320];
+ float features[2 * OSCE_BWE_FEATURE_DIM];
+ int num_frames, i;
+
+ /* currently restricting to 10 or 20-ms frames */
+ celt_assert(xq16_len == 160 || xq16_len == 320);
+
+ num_frames = xq16_len / 160;
+
+ /* scale input */
+ for (i = 0; i < xq16_len; i++)
+ {
+ in_buffer[i] = ((float) xq16[i]) * (1.f/32768.f);
+ }
+
+ osce_bwe_calculate_features(&psOSCEBWE->features, features, xq16, xq16_len);
+
+#if 0
+ /* just upsampling for now */
+ upsamp_2x(&psOSCEBWE->state.bbwenet.resampler_state[0], out_buffer, in_buffer, xq16_len);
+ interpol_3_2(&psOSCEBWE->state.bbwenet.resampler_state[0], out_buffer, out_buffer, 2 * xq16_len);
+
+#else
+ /* process frames */
+ bbwenet_process_frames(
+ &model->bbwenet,
+ &psOSCEBWE->state.bbwenet,
+ out_buffer,
+ in_buffer,
+ features,
+ num_frames,
+ arch
+ );
+#endif
+
+ /* scale and delay output */
+ OPUS_COPY(xq48, psOSCEBWE->state.bbwenet.outbut_buffer, OSCE_BWE_OUTPUT_DELAY);
+ for (i = 0; i < 3 * xq16_len - OSCE_BWE_OUTPUT_DELAY; i++)
+ {
+ float tmp = 32768.f * out_buffer[i];
+ if (tmp > 32767.f) tmp = 32767.f;
+ if (tmp < -32767.f) tmp = -32767.f;
+ xq48[i + OSCE_BWE_OUTPUT_DELAY] = float2int(tmp);
+ }
+
+ for (i = 0; i < OSCE_BWE_OUTPUT_DELAY; i++)
+ {
+ float tmp = 32768.f * out_buffer[3 * xq16_len - OSCE_BWE_OUTPUT_DELAY + i];
+ if (tmp > 32767.f) tmp = 32767.f;
+ if (tmp < -32767.f) tmp = -32767.f;
+ psOSCEBWE->state.bbwenet.outbut_buffer[i] = float2int(tmp);
+ }
+
+
+ }
+
+ #endif
void osce_enhance_frame(
OSCEModel *model, /* I OSCE model struct */
--- a/dnn/osce.h
+++ b/dnn/osce.h
@@ -42,6 +42,12 @@
#include "osce_structs.h"
#include "structs.h"
+
+#define OSCE_MODE_SILK_ONLY 1000
+#define OSCE_MODE_HYBRID 1001
+#define OSCE_MODE_CELT_ONLY 1002
+#define OSCE_MODE_SILK_BBWE 1003
+
#define OSCE_METHOD_NONE 0
#ifndef DISABLE_LACE
#define OSCE_METHOD_LACE 1
@@ -61,7 +67,10 @@
#define OSCE_MAX_RNN_NEURONS 0
#endif
+#ifdef ENABLE_OSCE_BWE
+#define OSCE_BWE_MAX_RNN_NEURONS BBWENET_FNET_GRU_STATE_SIZE
+#endif
/* API */
@@ -79,6 +88,19 @@
int osce_load_models(OSCEModel *hModel, const void *data, int len);
void osce_reset(silk_OSCE_struct *hOSCE, int method);
+
+#ifdef ENABLE_OSCE_BWE
+void osce_bwe(
+ OSCEModel *model, /* I OSCE model struct */
+ silk_OSCE_BWE_struct *psOSCEBWE, /* I/O OSCE BWE state */
+ opus_int16 xq48[], /* O bandwidth-extended speech */
+ opus_int16 xq16[], /* I Decoded speech */
+ opus_int32 xq16_len, /* I Length of xq16 in samples */
+ int arch /* I Run-time architecture */
+);
+
+void osce_bwe_reset(silk_OSCE_BWE_struct *hOSCEBWE);
+#endif /* ENABLE_OSCE_BWE */
#endif
--- a/dnn/osce_config.h
+++ b/dnn/osce_config.h
@@ -56,5 +56,11 @@
#define OSCE_LOG_GAIN_START 92
#define OSCE_LOG_GAIN_LENGTH 1
+#define OSCE_BWE_MAX_INSTAFREQ_BIN 40
+#define OSCE_BWE_HALF_WINDOW_SIZE 160
+#define OSCE_BWE_WINDOW_SIZE (2 * (OSCE_BWE_HALF_WINDOW_SIZE))
+#define OSCE_BWE_NUM_BANDS 32
+#define OSCE_BWE_FEATURE_DIM 114
+#define OSCE_BWE_OUTPUT_DELAY 21
#endif
--- a/dnn/osce_features.c
+++ b/dnn/osce_features.c
@@ -68,6 +68,13 @@
136, 160
};
+static const int center_bins_bwe[32] = {
+ 0, 5, 10, 15, 20, 25, 30, 35,
+ 40, 45, 50, 55, 60, 65, 70, 75,
+ 80, 85, 90, 95, 100, 105, 110, 115,
+ 120, 125, 130, 135, 140, 145, 150, 160
+};
+
static const float band_weights_clean[64] = {
0.666666666667f, 0.400000000000f, 0.333333333333f, 0.400000000000f,
0.500000000000f, 0.400000000000f, 0.333333333333f, 0.400000000000f,
@@ -95,6 +102,17 @@
0.041666666667f, 0.080000000000f
};
+static const float band_weights_bwe[32] = {
+ 0.333333333, 0.200000000, 0.200000000, 0.200000000,
+ 0.200000000, 0.200000000, 0.200000000, 0.200000000,
+ 0.200000000, 0.200000000, 0.200000000, 0.200000000,
+ 0.200000000, 0.200000000, 0.200000000, 0.200000000,
+ 0.200000000, 0.200000000, 0.200000000, 0.200000000,
+ 0.200000000, 0.200000000, 0.200000000, 0.200000000,
+ 0.200000000, 0.200000000, 0.200000000, 0.200000000,
+ 0.200000000, 0.200000000, 0.133333333, 0.181818182
+};
+
static float osce_window[OSCE_SPEC_WINDOW_SIZE] = {
0.004908718808f, 0.014725683311f, 0.024541228523f, 0.034354408400f, 0.044164277127f,
0.053969889210f, 0.063770299562f, 0.073564563600f, 0.083351737332f, 0.093130877450f,
@@ -440,6 +458,90 @@
}
+#ifdef ENABLE_OSCE_BWE
+void osce_bwe_calculate_features(
+ OSCEBWEFeatureState *psFeatures, /* I/O BWE feature state */
+ float *features, /* O input features */
+ const opus_int16 xq[], /* I Decoded speech */
+ int num_samples /* I number of input samples */
+ )
+ {
+ int n, k, num_frames, frame;
+ kiss_fft_cpx fft_buffer[OSCE_BWE_WINDOW_SIZE];
+ float spec[2 * OSCE_BWE_MAX_INSTAFREQ_BIN + 2];
+ float buffer[OSCE_BWE_WINDOW_SIZE];
+ float mag_spec[OSCE_SPEC_NUM_FREQS];
+ float *lmspec, *instafreq;
+
+ /* OSCE_BWE_WINDOW_SIZE == 320 is a hard requirement */
+ celt_assert((num_samples % OSCE_BWE_HALF_WINDOW_SIZE == 0) && (OSCE_BWE_WINDOW_SIZE == 320));
+
+ num_frames = num_samples / OSCE_BWE_HALF_WINDOW_SIZE;
+
+ for (frame = 0; frame < num_frames; frame++)
+ {
+ /* clear features */
+ OPUS_CLEAR(features + frame * OSCE_BWE_FEATURE_DIM, OSCE_BWE_FEATURE_DIM);
+
+ lmspec = features + frame * OSCE_BWE_FEATURE_DIM;
+ instafreq = lmspec + OSCE_BWE_NUM_BANDS;
+ const opus_int16 *x = xq + frame * OSCE_BWE_HALF_WINDOW_SIZE;
+
+ OPUS_COPY(buffer, psFeatures->signal_history, OSCE_BWE_HALF_WINDOW_SIZE);
+ for (n = 0; n < OSCE_BWE_HALF_WINDOW_SIZE; n++)
+ {
+ buffer[n + OSCE_BWE_HALF_WINDOW_SIZE] = (float) x[n] / (1U<<15);
+ }
+
+ /* update signal history buffer */
+ OPUS_COPY(psFeatures->signal_history, buffer + OSCE_BWE_HALF_WINDOW_SIZE, OSCE_BWE_HALF_WINDOW_SIZE);
+
+ /* apply window */
+ for (n = 0; n < OSCE_BWE_WINDOW_SIZE; n ++)
+ {
+ buffer[n] *= osce_window[n];
+ }
+
+ /* DFT */
+ forward_transform(fft_buffer, buffer);
+
+ /* instafreq */
+ for (k = 0; k <= OSCE_BWE_MAX_INSTAFREQ_BIN; k++)
+ {
+ float aux_r, aux_i, aux_abs;
+ float re1, re2, im1, im2;
+ spec[2*k] = OSCE_BWE_WINDOW_SIZE * fft_buffer[k].r + 1e-9; /* ToDo: remove 1e-9 from python code*/
+ spec[2*k+1] = OSCE_BWE_WINDOW_SIZE * fft_buffer[k].i;
+ re1 = spec[2*k];
+ im1 = spec[2*k+1];
+ re2 = psFeatures->last_spec[2*k];
+ im2 = psFeatures->last_spec[2*k+1];
+ aux_r = re1 * re2 + im1 * im2;
+ aux_i = im1 * re2 - re1 * im2;
+ aux_abs = sqrt(aux_r * aux_r + aux_i * aux_i);
+ instafreq[k] = aux_r / (aux_abs + 1e-9);
+ instafreq[k + OSCE_BWE_MAX_INSTAFREQ_BIN + 1] = aux_i / (aux_abs + 1e-9);
+ }
+
+ /* erb-scale magnitude spectrogram */
+ for (k = 0; k < OSCE_SPEC_NUM_FREQS; k ++)
+ {
+ mag_spec[k] = OSCE_BWE_WINDOW_SIZE * sqrt(fft_buffer[k].r * fft_buffer[k].r + fft_buffer[k].i * fft_buffer[k].i);
+ }
+
+ apply_filterbank(lmspec, mag_spec, center_bins_bwe, band_weights_bwe, OSCE_BWE_NUM_BANDS);
+
+ for (k = 0; k < OSCE_BWE_NUM_BANDS; k++)
+ {
+ lmspec[k] = log(lmspec[k] + 1e-9);
+ }
+
+ /* update instafreq buffer */
+ OPUS_COPY(psFeatures->last_spec, spec, 2 * OSCE_BWE_MAX_INSTAFREQ_BIN + 2);
+ }
+}
+#endif /* ENABLE_OSCE_BWE */
+
void osce_cross_fade_10ms(float *x_enhanced, float *x_in, int length)
{
int i;
@@ -449,6 +551,22 @@
{
x_enhanced[i] = osce_window[i] * x_enhanced[i] + (1.f - osce_window[i]) * x_in[i];
}
+}
-}
+void osce_bwe_cross_fade_10ms(opus_int16 *x_fadein, opus_int16 *x_fadeout, int length)
+{
+ int i;
+ celt_assert(length >= 480);
+ float f = 1.f / 3;
+ for (i = 0; i < 160; i++)
+ {
+ float diff = i == 159 ? 0.f : osce_window[i + 1] - osce_window[i];
+ float w_curr = osce_window[i];
+ x_fadein[3*i + 0] = (int) (w_curr * x_fadein[3*i + 0] + (1.f - w_curr) * x_fadeout[3*i + 0] + 0.5);
+ w_curr += diff * f;
+ x_fadein[3*i + 1] = (int) (w_curr * x_fadein[3*i + 1] + (1.f - w_curr) * x_fadeout[3*i + 1] + 0.5);
+ w_curr += diff * f;
+ x_fadein[3*i + 2] = (int) (w_curr * x_fadein[3*i + 2] + (1.f - w_curr) * x_fadeout[3*i + 2] + 0.5);
+ }
+}
\ No newline at end of file
--- a/dnn/osce_features.h
+++ b/dnn/osce_features.h
@@ -44,7 +44,15 @@
opus_int32 num_bits /* I Size of SILK payload in bits */
);
+#ifdef ENABLE_OSCE_BWE
+void osce_bwe_calculate_features(
+ OSCEBWEFeatureState *psFeatures, /* I/O BWE feature state */
+ float *features, /* O input features */
+ const opus_int16 xq[], /* I Decoded speech */
+ int num_samples /* I number of input samples */
+);
+#endif
void osce_cross_fade_10ms(float *x_enhanced, float *x_in, int length);
-
+void osce_bwe_cross_fade_10ms(opus_int16 *x_fadein, opus_int16* x_fadeout, int length);
#endif
--- a/dnn/osce_structs.h
+++ b/dnn/osce_structs.h
@@ -36,6 +36,10 @@
#ifndef DISABLE_NOLACE
#include "nolace_data.h"
#endif
+#ifndef DISABLE_BBWENET
+#include "bbwenet_data.h"
+#include "resampler_structs.h"
+#endif
#include "nndsp.h"
#include "nnet.h"
@@ -50,7 +54,40 @@
int reset;
} OSCEFeatureState;
+typedef struct {
+ float signal_history[OSCE_BWE_HALF_WINDOW_SIZE];
+ float last_spec[2 * OSCE_BWE_MAX_INSTAFREQ_BIN + 2];
+} OSCEBWEFeatureState;
+#ifndef DISABLE_BBWENET
+/* BBWENet */
+typedef struct {
+ float upsamp_buffer[2][3];
+ float interpol_buffer[8];
+} resamp_state;
+
+typedef struct {
+ float feature_net_conv1_state[BBWENET_FNET_CONV1_STATE_SIZE];
+ float feature_net_conv2_state[BBWENET_FNET_CONV2_STATE_SIZE];
+ float feature_net_gru_state[BBWENET_FNET_GRU_STATE_SIZE];
+ opus_int16 outbut_buffer[OSCE_BWE_OUTPUT_DELAY];
+ AdaConvState af1_state;
+ AdaConvState af2_state;
+ AdaConvState af3_state;
+ AdaShapeState tdshape1_state;
+ AdaShapeState tdshape2_state;
+ resamp_state resampler_state[3];
+} BBWENetState;
+
+typedef struct {
+ BBWENETLayers layers;
+ float window16[BBWENET_AF1_OVERLAP_SIZE];
+ float window32[BBWENET_AF2_OVERLAP_SIZE];
+ float window48[BBWENET_AF3_OVERLAP_SIZE];
+} BBWENet;
+#endif
+
+
#ifndef DISABLE_LACE
/* LACE */
typedef struct {
@@ -111,6 +148,9 @@
#ifndef DISABLE_NOLACE
NoLACE nolace;
#endif
+#ifndef DISABLE_BBWENET
+ BBWENet bbwenet;
+#endif
} OSCEModel;
typedef union {
@@ -121,5 +161,11 @@
NoLACEState nolace;
#endif
} OSCEState;
+
+typedef struct {
+ #ifndef DISABLE_BBWENET
+ BBWENetState bbwenet;
+ #endif
+} OSCEBWEState;
#endif
--- a/dnn/torch/osce/export_model_weights.py
+++ b/dnn/torch/osce/export_model_weights.py
@@ -50,7 +50,7 @@
parser = argparse.ArgumentParser()
-parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint')
+parser.add_argument('checkpoint', type=str, help='LACE, NoLACE or BBWENet model checkpoint')
parser.add_argument('output_dir', type=str, help='output folder')
parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
@@ -86,6 +86,17 @@
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None))
+ ],
+ 'bbwenet' : [
+ ('feature_net.conv1', dict(quantize=False, scale=None)),
+ ('feature_net.conv2', dict(quantize=True, scale=None)),
+ ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
+ ('feature_net.tconv', dict(quantize=True, scale=None)),
+ ('tdshape1', dict(quantize=True, scale=None)),
+ ('tdshape2', dict(quantize=True, scale=None)),
+ ('af1', dict(quantize=True, scale=None)),
+ ('af2', dict(quantize=True, scale=None)),
+ ('af3', dict(quantize=True, scale=None)),
]
}
@@ -148,7 +159,8 @@
cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper() + 'Layers', add_typedef=True)
# Add custom includes and global parameters
- cwriter.header.write(f'''
+ if model_name in {'lace', 'nolace'}:
+ cwriter.header.write(f'''
#define {model_name.upper()}_PREEMPH {model.preemph}f
#define {model_name.upper()}_FRAME_SIZE {model.FRAME_SIZE}
#define {model_name.upper()}_OVERLAP_SIZE 40
@@ -161,9 +173,20 @@
#define {model_name.upper()}_COND_DIM {model.cond_dim}
#define {model_name.upper()}_HIDDEN_FEATURE_DIM {model.hidden_feature_dim}
''')
+ for i, s in enumerate(model.numbits_embedding.scale_factors):
+ cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}f\n")
- for i, s in enumerate(model.numbits_embedding.scale_factors):
- cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}f\n")
+ elif model_name in {'bbwenet'}:
+ # restricting to bbwenet with both activations for now
+ assert model.shape_extension and model.func_extension
+ assert model.activation == "ImPowI" and model.shaper == "TDShaper"
+ cwriter.header.write(f'''
+#define {model_name.upper()}_FEATURE_DIM {model.feature_dim}
+#define {model_name.upper()}_FRAME_SIZE16 {model.frame_size16}
+#define {model_name.upper()}_COND_DIM {model.cond_dim}
+''')
+
+
# dump layers
if model_name in schedules and args.quantize:
--- a/dnn/torch/osce/models/bbwe_net.py
+++ b/dnn/torch/osce/models/bbwe_net.py
@@ -22,8 +22,16 @@
s = 0.5 * s / s.max()
wavfile.write(filename, fs, (2**15 * s).astype(np.int16))
+DEBUGDUMP=False
+if DEBUGDUMP:
+ import os
+ debugdumpdir='debugdump'
+ os.makedirs(debugdumpdir, exist_ok=True)
+ def debugdump(filename, data):
+ data.detach().numpy().tofile(os.path.join(debugdumpdir, filename))
+
class FloatFeatureNet(nn.Module):
def __init__(self,
@@ -62,11 +70,14 @@
return count
- def forward(self, features, state=None):
+ def forward(self, features, state=None, debug=False):
""" features shape: (batch_size, num_frames, feature_dim) """
batch_size = features.size(0)
+ if DEBUGDUMP:
+ debugdump('features.f32', features.float())
+
if state is None:
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
@@ -77,14 +88,23 @@
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
else:
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
+ if DEBUGDUMP:
+ debugdump('feature_net_conv1_activated.f32', c.permute(0, 2, 1))
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
+ if DEBUGDUMP:
+ debugdump('feature_net_conv2_activated.f32', c.permute(0, 2, 1))
c = torch.tanh(self.tconv(c))
+ if DEBUGDUMP:
+ debugdump('feature_net_tconv_activated.f32', c.permute(0, 2, 1))
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)
+ if DEBUGDUMP:
+ debugdump('feature_net_gru.f32', c)
+
return c
@@ -218,28 +238,51 @@
# split into latent_channels channels
y16 = self.af1(x, cf, debug=debug)
+ if DEBUGDUMP:
+ debugdump('bbwenet_af1_1.f32', y16[:, 0:1, :]) # first channel is bypass channel
+ debugdump('bbwenet_af1_2.f32', y16[:, 1:2, :])
+ debugdump('bbwenet_af1_3.f32', y16[:, 2:3, :])
+
# first 2x upsampling step
y32 = self.upsampler.hq_2x_up(y16)
y32_out = y32[:, 0:1, :] # first channel is bypass channel
+ if DEBUGDUMP:
+ debugdump('bbwenet_up2_1.f32', y32_out)
+ debugdump('bbwenet_up2_2.f32', y32[:, 1:2, :])
+ debugdump('bbwenet_up2_3.f32', y32[:, 2:3, :])
+
# extend frequencies
idx = 1
if self.shape_extension:
y32_shape = self.tdshape1(y32[:, idx:idx+1, :], cf)
y32_out = torch.cat((y32_out, y32_shape), dim=1)
+ if DEBUGDUMP:
+ debugdump('bbwenet_up2_shape.f32', y32_shape)
idx += 1
if self.func_extension:
y32_func = self.nlfunc(y32[:, idx:idx+1, :])
y32_out = torch.cat((y32_out, y32_func), dim=1)
+ if DEBUGDUMP:
+ debugdump('bbwenet_up2_func.f32', y32_func)
# mix-select
y32_out = self.af2(y32_out, cf)
+ if DEBUGDUMP:
+ debugdump('bbwenet_af2_1.f32', y32_out[:, 0:1, :])
+ debugdump('bbwenet_af2_2.f32', y32_out[:, 1:2, :])
+ debugdump('bbwenet_af2_3.f32', y32_out[:, 2:3, :])
# 1.5x upsampling
y48 = self.upsampler.interpolate_3_2(y32_out)
y48_out = y48[:, 0:1, :] # first channel is bypass channel
+ if DEBUGDUMP:
+ debugdump('bbwenet_up15_1.f32', y48_out)
+ debugdump('bbwenet_up15_2.f32', y48[:, 1:2, :])
+ debugdump('bbwenet_up15_3.f32', y48[:, 2:3, :])
+
# extend frequencies
idx = 1
if self.shape_extension:
@@ -246,12 +289,19 @@
y48_shape = self.tdshape2(y48[:, idx:idx+1, :], cf)
y48_out = torch.cat((y48_out, y48_shape), dim=1)
idx += 1
+ if DEBUGDUMP:
+ debugdump('bbwenet_up15_shape.f32', y48_shape)
if self.func_extension:
y48_func = self.nlfunc(y48[:, idx:idx+1, :])
y48_out = torch.cat((y48_out, y48_func), dim=1)
+ if DEBUGDUMP:
+ debugdump('bbwenet_up15_func.f32', y48_func)
# 2nd mixing
y48_out = self.af3(y48_out, cf)
+
+ if DEBUGDUMP:
+ debugdump('bbwenet_af3_1.f32', y48_out)
return y48_out
--- /dev/null
+++ b/dnn/torch/osce/scripts/bwe_extract_filterbank.py
@@ -1,0 +1,29 @@
+import argparse
+import sys
+sys.path.append('./')
+
+import torch
+from utils.spec import create_filter_bank
+import numpy as np
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('checkpoint', type=str)
+
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+
+ c = torch.load(args.checkpoint, map_location='cpu')
+
+ num_bands = c['setup']['data']['spec_num_bands']
+ fb, center_bins = create_filter_bank(num_bands, n_fft=320, fs=16000, scale='erb', round_center_bins=True, normalize=False, return_center_bins=True)
+ weights = 1/fb.sum(axis=-1)
+
+ print(f"center_bins:")
+
+ print("".join([f"{int(cb):4d}," for cb in center_bins]))
+
+ print(f"band_weights:")
+ print("".join([f" {w:1.9f}," for w in weights]))
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/scripts/print_features.py
@@ -1,0 +1,22 @@
+import argparse
+import sys
+sys.path.append('./')
+
+import torch
+from utils.bwe_features import load_inference_data
+import numpy as np
+
+parser = argparse.ArgumentParser()
+parser.add_argument('testsignal', type=str)
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+
+ _, features = load_inference_data(args.testsignal)
+
+ N = features.shape[0]
+
+ for n in range(N):
+ print(f"frame[{n}]")
+ print(f"lmspec: {features[n, :32]}")
+ print(f"freqs: {features[n,32:]}")
--- a/dnn/torch/osce/utils/spec.py
+++ b/dnn/torch/osce/utils/spec.py
@@ -65,7 +65,7 @@
RE = RE/norm[:, np.newaxis]
return torch.from_numpy(RE)
-def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
+def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False, return_center_bins=False):
f0 = 0
num_bins = n_fft // 2 + 1
@@ -111,6 +111,9 @@
if normalize:
filter_bank = filter_bank / np.sum(filter_bank, axis=1).reshape(-1, 1)
+ if return_center_bins:
+ return filter_bank, center_bins
+
return filter_bank
@@ -232,7 +235,7 @@
X = np.fft.fft(x_unfold, n=frame_size, axis=-1)
# instantaneus frequency
- X_trunc = X[..., :max_bin + 1] + 1e-9
+ X_trunc = X[..., :max_bin + 1] + 1e-9
Y = X_trunc[1:] * np.conj(X_trunc[:-1])
Y = Y / (np.abs(Y) + 1e-9)
--- a/dnn/torch/weight-exchange/wexchange/torch/torch.py
+++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py
@@ -156,6 +156,9 @@
def dump_torch_tdshaper(where, shaper, name='tdshaper', quantize=False, scale=1/128):
if isinstance(where, CWriter):
+ interpolate_k = 1
+ if hasattr(shaper, 'interpolate_k'):
+ interpolate_k = shaper.interpolate_k
where.header.write(f"""
#define {name.upper()}_FEATURE_DIM {shaper.feature_dim}
#define {name.upper()}_FRAME_SIZE {shaper.frame_size}
@@ -162,6 +165,7 @@
#define {name.upper()}_AVG_POOL_K {shaper.avg_pool_k}
#define {name.upper()}_INNOVATE {1 if shaper.innovate else 0}
#define {name.upper()}_POOL_AFTER {1 if shaper.pool_after else 0}
+#define {name.upper()}_INTERPOLATE_K {interpolate_k}
"""
)
--- a/include/opus_defines.h
+++ b/include/opus_defines.h
@@ -173,6 +173,8 @@
#define OPUS_GET_DRED_DURATION_REQUEST 4051
#define OPUS_SET_DNN_BLOB_REQUEST 4052
/*#define OPUS_GET_DNN_BLOB_REQUEST 4053 */
+#define OPUS_SET_OSCE_BWE_REQUEST 4054
+#define OPUS_GET_OSCE_BWE_REQUEST 4055
/** Defines for the presence of extended APIs. */
#define OPUS_HAVE_OPUS_PROJECTION_H
@@ -798,6 +800,19 @@
*
* @hideinitializer */
#define OPUS_GET_PITCH(x) OPUS_GET_PITCH_REQUEST, __opus_check_int_ptr(x)
+
+/** Enables blind bandwidth extension for wideband signals if decoding sampling rate is 48 kHz.
+ * @param[in] x <tt>opus_int32 </tt>: 1 enables bandwidth extension, 0 disables it.
+ * The default is 0.
+ *
+ * @hideinitializer */
+ #define OPUS_SET_OSCE_BWE(x) OPUS_SET_OSCE_BWE_REQUEST, __opus_check_int(x)
+/** Gets blind bandwidth extension flag for wideband signals if decoding sampling rate is 48 kHz.
+ * @param[out] x <tt>opus_int32 *</tt>: 1 if bwe enabled, 0 if disabled.
+ *
+ * @hideinitializer */
+ #define OPUS_GET_OSCE_BWE(x) OPUS_GET_OSCE_BWE_REQUEST, __opus_check_int_ptr(x)
+
/**@}*/
--- a/lpcnet_headers.mk
+++ b/lpcnet_headers.mk
@@ -37,7 +37,8 @@
dnn/osce_features.h \
dnn/nndsp.h \
dnn/lace_data.h \
-dnn/nolace_data.h
+dnn/nolace_data.h \
+dnn/bbwenet_data.h
LOSSGEN_HEAD = \
dnn/lossgen.h \
--- a/lpcnet_sources.mk
+++ b/lpcnet_sources.mk
@@ -28,7 +28,8 @@
dnn/osce_features.c \
dnn/nndsp.c \
dnn/lace_data.c \
-dnn/nolace_data.c
+dnn/nolace_data.c \
+dnn/bbwenet_data.c
LOSSGEN_SOURCES = \
dnn/lossgen.c \
--- a/silk/control.h
+++ b/silk/control.h
@@ -151,6 +151,17 @@
#ifdef ENABLE_OSCE
/* I: OSCE method */
opus_int osce_method;
+
+#ifdef ENABLE_OSCE_BWE
+ /* I: OSCE bandwidth extension method */
+ opus_int enable_osce_bwe;
+
+ /* I: extended mode */
+ opus_int osce_extended_mode;
+
+ /* O: previous extended mode */
+ opus_int prev_osce_extended_mode;
+#endif
#endif
} silk_DecControlStruct;
--- a/silk/dec_API.c
+++ b/silk/dec_API.c
@@ -36,7 +36,10 @@
#ifdef ENABLE_OSCE
#include "osce.h"
#include "osce_structs.h"
+#ifdef ENABLE_OSCE_BWE
+#include "osce_features.h"
#endif
+#endif
/************************/
/* Decoder Super Struct */
@@ -154,6 +157,9 @@
silk_decoder_state *channel_state = psDec->channel_state;
opus_int has_side;
opus_int stereo_to_mono;
+#ifdef ENABLE_OSCE_BWE
+ ALLOC(resamp_buffer, 3 * MAX_FRAME_LENGTH, opus_int16);
+#endif
SAVE_STACK;
celt_assert( decControl->nChannelsInternal == 1 || decControl->nChannelsInternal == 2 );
@@ -378,9 +384,38 @@
for( n = 0; n < silk_min( decControl->nChannelsAPI, decControl->nChannelsInternal ); n++ ) {
+#ifdef ENABLE_OSCE_BWE
+ /* Resample or extend decoded signal to API_sampleRate */
+ if (decControl->osce_extended_mode == OSCE_MODE_SILK_BBWE) {
+ silk_assert(decControl->API_sampleRate == 48000);
+
+ if (decControl->prev_osce_extended_mode != OSCE_MODE_SILK_BBWE) {
+ /* Reset the BWE state */
+ osce_bwe_reset( &channel_state[ n ].osce_bwe );
+ }
+
+ osce_bwe(&psDec->osce_model, &channel_state[ n ].osce_bwe,
+ resample_out_ptr, &samplesOut1_tmp[ n ][ 1 ], nSamplesOutDec, arch);
+
+ if (decControl->prev_osce_extended_mode == OSCE_MODE_SILK_ONLY ||
+ decControl->prev_osce_extended_mode == OSCE_MODE_HYBRID) {
+ /* cross-fade with upsampled signal */
+ silk_resampler( &channel_state[ n ].resampler_state, resamp_buffer, &samplesOut1_tmp[ n ][ 1 ], nSamplesOutDec );
+ osce_bwe_cross_fade_10ms(resample_out_ptr, resamp_buffer, 480);
+ }
+ } else {
+ ret += silk_resampler( &channel_state[ n ].resampler_state, resample_out_ptr, &samplesOut1_tmp[ n ][ 1 ], nSamplesOutDec );
+ if (decControl->prev_osce_extended_mode == OSCE_MODE_SILK_BBWE) {
+ osce_bwe(&psDec->osce_model, &channel_state[ n ].osce_bwe,
+ resamp_buffer, &samplesOut1_tmp[ n ][ 1 ], nSamplesOutDec, arch);
+ /* cross-fade with upsampled signal */
+ osce_bwe_cross_fade_10ms(resample_out_ptr, resamp_buffer, 480);
+ }
+ }
+#else
/* Resample decoded signal to API_sampleRate */
ret += silk_resampler( &channel_state[ n ].resampler_state, resample_out_ptr, &samplesOut1_tmp[ n ][ 1 ], nSamplesOutDec );
-
+#endif
/* Interleave if stereo output and stereo stream */
if( decControl->nChannelsAPI == 2 ) {
for( i = 0; i < *nSamplesOut; i++ ) {
@@ -392,6 +427,10 @@
}
}
}
+
+#ifdef ENABLE_OSCE_BWE
+ decControl->prev_osce_extended_mode = decControl->osce_extended_mode;
+#endif
/* Create two channel output from mono stream */
if( decControl->nChannelsAPI == 2 && decControl->nChannelsInternal == 1 ) {
--- a/silk/structs.h
+++ b/silk/structs.h
@@ -249,6 +249,11 @@
OSCEState state;
int method;
} silk_OSCE_struct;
+
+typedef struct {
+ OSCEBWEFeatureState features;
+ OSCEBWEState state;
+} silk_OSCE_BWE_struct;
#endif
/* Struct for Packet Loss Concealment */
@@ -285,6 +290,9 @@
typedef struct {
#ifdef ENABLE_OSCE
silk_OSCE_struct osce;
+#ifdef ENABLE_OSCE_BWE
+ silk_OSCE_BWE_struct osce_bwe;
+#endif
#endif
#define SILK_DECODER_STATE_RESET_START prev_gain_Q16
opus_int32 prev_gain_Q16;
--- a/src/opus_decoder.c
+++ b/src/opus_decoder.c
@@ -435,7 +435,22 @@
#ifndef DISABLE_NOLACE
if (st->complexity >= 7) {st->DecControl.osce_method = OSCE_METHOD_NOLACE;}
#endif
+#ifdef ENABLE_OSCE_BWE
+ if (st->complexity >= 4 && st->DecControl.enable_osce_bwe &&
+ st->Fs == 48000 && st->DecControl.internalSampleRate == 16000 &&
+ ((mode == MODE_SILK_ONLY) || (data == NULL))) {
+ /* request WB -> FB signal extension */
+ st->DecControl.osce_extended_mode = OSCE_MODE_SILK_BBWE;
+ } else {
+ /* at this point, mode can only be MODE_SILK_ONLY or MODE_HYBRID */
+ st->DecControl.osce_extended_mode = mode == MODE_SILK_ONLY ? OSCE_MODE_SILK_ONLY : OSCE_MODE_HYBRID;
+ }
+ if (st->prev_mode == MODE_CELT_ONLY) {
+ /* Update extended mode for CELT->SILK transition */
+ st->DecControl.prev_osce_extended_mode = OSCE_MODE_CELT_ONLY;
+ }
#endif
+#endif
lost_flag = data == NULL ? 1 : 2 * !!decode_fec;
decoded_samples = 0;
@@ -563,7 +578,11 @@
/* MUST be after PLC */
MUST_SUCCEED(celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(start_band)));
+#ifdef ENABLE_OSCE_BWE
+ if (mode != MODE_SILK_ONLY && st->DecControl.osce_extended_mode != OSCE_MODE_SILK_BBWE)
+#else
if (mode != MODE_SILK_ONLY)
+#endif
{
int celt_frame_size = IMIN(F20, frame_size);
/* Make sure to discard any previous CELT state */
@@ -1016,6 +1035,28 @@
*value = st->complexity;
}
break;
+#ifdef ENABLE_OSCE_BWE
+ case OPUS_SET_OSCE_BWE_REQUEST:
+ {
+ opus_int32 value = va_arg(ap, opus_int32);
+ if(value<0 || value>1)
+ { goto bad_arg;
+ }
+ st->DecControl.enable_osce_bwe = value;
+
+ }
+ break;
+ case OPUS_GET_OSCE_BWE_REQUEST:
+ {
+ opus_int32 *value = va_arg(ap, opus_int32*);
+ if (!value)
+ {
+ goto bad_arg;
+ }
+ *value = st->DecControl.enable_osce_bwe;
+ }
+ break;
+#endif
case OPUS_GET_FINAL_RANGE_REQUEST:
{
opus_uint32 *value = va_arg(ap, opus_uint32*);
--- a/src/opus_demo.c
+++ b/src/opus_demo.c
@@ -141,6 +141,9 @@
fprintf(stderr, "-lossfile <file> : simulate packet loss, reading loss from file\n" );
fprintf(stderr, "-dred <frames> : add Deep REDundancy (in units of 10-ms frames)\n" );
fprintf(stderr, "-enc_loss : Apply loss on the encoder side (store empty packets)\n" );
+#ifdef ENABLE_OSCE_BWE
+ fprintf(stderr, "-enable_osce_bwe : enable OSCE bandwidth extension for wideband signals (48 kHz sampling rate only), raises dec_complexity to 4\n");
+#endif
}
#define FORMAT_S16_LE 0
@@ -431,6 +434,9 @@
int silk_random_switching = 0;
int silk_frame_counter = 0;
#endif
+#if defined(ENABLE_OSCE) && defined(ENABLE_OSCE_BWE)
+ int enable_osce_bwe = 0;
+#endif
#ifdef USE_WEIGHTS_FILE
int blob_len;
void *blob_data;
@@ -686,6 +692,11 @@
printf("switching encoding parameters every %dth frame\n", silk_random_switching);
args += 2;
#endif
+#if defined(ENABLE_OSCE) && defined(ENABLE_OSCE_BWE)
+ } else if( strcmp( argv[ args ], "-enable_osce_bwe" ) == 0 ) {
+ enable_osce_bwe = 1;
+ args++;
+#endif
} else {
printf( "Error: unrecognized setting: %s\n\n", argv[ args ] );
print_usage( argv );
@@ -767,6 +778,12 @@
fprintf(stderr, "Cannot create decoder: %s\n", opus_strerror(err));
goto failure;
}
+#ifdef ENABLE_OSCE_BWE
+ if (enable_osce_bwe) {
+ opus_decoder_ctl(dec, OPUS_SET_OSCE_BWE(1));
+ if (dec_complexity < 4) {dec_complexity = 4;}
+ }
+#endif
opus_decoder_ctl(dec, OPUS_SET_COMPLEXITY(dec_complexity));
}
switch(bandwidth)
--- a/tar_list.txt
+++ b/tar_list.txt
@@ -26,3 +26,6 @@
dnn/dred_rdovae_dec_data.c
dnn/lossgen_data.c
dnn/lossgen_data.h
+dnn/models/bbwenet_v1.pth
+dnn/bbwenet_data.h
+dnn/bbwenet_data.c
--
⑨