shithub: util

Download patch

ref: 1bdf8fcfd4d4208415b405fa5e0e9aa501a782c8
parent: bfa694f3ddb1c7784f82e9017814d10b73ee4ce8
author: eli <eli@singularity>
date: Sat Sep 20 18:06:52 EDT 2025

train rate setting

--- a/ann/ann.c
+++ b/ann/ann.c
@@ -257,6 +257,18 @@
 }
 
 double
+torque(double input)
+{
+	if (input < -.9999999)
+		return -17.0;
+
+	if (input > .9999999)
+		return 17.0;
+
+	return log((1.0 + input) / (1.0 - input));
+}
+
+double
 anntrain(Ann *ann, double *inputs, double *outputs)
 {
 	double *error = annrun(ann, inputs);
@@ -272,12 +284,7 @@
 		error[o] -= outputs[o];
 		error[o] = -error[o];
 		ret += pow(error[o], 2.0) * 0.5;
-		if (error[o] < -.9999999)
-			error[o] = -17.0;
-		else if (error[o] > .9999999)
-			error[o] = 17.0;
-		else
-			error[o] = log((1.0 + error[o]) / (1.0 - error[o]));
+		error[o] = torque(error[o]);
 	}
 	D = ann->deltas[ann->n-2];
 	weightsinitdoubles(D, error);
@@ -373,6 +380,7 @@
 		error[o] -= outputs[o];
 		error[o] = -error[o];
 		ret += pow(error[o], 2.0) * 0.5;
+		error[o] = torque(error[o]);
 	}
 	D = ann->deltas[ann->n-2];
 	weightsinitdoubles(D, error);
@@ -454,6 +462,7 @@
 		error[o] -= outputs[o];
 		error[o] = -error[o];
 		ret += pow(error[o], 2.0) * 0.5;
+		error[o] = torque(error[o]);
 	}
 	D = ann->deltas[ann->n-2];
 	weightsinitdoubles(D, error);
--- a/ann/main.c
+++ b/ann/main.c
@@ -6,7 +6,7 @@
 void
 usage(char **argv)
 {
-	fprint(2, "usage: %s [-train] filename [num_layers num_input_layer ... num_output_layer]\n", argv[0]);
+	fprint(2, "usage: %s [-train [-rate n]] filename [num_layers num_input_layer ... num_output_layer]\n", argv[0]);
 	exits("usage");
 }
 
@@ -30,6 +30,7 @@
 	double f;
 	int trainline;
 	int nline;
+	double rate = 0.7;
 
 	train = 0;
 
@@ -43,7 +44,23 @@
 			usage(argv);
 
 		train = 1;
-		filename = argv[2];
+
+		if (argv[train + 1][0] == '-') {
+			if (argv[train + 1][1] == 'r') {
+				if (argc < 5)
+					usage(argv);
+
+				train++;
+
+				rate = atof(argv[train + 1]);
+				if (rate == 0.0)
+					usage(argv);
+
+				train++;
+			}
+		}
+
+		filename = argv[train + 1];
 	}
 
 	ann = nil;
@@ -92,10 +109,12 @@
 		exits("file not found");
 	}
 
+	ann->rate = rate;
+
 	ninput = ann->layers[0]->n;
 	noutput = ann->layers[ann->n - 1]->n;
 	input = calloc(ninput, sizeof(double));
-	if (train == 1)
+	if (train != 0)
 		output = calloc(noutput, sizeof(double));
 
 	trainline = 0;
@@ -135,21 +154,17 @@
 				if (trainline == 0) {
 					runoutput = annrun(ann, input);
 					for (i = 0; i < noutput; i++)
-/*						if (runoutput[i] == 0.0)
-							print("0%c", (i == (noutput-1))? '\n': ' ');
-						else if (runoutput[i] == 1.0)
-							print("1%c", (i == (noutput-1))? '\n': ' ');
-						else */
 							print("%f%c", runoutput[i], (i == (noutput-1))? '\n': ' ');
 					free(runoutput);
 				}
 
-				if (train == 1) {
+				if (train != 0) {
 					if (trainline == 0) {
 						trainline = 1;
 						nline = noutput;
 					} else {
 						anntrain(ann, input, output);
+
 						trainline = 0;
 						nline = ninput;
 					}
@@ -159,7 +174,7 @@
 		}
 	} while(line != nil);
 
-	if (train == 1 && annsave(filename, ann) != 0)
+	if (train != 0 && annsave(filename, ann) != 0)
 		exits("save");
 
 	exits(nil);
--