ref: 02a55f299a829e48e10d1591ea449a225e94f1db
dir: /ann/main.c/
#include <u.h>
#include <libc.h>
#include <ctype.h>
#include "ann.h"
void
usage(char **argv)
{
fprint(2, "usage: %s [-train] filename [num_layers num_input_layer ... num_output_layer]\n", argv[0]);
exits("usage");
}
void
main(int argc, char **argv)
{
Ann *ann;
char *filename;
int train;
Dir *dir;
int num_layers = 0;
int *layers = nil;
int i;
char *line;
double *input;
double *output = nil;
double *runoutput;
int ninput;
int noutput;
int offset;
double f;
int trainline;
int nline;
train = 0;
if (argc < 2)
usage(argv);
filename = argv[1];
if (argv[1][0] == '-' && argv[1][1] == 't') {
if (argc < 3)
usage(argv);
train = 1;
filename = argv[2];
}
ann = nil;
dir = dirstat(filename);
if (dir != nil) {
free(dir);
ann = annload(filename);
if (ann == nil)
exits("load");
}
if (argc >= (train + 3)) {
num_layers = atoi(argv[train + 2]);
if (num_layers < 2 || argc != (train + 3 + num_layers))
usage(argv);
layers = calloc(num_layers, sizeof(int));
for (i = 0; i < num_layers; i++)
layers[i] = atoi(argv[train + 3 + i]);
}
if (num_layers > 0) {
if (ann != nil) {
if (ann->n != num_layers) {
fprint(2, "num_layers: %d != %d\n", ann->n, num_layers);
exits("num_layers");
}
for (i = 0; i < num_layers; i++) {
if (layers[i] != ann->layers[i]->n) {
fprint(2, "num_layer_%d: %d != %d\n", i, layers[i], ann->layers[i]->n);
exits("num_layer");
}
}
} else {
ann = anncreatev(num_layers, layers);
if (ann == nil)
exits("anncreatev");
}
}
if (ann == nil) {
fprint(2, "file not found: %s\n", filename);
exits("file not found");
}
ninput = ann->layers[0]->n;
noutput = ann->layers[ann->n - 1]->n;
input = calloc(ninput, sizeof(double));
if (train == 1)
output = calloc(noutput, sizeof(double));
trainline = 0;
nline = ninput;
do {
int i = 0;
while ((line = readline(0)) != nil) {
do {
if (strlen(line) == 0)
break;
while(isspace(*line))
line++;
if (strlen(line) == 0)
break;
offset = 0;
while (isdigit(line[offset]) || line[offset] == '.' || line[offset] == '-')
offset++;
if (!isspace(line[offset]) && line[offset] != '\0') {
fprint(2, "input error: %s\n", line);
exits("input");
}
f = atof(line);
if (trainline == 0) {
input[i] = f;
i++;
} else {
output[i] = f;
i++;
}
line = &line[offset];
while(isspace(*line))
line++;
} while(i < nline && strlen(line) > 0);
if (i == nline) {
if (trainline == 0) {
runoutput = annrun(ann, input);
for (i = 0; i < noutput; i++)
/* if (runoutput[i] == 0.0)
print("0%c", (i == (noutput-1))? '\n': ' ');
else if (runoutput[i] == 1.0)
print("1%c", (i == (noutput-1))? '\n': ' ');
else */
print("%f%c", runoutput[i], (i == (noutput-1))? '\n': ' ');
free(runoutput);
}
if (train == 1) {
if (trainline == 0) {
trainline = 1;
nline = noutput;
} else {
anntrain(ann, input, output);
trainline = 0;
nline = ninput;
}
}
i = 0;
}
}
} while(line != nil);
if (annsave(filename, ann) != 0)
exits("save");
exits(nil);
}