ref: 6f8db9392902f156125442730cf93b7de5f9aec1
parent: bfcf94de2a6dbc245f6b92f6a4d5a301845f2549
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Mon Mar 11 13:06:29 EDT 2019
Add M-best VQ search
--- a/dnn/dump_data.c
+++ b/dnn/dump_data.c
@@ -57,6 +57,37 @@
#include "ceps_codebooks.c"
+#define SURVIVORS 5
+
+
+void vq_quantize_mbest(const float *codebook, int nb_entries, const float *x, int ndim, int mbest, float *dist, int *index)
+{
+ int i, j;
+ for (i=0;i<mbest;i++) dist[i] = 1e15;
+
+ for (i=0;i<nb_entries;i++)
+ {
+ float d=0;
+ for (j=0;j<ndim;j++)
+ d += (x[j]-codebook[i*ndim+j])*(x[j]-codebook[i*ndim+j]);
+ if (d<dist[mbest-1])
+ {
+ int pos;
+ for (j=0;j<mbest-1;j++) {
+ if (d < dist[j]) break;
+ }
+ pos = j;
+ for (j=mbest-1;j>=pos+1;j--) {
+ dist[j] = dist[j-1];
+ index[j] = index[j-1];
+ }
+ dist[pos] = d;
+ index[pos] = i;
+ }
+ }
+}
+
+
int vq_quantize(const float *codebook, int nb_entries, const float *x, int ndim, float *dist)
{
int i, j;
@@ -110,6 +141,117 @@
return id;
}
+
+int quantize_3stage_mbest(float *x)
+{
+ int i, k;
+ int id, id2, id3;
+ float ref[NB_BANDS_1];
+ int curr_index[SURVIVORS];
+ int index1[SURVIVORS][3];
+ int index2[SURVIVORS][3];
+ int index3[SURVIVORS][3];
+ float curr_dist[SURVIVORS];
+ float glob_dist[SURVIVORS];
+ RNN_COPY(ref, x, NB_BANDS_1);
+ vq_quantize_mbest(ceps_codebook1, 1024, x, NB_BANDS_1, SURVIVORS, curr_dist, curr_index);
+ for (k=0;k<SURVIVORS;k++) {
+ index1[k][0] = curr_index[k];
+ }
+ for (k=0;k<SURVIVORS;k++) {
+ int m;
+ float diff[NB_BANDS_1];
+ for (i=0;i<NB_BANDS_1;i++) {
+ diff[i] = x[i] - ceps_codebook1[index1[k][0]*NB_BANDS_1 + i];
+ }
+ vq_quantize_mbest(ceps_codebook2, 1024, diff, NB_BANDS_1, SURVIVORS, curr_dist, curr_index);
+ if (k==0) {
+ for (m=0;m<SURVIVORS;m++) {
+ index2[m][0] = index1[k][0];
+ index2[m][1] = curr_index[m];
+ glob_dist[m] = curr_dist[m];
+ }
+ //printf("%f ", glob_dist[0]);
+ } else if (curr_dist[0] < glob_dist[SURVIVORS-1]) {
+ m=0;
+ int pos;
+ for (pos=0;pos<SURVIVORS;pos++) {
+ if (curr_dist[m] < glob_dist[pos]) {
+ int j;
+ for (j=SURVIVORS-1;j>=pos+1;j--) {
+ glob_dist[j] = glob_dist[j-1];
+ index2[j][0] = index2[j-1][0];
+ index2[j][1] = index2[j-1][1];
+ }
+ glob_dist[pos] = curr_dist[m];
+ index2[pos][0] = index1[k][0];
+ index2[pos][1] = curr_index[m];
+ m++;
+ }
+ }
+ }
+ }
+ for (k=0;k<SURVIVORS;k++) {
+ int m;
+ float diff[NB_BANDS_1];
+ for (i=0;i<NB_BANDS_1;i++) {
+ diff[i] = x[i] - ceps_codebook1[index2[k][0]*NB_BANDS_1 + i] - ceps_codebook2[index2[k][1]*NB_BANDS_1 + i];
+ }
+ vq_quantize_mbest(ceps_codebook3, 1024, diff, NB_BANDS_1, SURVIVORS, curr_dist, curr_index);
+ if (k==0) {
+ for (m=0;m<SURVIVORS;m++) {
+ index3[m][0] = index2[k][0];
+ index3[m][1] = index2[k][1];
+ index3[m][2] = curr_index[m];
+ glob_dist[m] = curr_dist[m];
+ }
+ //printf("%f ", glob_dist[0]);
+ } else if (curr_dist[0] < glob_dist[SURVIVORS-1]) {
+ m=0;
+ int pos;
+ for (pos=0;pos<SURVIVORS;pos++) {
+ if (curr_dist[m] < glob_dist[pos]) {
+ int j;
+ for (j=SURVIVORS-1;j>=pos+1;j--) {
+ glob_dist[j] = glob_dist[j-1];
+ index3[j][0] = index3[j-1][0];
+ index3[j][1] = index3[j-1][1];
+ index3[j][2] = index3[j-1][2];
+ }
+ glob_dist[pos] = curr_dist[m];
+ index3[pos][0] = index2[k][0];
+ index3[pos][1] = index2[k][1];
+ index3[pos][2] = curr_index[m];
+ m++;
+ }
+ }
+ }
+ }
+ id = index3[0][0];
+ id2 = index3[0][1];
+ id3 = index3[0][2];
+ //printf("%f ", glob_dist[0]);
+ for (i=0;i<NB_BANDS_1;i++) {
+ x[i] -= ceps_codebook1[id*NB_BANDS_1 + i];
+ }
+ for (i=0;i<NB_BANDS_1;i++) {
+ x[i] -= ceps_codebook2[id2*NB_BANDS_1 + i];
+ }
+ //id3 = vq_quantize(ceps_codebook3, 1024, x, NB_BANDS_1, NULL);
+ for (i=0;i<NB_BANDS_1;i++) {
+ x[i] = ceps_codebook1[id*NB_BANDS_1 + i] + ceps_codebook2[id2*NB_BANDS_1 + i] + ceps_codebook3[id3*NB_BANDS_1 + i];
+ }
+ if (0) {
+ float err = 0;
+ for (i=0;i<NB_BANDS_1;i++) {
+ err += (x[i]-ref[i])*(x[i]-ref[i]);
+ }
+ printf("%f\n", sqrt(err/NB_BANDS));
+ }
+
+ return id;
+}
+
static int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist, int sign)
{
int i, j;
@@ -564,7 +706,7 @@
RNN_COPY(&st->xc[1][0], &st->xc[9][0], PITCH_MAX_PERIOD);
//printf("%f\n", st->features[3][0]);
st->features[3][0] = floor(.5 + st->features[3][0]*5)/5;
- quantize_2stage(&st->features[3][1]);
+ quantize_3stage_mbest(&st->features[3][1]);
/*perform_interp_relaxation(st->features, vq_mem);*/
quantize_diff(&st->features[1][0], vq_mem, &st->features[3][0], ceps_codebook_diff4, 11, 1);
#if 0
--
⑨