shithub: moonfish

ref: 2d4418c24ad95172831580fe71bb00259eb716d0
dir: /tools/learn.c/

View raw version
/* moonfish is licensed under the AGPL (v3 or later) */
/* copyright 2024 zamfofex */

#include <errno.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#include "../moonfish.h"

#define moonfish_count 192

static double moonfish_next_line(char *line, FILE *file)
{
	char *arg, *end;
	double score;
	
	if (fgets(line, 2048, file) == NULL) {
		
		if (!feof(stdin)) {
			perror("fgets");
			exit(1);
		}
		
		errno = 0;
		rewind(file);
		if (errno) {
			perror("rewind");
			exit(1);
		}
	}
	
	arg = strrchr(line, ' ');
	if (arg == NULL) {
		fprintf(stderr, "malformed FEN line\n");
		exit(1);
	}
	
	errno = 0;
	score = strtod(arg + 1, &end);
	if (errno || (*end != 0 && *end != '\n') || score > 10000 || score < -10000) {
		fprintf(stderr, "unexpected score\n");
		exit(1);
	}
	
	return score;
}

static double moonfish_gradient(double *gradient, double score0, char *fen)
{
	int i;
	double prev;
	double score, error;
	struct moonfish_chess chess;
	
	moonfish_chess(&chess);
	moonfish_from_fen(&chess, fen);
	score = moonfish_score(&chess);
	error = score - score0;
	
	for (i = 0 ; i < moonfish_count ; i++) {
		prev = moonfish_values[i];
		moonfish_values[i] += 1.0 / 256 / 256 / 8;
		moonfish_from_fen(&chess, fen);
		gradient[i] += (moonfish_score(&chess) - score) * 256 * 256 * error;
		moonfish_values[i] = prev;
	}
	
	if (error < 0) error *= -1;
	return error;
}

static double moonfish_step(FILE *file, double *gradient)
{
	static char line[2048];
	
	int i;
	double score;
	double error;
	
	error = 0;
	
	for (i = 0 ; i < moonfish_count ; i++) gradient[i] = 0;
	
	for (i = 0 ; i < 2048 ; i++) {
		score = moonfish_next_line(line, file);
		error += moonfish_gradient(gradient, score, line);
	}
	
	for (i = 0 ; i < moonfish_count ; i++) moonfish_values[i] -= gradient[i] / 2048;
	
	return error;
}

int main(int argc, char **argv)
{
	static double gradient[moonfish_count];
	
	FILE *file;
	int i;
	double error;
	int iteration;
	
	if (argc != 2) {
		if (argc > 0) fprintf(stderr, "usage: %s <file-name>\n", argv[0]);
		return 1;
	}
	
	file = fopen(argv[1], "r");
	if (file == NULL) {
		perror("fopen");
		return 1;
	}
	
	iteration = 0;
	
	for (;;) {
		
		iteration++;
		if (iteration > 0x1000) return 0;
		
		error = moonfish_step(file, gradient);
		
		printf("\n");
		for (i = 0 ; i < moonfish_count ; i++) printf("%.0f,", moonfish_values[i]);
		printf("\n");
		
		printf("iteration: %d\n", iteration);
		printf("current error: ");
		if (error > 10000 * 1000) {
			printf("%.0fM\n", error / 1000 / 1000);
			continue;
		}
		if (error > 10000) {
			printf("%.0fK\n", error / 1000);
			continue;
		}
		printf("%.0f\n", error);
	}
}