ref: fa7b432eed4e9fc11d5f2bfe10c4f54d89dd1788
parent: d98c59fb9ae60acae66da0198e29ea694b68b184
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sat May 27 21:53:20 EDT 2023
Initial blob loading support
--- a/dnn/Makefile.am
+++ b/dnn/Makefile.am
@@ -62,7 +62,7 @@
dump_data_LDADD = $(LIBM)
dump_data_CFLAGS = $(AM_CFLAGS)
-dump_weights_blob_SOURCES = nnet_data.c plc_data.c write_lpcnet_weights.c
+dump_weights_blob_SOURCES = write_lpcnet_weights.c
dump_weights_blob_LDADD = $(LIBM)
dump_weights_blob_CFLAGS = $(AM_CFLAGS) -DDUMP_BINARY_WEIGHTS
--- a/dnn/autogen.sh
+++ b/dnn/autogen.sh
@@ -6,7 +6,7 @@
test -n "$srcdir" && cd "$srcdir"
#SHA1 of the first commit compatible with the current model
-commit=399be7c
+commit=859bfae
./download_model.sh $commit
echo "Updating build configuration files for lpcnet, please wait...."
--- a/dnn/include/lpcnet.h
+++ b/dnn/include/lpcnet.h
@@ -199,4 +199,7 @@
LPCNET_EXPORT void lpcnet_plc_fec_clear(LPCNetPLCState *st);
+LPCNET_EXPORT int lpcnet_load_model(LPCNetState *st, const unsigned char *data, int len);
+LPCNET_EXPORT int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len);
+
#endif
--- a/dnn/lpcnet.c
+++ b/dnn/lpcnet.c
@@ -183,9 +183,23 @@
lpcnet->sampling_logit_table[i] = -log((1-prob)/prob);
}
kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string));
+#ifndef USE_WEIGHTS_FILE
ret = init_lpcnet_model(&lpcnet->model, lpcnet_arrays);
+#else
+ ret = 0;
+#endif
celt_assert(ret == 0);
return ret;
+}
+
+LPCNET_EXPORT int lpcnet_load_model(LPCNetState *st, const unsigned char *data, int len) {
+ WeightArray *list;
+ int ret;
+ parse_weights(&list, data, len);
+ ret = init_lpcnet_model(&st->model, list);
+ free(list);
+ if (ret == 0) return 0;
+ else return -1;
}
--- a/dnn/lpcnet_demo.c
+++ b/dnn/lpcnet_demo.c
@@ -34,6 +34,49 @@
#include "lpcnet.h"
#include "freq.h"
+#ifdef USE_WEIGHTS_FILE
+# if __unix__
+# include <fcntl.h>
+# include <sys/mman.h>
+# include <unistd.h>
+# include <sys/stat.h>
+/* When available, mmap() is preferable to reading the file, as it leads to
+ better resource utilization, especially if multiple processes are using the same
+ file (mapping will be shared in cache). */
+unsigned char *load_blob(const char *filename, int *len) {
+ int fd;
+ unsigned char *data;
+ struct stat st;
+ stat(filename, &st);
+ *len = st.st_size;
+ fd = open(filename, O_RDONLY);
+ data = mmap(NULL, *len, PROT_READ, MAP_SHARED, fd, 0);
+ close(fd);
+ return data;
+}
+void free_blob(unsigned char *blob, int len) {
+ munmap(blob, len);
+}
+# else
+unsigned char *load_blob(const char *filename, int *len) {
+ FILE *file;
+ unsigned char *data;
+ file = fopen(filename, "r");
+ fseek(file, 0L, SEEK_END);
+ *len = ftell(file);
+ fseek(file, 0L, SEEK_SET);
+ if (*len <= 0) return NULL;
+ data = malloc(*len);
+ *len = fread(data, 1, *len, file);
+ return data;
+}
+void free_blob(unsigned char *blob, int len) {
+ free(blob);
+ (void)len;
+}
+# endif
+#endif
+
#define MODE_ENCODE 0
#define MODE_DECODE 1
#define MODE_FEATURES 2
@@ -64,6 +107,11 @@
FILE *plc_file = NULL;
const char *plc_options;
int plc_flags=-1;
+#ifdef USE_WEIGHTS_FILE
+ int len;
+ unsigned char *data;
+ const char *filename = "weights_blob.bin";
+#endif
if (argc < 4) usage();
if (strcmp(argv[1], "-encode") == 0) mode=MODE_ENCODE;
else if (strcmp(argv[1], "-decode") == 0) mode=MODE_DECODE;
@@ -109,7 +157,9 @@
fprintf(stderr, "Can't open %s\n", argv[3]);
exit(1);
}
-
+#ifdef USE_WEIGHTS_FILE
+ data = load_blob(filename, &len);
+#endif
if (mode == MODE_ENCODE) {
LPCNetEncState *net;
net = lpcnet_encoder_create();
@@ -152,6 +202,9 @@
} else if (mode == MODE_SYNTHESIS) {
LPCNetState *net;
net = lpcnet_create();
+#ifdef USE_WEIGHTS_FILE
+ lpcnet_load_model(net, data, len);
+#endif
while (1) {
float in_features[NB_TOTAL_FEATURES];
float features[NB_FEATURES];
@@ -207,5 +260,8 @@
}
fclose(fin);
fclose(fout);
+#ifdef USE_WEIGHTS_FILE
+ free_blob(data, len);
+#endif
return 0;
}
--- a/dnn/lpcnet_plc.c
+++ b/dnn/lpcnet_plc.c
@@ -68,9 +68,23 @@
return -1;
}
st->remove_dc = !!(options&LPCNET_PLC_DC_FILTER);
+#ifndef USE_WEIGHTS_FILE
ret = init_plc_model(&st->model, lpcnet_plc_arrays);
+#else
+ ret = 0;
+#endif
celt_assert(ret == 0);
return ret;
+}
+
+LPCNET_EXPORT int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len) {
+ WeightArray *list;
+ int ret;
+ parse_weights(&list, data, len);
+ ret = init_plc_model(&st->model, list);
+ free(list);
+ if (ret == 0) return 0;
+ else return -1;
}
LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(int options) {
--- a/dnn/lpcnet_private.h
+++ b/dnn/lpcnet_private.h
@@ -131,4 +131,6 @@
void process_single_frame(LPCNetEncState *st, FILE *ffeat);
void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features);
+
+int parse_weights(WeightArray **list, const unsigned char *data, int len);
#endif
--- a/dnn/write_lpcnet_weights.c
+++ b/dnn/write_lpcnet_weights.c
@@ -31,8 +31,14 @@
#include <stdio.h>
#include "nnet.h"
-extern const WeightArray lpcnet_arrays[];
-extern const WeightArray lpcnet_plc_arrays[];
+/* This is a bit of a hack because we need to build nnet_data.c and plc_data.c without USE_WEIGHTS_FILE,
+ but USE_WEIGHTS_FILE is defined in config.h. */
+#undef HAVE_CONFIG_H
+#ifdef USE_WEIGHTS_FILE
+#undef USE_WEIGHTS_FILE
+#endif
+#include "nnet_data.c"
+#include "plc_data.c"
void write_weights(const WeightArray *list, FILE *fout)
{
--
⑨