ciir.umass.edu.learning.neuralnet.ListNet Maven / Gradle / Ivy
/*===============================================================================
* Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved.
*
* Use of the RankLib package is subject to the terms of the software license set
* forth in the LICENSE file included with this software, and also available at
* http://people.cs.umass.edu/~vdang/ranklib_license.html
*===============================================================================
*/
package ciir.umass.edu.learning.neuralnet;
import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;
public class ListNet extends RankNet {
//Parameters
public static int nIteration = 1500;
public static double learningRate = 0.00001;
public static int nHiddenLayer = 0;//FIXED, it doesn't work with hidden layer
public ListNet()
{
}
public ListNet(List samples, int [] features, MetricScorer scorer)
{
super(samples, features, scorer);
}
protected float[] feedForward(RankList rl)
{
float[] labels = new float[rl.size()];
for(int i=0;i lastError && Neuron.learningRate > 0.0000001)
//Neuron.learningRate *= 0.9;
lastError = error;
}
public void init()
{
PRINT("Initializing... ");
//Set up the network
setInputOutput(features.length, 1, 1);
wire();
if(validationSamples != null)
for(int i=0;i());
Neuron.learningRate = learningRate;
PRINTLN("[Done]");
}
public void learn()
{
PRINTLN("-----------------------------------------");
PRINTLN("Training starts...");
PRINTLN("--------------------------------------------------");
PRINTLN(new int[]{7, 14, 9, 9}, new String[]{"#epoch", "C.E. Loss", scorer.name()+"-T", scorer.name()+"-V"});
PRINTLN("--------------------------------------------------");
for(int i=1;i<=nIteration;i++)
{
for(int j=0;j bestScoreOnValidationData)
{
bestScoreOnValidationData = score;
saveBestModelOnValidation();
}
PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4)+""});
}
}
PRINTLN("");
}
//if validation data is specified ==> best model on this data has been saved
//we now restore the current model to that best model
if(validationSamples != null)
restoreBestModelOnValidation();
scoreOnTrainingData = SimpleMath.round(scorer.score(rank(samples)), 4);
PRINTLN("--------------------------------------------------");
PRINTLN("Finished sucessfully.");
PRINTLN(scorer.name() + " on training data: " + scoreOnTrainingData);
if(validationSamples != null)
{
bestScoreOnValidationData = scorer.score(rank(validationSamples));
PRINTLN(scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
}
PRINTLN("---------------------------------");
}
public double eval(DataPoint p)
{
return super.eval(p);
}
public Ranker createNew()
{
return new ListNet();
}
public String toString()
{
return super.toString();
}
public String model()
{
String output = "## " + name() + "\n";
output += "## Epochs = " + nIteration + "\n";
output += "## No. of features = " + features.length + "\n";
//print used features
for(int i=0;i l = new ArrayList();
while((content = in.readLine()) != null)
{
content = content.trim();
if(content.length() == 0)
continue;
if(content.indexOf("##")==0)
continue;
l.add(content);
}
in.close();
//load the network
//the first line contains features information
String[] tmp = l.get(0).split(" ");
features = new int[tmp.length];
for(int i=0;i