shithub: opus

Download patch

ref: b3198a09dafffa0282487c329148b02bad27ef62
parent: e6347180367c55ae1cf0813c30084c6c4d8b0b86
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Thu Jan 31 09:55:43 EST 2019

Add frame weighting, doubling prevention

--- a/dnn/dump_data.c
+++ b/dnn/dump_data.c
@@ -160,12 +160,17 @@
     int sub;
     static float xc[10][PITCH_MAX_PERIOD+1];
     static float ener[10][PITCH_MAX_PERIOD];
+    static float frame_weight[10];
     static float frame_max_corr[PITCH_MAX_PERIOD];
+    static float ener_follow;
     /* Cross-correlation on half-frames. */
     for (sub=0;sub<2;sub++) {
       int off = sub*FRAME_SIZE/2;
       celt_pitch_xcorr(&st->exc_buf[PITCH_MAX_PERIOD+off], st->exc_buf+off, xcorr, FRAME_SIZE/2, PITCH_MAX_PERIOD);
       ener0 = celt_inner_prod(&st->exc_buf[PITCH_MAX_PERIOD+off], &st->exc_buf[PITCH_MAX_PERIOD+off], FRAME_SIZE/2);
+      ener_follow = MAX16(.7*ener_follow, ener0);
+      frame_weight[2+2*pcount+sub] = ener0/(1+ener_follow);
+      //printf("%f\n", frame_weight[2+2*pcount+sub]);
       for (i=0;i<PITCH_MAX_PERIOD;i++) {
         ener[2+2*pcount+sub][i] = (1 + ener0 + celt_inner_prod(&st->exc_buf[i+off], &st->exc_buf[i+off], FRAME_SIZE/2));
         xc[2+2*pcount+sub][i] = 2*xcorr[i] / ener[2+2*pcount+sub][i];
@@ -195,18 +200,23 @@
       for(sub=0;sub<8;sub++) {
         float max_path_all = -1e15;
         best_i = 0;
+        for (i=0;i<PITCH_MAX_PERIOD-2*PITCH_MIN_PERIOD;i++) {
+          float xc_half = MAX16(MAX16(xc[2+sub][(PITCH_MAX_PERIOD+i)/2], xc[2+sub][(PITCH_MAX_PERIOD+i+2)/2]), xc[2+sub][(PITCH_MAX_PERIOD+i-1)/2]);
+          if (xc[2+sub][i] < xc_half*1.1) xc[2+sub][i] *= .8;
+        }
         for (i=0;i<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;i++) {
           int j;
           float max_prev;
-          max_prev = st->pitch_max_path_all - 1.5f;
+          period = PITCH_MAX_PERIOD-i;
+          max_prev = st->pitch_max_path_all - 6.f;
           pitch_prev[sub][i] = st->best_i;
           for (j=IMIN(0, 4-i);j<=4 && i+j<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;j++) {
             if (st->pitch_max_path[0][i+j] > max_prev) {
-              max_prev = st->pitch_max_path[0][i+j] - .05f*abs(j);
+              max_prev = st->pitch_max_path[0][i+j] - .02f*abs(j)*abs(j);
               pitch_prev[sub][i] = i+j;
             }
           }
-          st->pitch_max_path[1][i] = max_prev + xc[2+sub][i];
+          st->pitch_max_path[1][i] = max_prev + frame_weight[2+sub]*xc[2+sub][i];
           if (st->pitch_max_path[1][i] > max_path_all) {
             max_path_all = st->pitch_max_path[1][i];
             best_i = i;
--- a/dnn/train_lpcnet.py
+++ b/dnn/train_lpcnet.py
@@ -103,7 +103,7 @@
 del in_exc
 
 # dump models to disk as we go
-checkpoint = ModelCheckpoint('lpcnet20h_384_10_G16_{epoch:02d}.h5')
+checkpoint = ModelCheckpoint('lpcnet24b_384_10_G16_{epoch:02d}.h5')
 
 #model.load_weights('lpcnet9b_384_10_G16_01.h5')
 model.compile(optimizer=Adam(0.001, amsgrad=True, decay=5e-5), loss='sparse_categorical_crossentropy')
--