shithub: util

Download patch

ref: 462f8462b7b7b7c05a5c2edf1b1d074f0b7774e4
parent: b299aa428c810f2b65d47c8e03aad9a5e527ab62
author: eli <eli@singularity>
date: Wed Dec 25 21:13:23 EST 2024

ann output layer activation function change

--- a/ann.c
+++ b/ann.c
@@ -47,6 +47,8 @@
 double gradient_tanh(Neuron*);
 double activation_leaky_relu(Neuron*);
 double gradient_leaky_relu(Neuron*);
+double activation_piece(Neuron*);
+double gradient_piece(Neuron*);
 
 #define ACTIVATION activation_leaky_relu
 #define GRADIENT gradient_leaky_relu
@@ -120,6 +122,22 @@
 	return 0.01;
 }
 
+double
+activation_piece(Neuron *in)
+{
+	if (in->sum < 0.0)
+		return 0.0;
+	else if (in->sum > 1.0)
+		return 1.0;
+	return in->sum;
+}
+
+double
+gradient_piece(Neuron*)
+{
+	return 1.0;
+}
+
 Weights*
 weightsinitdoubles(Weights *in, double *init)
 {
@@ -263,7 +281,7 @@
 		if (i < (num_layers-1))
 			ret->layers[i] = layercreate(arg, ACTIVATION, GRADIENT);
 		else
-			ret->layers[i] = layercreate(arg, activation_sigmoid, gradient_sigmoid);
+			ret->layers[i] = layercreate(arg, activation_piece, gradient_piece);
 		if (i > 0) {
 			ret->weights[i-1] = weightscreate(ret->layers[i-1]->n, ret->layers[i]->n, 1);
 			ret->deltas[i-1] = weightscreate(ret->layers[i-1]->n, ret->layers[i]->n, 0);
@@ -825,7 +843,12 @@
 				if (trainline == 0) {
 					runoutput = annrun(ann, input);
 					for (i = 0; i < noutput; i++)
-						print("%f%c", runoutput[i], (i == (noutput-1))? '\n': ' ');
+						if (runoutput[i] == 0.0)
+							print("0%c", (i == (noutput-1))? '\n': ' ');
+						else if (runoutput[i] == 1.0)
+							print("1%c", (i == (noutput-1))? '\n': ' ');
+						else
+							print("%f%c", runoutput[i], (i == (noutput-1))? '\n': ' ');
 					free(runoutput);
 				}
 
--