All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ciir.umass.edu.learning.neuralnet.ListNet Maven / Gradle / Ivy

/*===============================================================================
 * Copyright (c) 2010-2012 University of Massachusetts.  All Rights Reserved.
 *
 * Use of the RankLib package is subject to the terms of the software license set 
 * forth in the LICENSE file included with this software, and also available at
 * http://people.cs.umass.edu/~vdang/ranklib_license.html
 *===============================================================================
 */

package ciir.umass.edu.learning.neuralnet;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;

import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;

public class ListNet extends RankNet {
	
	//Parameters
	public static int nIteration = 1500;
	public static double learningRate = 0.00001; 
	public static int nHiddenLayer = 0;//FIXED, it doesn't work with hidden layer
	
	public ListNet()
	{		
	}
	public ListNet(List samples, int [] features, MetricScorer scorer)
	{
		super(samples, features, scorer);
	}
	
	protected float[] feedForward(RankList rl)
	{
		float[] labels = new float[rl.size()];
		for(int i=0;i lastError && Neuron.learningRate > 0.0000001)
			//Neuron.learningRate *= 0.9;
		lastError = error;
	}
	
	public void init()
	{
		PRINT("Initializing... ");
		
		//Set up the network
		setInputOutput(features.length, 1, 1);
		wire();
		
		if(validationSamples != null)
			for(int i=0;i());
		
		Neuron.learningRate = learningRate;
		PRINTLN("[Done]");
	}
	public void learn()
	{
		PRINTLN("-----------------------------------------");
		PRINTLN("Training starts...");
		PRINTLN("--------------------------------------------------");
		PRINTLN(new int[]{7, 14, 9, 9}, new String[]{"#epoch", "C.E. Loss", scorer.name()+"-T", scorer.name()+"-V"});
		PRINTLN("--------------------------------------------------");
		
		for(int i=1;i<=nIteration;i++)
		{
			for(int j=0;j bestScoreOnValidationData)
					{
						bestScoreOnValidationData = score;
						saveBestModelOnValidation();
					}
					PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4)+""});
				}
			}
			PRINTLN("");
		}
		
		//if validation data is specified ==> best model on this data has been saved
		//we now restore the current model to that best model
		if(validationSamples != null)
			restoreBestModelOnValidation();
		
		scoreOnTrainingData = SimpleMath.round(scorer.score(rank(samples)), 4);
		PRINTLN("--------------------------------------------------");
		PRINTLN("Finished sucessfully.");
		PRINTLN(scorer.name() + " on training data: " + scoreOnTrainingData);
		if(validationSamples != null)
		{
			bestScoreOnValidationData = scorer.score(rank(validationSamples));
			PRINTLN(scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
		}
		PRINTLN("---------------------------------");
	}
	public double eval(DataPoint p)
	{
		return super.eval(p);
	}
	public Ranker createNew()
	{
		return new ListNet();
	}
	public String toString()
	{
		return super.toString();
	}
	public String model()
	{
		String output = "## " + name() + "\n";
		output += "## Epochs = " + nIteration + "\n";
		output += "## No. of features = " + features.length + "\n";
		
		//print used features
		for(int i=0;i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy