shithub: opus

Download patch

ref: 2112f3dd76f558a8e68cd437ab0e15b02ee094b2
parent: c1b357ed47656c69a33c336d4522fd051bcd229d
author: Jan Buethe <jbuethe@amazon.de>
date: Wed Oct 19 06:58:24 EDT 2022

some fixes

--- a/dnn/nfec_enc.c
+++ b/dnn/nfec_enc.c
@@ -2,8 +2,6 @@
 #include "nnet.h"
 #include "nfec_enc_data.h"
 
-
-
 void nfec_encode_dframe(struct NFECEncState *enc_state, float *latents, float *initial_state, const float *input)
 {
     float buffer[ENC_DENSE1_OUT_SIZE + ENC_DENSE2_OUT_SIZE + ENC_DENSE3_OUT_SIZE + ENC_DENSE4_OUT_SIZE + ENC_DENSE5_OUT_SIZE + ENC_DENSE6_OUT_SIZE + ENC_DENSE7_OUT_SIZE + ENC_DENSE8_OUT_SIZE + GDENSE1_OUT_SIZE];
--- a/dnn/nfec_enc_demo.c
+++ b/dnn/nfec_enc_demo.c
@@ -5,7 +5,7 @@
 
 void usage()
 {
-    printf("nfec_enc_demo <features>");
+    printf("nfec_enc_demo <features> <latents path> <states path>\n");
     exit(1);
 }
 
@@ -17,9 +17,11 @@
     float latents[80];
     float initial_state[24];
     int index = 0;
-    FILE *fid;
+    FILE *fid, *latents_fid, *states_fid;
 
-    if (argc < 2)
+    memset(&enc_state, 0, sizeof(enc_state));
+
+    if (argc < 4)
     {
         usage();
     }
@@ -31,16 +33,37 @@
         usage();
     }
 
+    latents_fid = fopen(argv[2], "wb");
+    if (latents_fid == NULL)
+    {
+        fprintf(stderr, "could not open latents file %s\n", argv[2]);
+        usage();
+    }
+
+    states_fid = fopen(argv[3], "wb");
+    if (fid == NULL)
+    {
+        fprintf(stderr, "could not open states file %s\n", argv[3]);
+        usage();
+    }
+
+
     while (fread(feature_buffer, sizeof(float), 32, fid) == 32)
     {
-        memcpy(dframe[16 * index++], feature_buffer, 16*sizeof(float));
+        memcpy(&dframe[16 * index++], feature_buffer, 16*sizeof(float));
 
         if (index == 2)
         {
             nfec_encode_dframe(&enc_state, latents, initial_state, dframe);
             index = 0;
+            fwrite(latents, sizeof(float), NFEC_LATENT_DIM, latents_fid);
+            fwrite(initial_state, sizeof(float), GDENSE2_OUT_SIZE, states_fid);
         }
     }
+
+    fclose(fid);
+    fclose(states_fid);
+    fclose(latents_fid);
 }
 
-/* gcc -DDISABLE_DOT_PROD nfec_enc_demo.c nfec_enc.c nnet.c nfec_enc_data.c -o nfec_enc_demo */
\ No newline at end of file
+/* gcc -DDISABLE_DOT_PROD -DDISABLE_NEON nfec_enc_demo.c nfec_enc.c nnet.c nfec_enc_data.c kiss99.c -o nfec_enc_demo */
\ No newline at end of file
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -37,7 +37,7 @@
 
 #if defined(__AVX__) || defined(__SSE2__)
 #include "vec_avx.h"
-#elif defined(__ARM_NEON__) || defined(__ARM_NEON)
+#elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) && !defined(DISABLE_NEON)
 #include "vec_neon.h"
 #else
 
--