shithub: opus

Download patch

ref: 6f8db9392902f156125442730cf93b7de5f9aec1
parent: bfcf94de2a6dbc245f6b92f6a4d5a301845f2549
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Mon Mar 11 13:06:29 EDT 2019

Add M-best VQ search

--- a/dnn/dump_data.c
+++ b/dnn/dump_data.c
@@ -57,6 +57,37 @@
 
 #include "ceps_codebooks.c"
 
+#define SURVIVORS 5
+
+
+void vq_quantize_mbest(const float *codebook, int nb_entries, const float *x, int ndim, int mbest, float *dist, int *index)
+{
+  int i, j;
+  for (i=0;i<mbest;i++) dist[i] = 1e15;
+  
+  for (i=0;i<nb_entries;i++)
+  {
+    float d=0;
+    for (j=0;j<ndim;j++)
+      d += (x[j]-codebook[i*ndim+j])*(x[j]-codebook[i*ndim+j]);
+    if (d<dist[mbest-1])
+    {
+      int pos;
+      for (j=0;j<mbest-1;j++) {
+        if (d < dist[j]) break;
+      }
+      pos = j;
+      for (j=mbest-1;j>=pos+1;j--) {
+        dist[j] = dist[j-1];
+        index[j] = index[j-1];
+      }
+      dist[pos] = d;
+      index[pos] = i;
+    }
+  }
+}
+
+
 int vq_quantize(const float *codebook, int nb_entries, const float *x, int ndim, float *dist)
 {
   int i, j;
@@ -110,6 +141,117 @@
     return id;
 }
 
+
+int quantize_3stage_mbest(float *x)
+{
+    int i, k;
+    int id, id2, id3;
+    float ref[NB_BANDS_1];
+    int curr_index[SURVIVORS];
+    int index1[SURVIVORS][3];
+    int index2[SURVIVORS][3];
+    int index3[SURVIVORS][3];
+    float curr_dist[SURVIVORS];
+    float glob_dist[SURVIVORS];
+    RNN_COPY(ref, x, NB_BANDS_1);
+    vq_quantize_mbest(ceps_codebook1, 1024, x, NB_BANDS_1, SURVIVORS, curr_dist, curr_index);
+    for (k=0;k<SURVIVORS;k++) {
+      index1[k][0] = curr_index[k];
+    }
+    for (k=0;k<SURVIVORS;k++) {
+      int m;
+      float diff[NB_BANDS_1];
+      for (i=0;i<NB_BANDS_1;i++) {
+        diff[i] = x[i] - ceps_codebook1[index1[k][0]*NB_BANDS_1 + i];
+      }
+      vq_quantize_mbest(ceps_codebook2, 1024, diff, NB_BANDS_1, SURVIVORS, curr_dist, curr_index);
+      if (k==0) {
+        for (m=0;m<SURVIVORS;m++) {
+          index2[m][0] = index1[k][0];
+          index2[m][1] = curr_index[m];
+          glob_dist[m] = curr_dist[m];
+        }
+        //printf("%f ", glob_dist[0]);
+      } else if (curr_dist[0] < glob_dist[SURVIVORS-1]) {
+        m=0;
+        int pos;
+        for (pos=0;pos<SURVIVORS;pos++) {
+          if (curr_dist[m] < glob_dist[pos]) {
+            int j;
+            for (j=SURVIVORS-1;j>=pos+1;j--) {
+              glob_dist[j] = glob_dist[j-1];
+              index2[j][0] = index2[j-1][0];
+              index2[j][1] = index2[j-1][1];
+            }
+            glob_dist[pos] = curr_dist[m];
+            index2[pos][0] = index1[k][0];
+            index2[pos][1] = curr_index[m];
+            m++;
+          }
+        }
+      }
+    }
+    for (k=0;k<SURVIVORS;k++) {
+      int m;
+      float diff[NB_BANDS_1];
+      for (i=0;i<NB_BANDS_1;i++) {
+        diff[i] = x[i] - ceps_codebook1[index2[k][0]*NB_BANDS_1 + i] - ceps_codebook2[index2[k][1]*NB_BANDS_1 + i];
+      }
+      vq_quantize_mbest(ceps_codebook3, 1024, diff, NB_BANDS_1, SURVIVORS, curr_dist, curr_index);
+      if (k==0) {
+        for (m=0;m<SURVIVORS;m++) {
+          index3[m][0] = index2[k][0];
+          index3[m][1] = index2[k][1];
+          index3[m][2] = curr_index[m];
+          glob_dist[m] = curr_dist[m];
+        }
+        //printf("%f ", glob_dist[0]);
+      } else if (curr_dist[0] < glob_dist[SURVIVORS-1]) {
+        m=0;
+        int pos;
+        for (pos=0;pos<SURVIVORS;pos++) {
+          if (curr_dist[m] < glob_dist[pos]) {
+            int j;
+            for (j=SURVIVORS-1;j>=pos+1;j--) {
+              glob_dist[j] = glob_dist[j-1];
+              index3[j][0] = index3[j-1][0];
+              index3[j][1] = index3[j-1][1];
+              index3[j][2] = index3[j-1][2];
+            }
+            glob_dist[pos] = curr_dist[m];
+            index3[pos][0] = index2[k][0];
+            index3[pos][1] = index2[k][1];
+            index3[pos][2] = curr_index[m];
+            m++;
+          }
+        }
+      }
+    }
+    id = index3[0][0];
+    id2 = index3[0][1];
+    id3 = index3[0][2];
+    //printf("%f ", glob_dist[0]);
+    for (i=0;i<NB_BANDS_1;i++) {
+        x[i] -= ceps_codebook1[id*NB_BANDS_1 + i];
+    }
+    for (i=0;i<NB_BANDS_1;i++) {
+        x[i] -= ceps_codebook2[id2*NB_BANDS_1 + i];
+    }
+    //id3 = vq_quantize(ceps_codebook3, 1024, x, NB_BANDS_1, NULL);
+    for (i=0;i<NB_BANDS_1;i++) {
+        x[i] = ceps_codebook1[id*NB_BANDS_1 + i] + ceps_codebook2[id2*NB_BANDS_1 + i] + ceps_codebook3[id3*NB_BANDS_1 + i];
+    }
+    if (0) {
+        float err = 0;
+        for (i=0;i<NB_BANDS_1;i++) {
+            err += (x[i]-ref[i])*(x[i]-ref[i]);
+        }
+        printf("%f\n", sqrt(err/NB_BANDS));
+    }
+    
+    return id;
+}
+
 static int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist, int sign)
 {
   int i, j;
@@ -564,7 +706,7 @@
   RNN_COPY(&st->xc[1][0], &st->xc[9][0], PITCH_MAX_PERIOD);
   //printf("%f\n", st->features[3][0]);
   st->features[3][0] = floor(.5 + st->features[3][0]*5)/5;
-  quantize_2stage(&st->features[3][1]);
+  quantize_3stage_mbest(&st->features[3][1]);
   /*perform_interp_relaxation(st->features, vq_mem);*/
   quantize_diff(&st->features[1][0], vq_mem, &st->features[3][0], ceps_codebook_diff4, 11, 1);
 #if 0
--