/**
 * bp372trainer.c
 *
 * A training program for backprop networks. 
 */

#include "backprop.h"
#include "bp372.h"
#include "training_data.h"
#include <stdio.h>
#include <stdlib.h>
#include <strings.h>

#define MAX_EPOCHS 1000			// maximum number of epochs
#define TOLERANCE 0.1			// maximum allowed error in output vector

int main(int argc, char* argv[])
{
	int i,j,k;
	double error, tolerance;
	int max_epochs;
	char* filename;

	// Allocate memory for the network.
	/* float x1[IN_SIZE];
	float W1[HID_SIZE][IN_SIZE];
	float y1[HID_SIZE];
	float g1[HID_SIZE];

	float x2[HID_SIZE];
	float W2[OUT_SIZE][HID_SIZE];
	float y2[OUT_SIZE];
	float g2[OUT_SIZE];
    */
	// Default settings
	tolerance = TOLERANCE;
	max_epochs = MAX_EPOCHS;
	RandomizeNetwork(&network);

	// Load data from file if filename is present
	if(argc>0)
	{
		filename=NULL;
		for(i=0;i<argc-1;i++)
		{
			if(strcmp("-i",argv[i])==0)
			{
				filename = argv[i+1];
				printf("Loading binary data from file %s\n",filename);
				LoadNetworkWeights(&network, filename);
			}

			if(strcmp("-t",argv[i])==0)
			{
				tolerance = strtod(argv[i+1],0);
				printf("Tolerance %f\n",tolerance);

			}

			if(strcmp("-e",argv[i])==0)
			{
				max_epochs = (int) strtol(argv[i+1],NULL,0);
				printf("Maximum epochs %d\n",max_epochs);
			}

			if(strcmp("-l",argv[i])==0)
			{
				network.lr = strtod(argv[i+1],0);
				printf("Network learning rate %f\n",network.lr);
			}
		}
	}



	// Tnput data to network and activate
	printf("\nInitial state\n");
	for(j=0;j<TRAINING_SET_SIZE;j++)
	{
		InputToNetwork(&network,&X[j][0]);
		ActivateNetwork(&network);
		PrintNetworkOutput(&network);
	}

	// Train the network.  The training is done in "batches" or "epochs".
	// meaning that each input/output pair is applied once to the network
	// for one epoch.  The training set is then repeatedly applied to the
	// network until the error is below a set minimum threshold.
	printf("\nBatch Training ... ");
	k=0;
	error = 2.0 * tolerance;
	while(error > tolerance && k<max_epochs)
	{
		for(j=0;j<TRAINING_SET_SIZE;j++)
		{
			InputToNetwork(&network,&X[j][0]);
			ActivateNetwork(&network);
			error = TrainNetwork(&network,&Y[j][0]);
		}
		k++;
	}

	// Display results of training
	printf("done\nTotal error = %f\n# of epochs %d\n",error,k);
	if(k>=max_epochs)
	{
		printf("warning: maximum epochs reached\n");
	}

	// Input data to network and activate
	printf("\nAfter training\n");
	for(j=0;j<TRAINING_SET_SIZE;j++)
	{
		InputToNetwork(&network,&X[j][0]);
		ActivateNetwork(&network);
		PrintNetworkOutput(&network);
	}

	// Write weight data to a file if filename is present
	if(argc>0)
	{
		filename=NULL;
		for(i=0;i<argc-1;i++)
		{
			if(strcmp("-f",argv[i])==0)
			{
				filename = argv[i+1];
				printf("\nSaving to file %s ",filename);
   			}
		}

		if(filename)
		{
			// save the network data
			if(strstr(filename,".h"))
			{
				printf("in HDT format\n");
				SaveNetworkWeightsHDT(&network,filename);
			}
			else
			{
				printf("in binary format\n");
				SaveNetworkWeights(&network,filename);
			}
		}
	}

	return 0;
}



