ciir.umass.edu.learning.tree.LambdaMART Maven / Gradle / Ivy
/*===============================================================================
* Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved.
*
* Use of the RankLib package is subject to the terms of the software license set
* forth in the LICENSE file included with this software, and also available at
* http://people.cs.umass.edu/~vdang/ranklib_license.html
*===============================================================================
*/
package ciir.umass.edu.learning.tree;
import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.MyThreadPool;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* @author vdang
*
* This class implements LambdaMART.
* Q. Wu, C.J.C. Burges, K. Svore and J. Gao. Adapting Boosting for Information Retrieval Measures.
* Journal of Information Retrieval, 2007.
*/
public class LambdaMART extends Ranker {
//Parameters
public static int nTrees = 1000;//the number of trees
public static float learningRate = 0.1F;//or shrinkage
public static int nThreshold = 256;
public static int nRoundToStopEarly = 100;//If no performance gain on the *VALIDATION* data is observed in #rounds, stop the training process right away.
public static int nTreeLeaves = 10;
public static int minLeafSupport = 1;
//for debugging
public static int gcCycle = 100;
//Local variables
protected float[][] thresholds = null;
protected Ensemble ensemble = null;
protected double[] modelScores = null;//on training data
protected double[][] modelScoresOnValidation = null;
protected int bestModelOnValidation = Integer.MAX_VALUE-2;
//Training instances prepared for MART
protected DataPoint[] martSamples = null;//Need initializing only once
protected int[][] sortedIdx = null;//sorted list of samples in @martSamples by each feature -- Need initializing only once
protected FeatureHistogram hist = null;
protected double[] pseudoResponses = null;//different for each iteration
protected double[] weights = null;//different for each iteration
public LambdaMART()
{
}
public LambdaMART(List samples, int[] features, MetricScorer scorer)
{
super(samples, features, scorer);
}
public void init()
{
PRINT("Initializing... ");
//initialize samples for MART
int dpCount = 0;
for(int i=0;i values = new ArrayList();
float fmax = Float.NEGATIVE_INFINITY;
float fmin = Float.MAX_VALUE;
for(int i=0;i fv)
fmin = fv;
//skip all samples with the same feature value
int j=i+1;
while(j < martSamples.length)
{
if(martSamples[sortedIdx[f][j]].getFeatureValue(features[f]) > fv)
break;
j++;
}
i = j-1;//[i, j] gives the range of samples with the same feature value
}
if(values.size() <= nThreshold || nThreshold == -1)
{
thresholds[f] = new float[values.size()+1];
for(int i=0;i leaves = rt.leaves();
for(int i=0;i bestScoreOnValidationData)
{
bestScoreOnValidationData = score;
bestModelOnValidation = ensemble.treeCount()-1;
}
}
PRINTLN("");
//Should we stop early?
if(m - bestModelOnValidation > nRoundToStopEarly)
break;
}
//Rollback to the best model observed on the validation data
while(ensemble.treeCount() > bestModelOnValidation+1)
ensemble.remove(ensemble.treeCount()-1);
//Finishing up
scoreOnTrainingData = scorer.score(rank(samples));
PRINTLN("---------------------------------");
PRINTLN("Finished sucessfully.");
PRINTLN(scorer.name() + " on training data: " + SimpleMath.round(scoreOnTrainingData, 4));
if(validationSamples != null)
{
bestScoreOnValidationData = scorer.score(rank(validationSamples));
PRINTLN(scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
}
PRINTLN("---------------------------------");
}
public double eval(DataPoint dp)
{
return ensemble.eval(dp);
}
public Ranker createNew()
{
return new LambdaMART();
}
public String toString()
{
return ensemble.toString();
}
public String model()
{
String output = "## " + name() + "\n";
output += "## No. of trees = " + nTrees + "\n";
output += "## No. of leaves = " + nTreeLeaves + "\n";
output += "## No. of threshold candidates = " + nThreshold + "\n";
output += "## Learning rate = " + learningRate + "\n";
output += "## Stop early = " + nRoundToStopEarly + "\n";
output += "\n";
output += toString();
return output;
}
@Override
public void loadFromString(String fullText)
{
try {
String content = "";
//String model = "";
StringBuffer model = new StringBuffer ();
BufferedReader in = new BufferedReader(new StringReader(fullText));
while((content = in.readLine()) != null)
{
content = content.trim();
if(content.length() == 0)
continue;
if(content.indexOf("##")==0)
continue;
//actual model component
//model += content;
model.append (content);
}
in.close();
//load the ensemble
ensemble = new Ensemble(model.toString());
features = ensemble.getFeatures();
}
catch(Exception ex)
{
throw RankLibError.create("Error in LambdaMART::load(): ", ex);
}
}
public void printParameters()
{
PRINTLN("No. of trees: " + nTrees);
PRINTLN("No. of leaves: " + nTreeLeaves);
PRINTLN("No. of threshold candidates: " + nThreshold);
PRINTLN("Min leaf support: " + minLeafSupport);
PRINTLN("Learning rate: " + learningRate);
PRINTLN("Stop early: " + nRoundToStopEarly + " rounds without performance gain on validation data");
}
public String name()
{
return "LambdaMART";
}
public Ensemble getEnsemble()
{
return ensemble;
}
protected void computePseudoResponses()
{
Arrays.fill(pseudoResponses, 0F);
Arrays.fill(weights, 0);
MyThreadPool p = MyThreadPool.getInstance();
if(p.size() == 1)//single-thread
computePseudoResponses(0, samples.size()-1, 0);
else //multi-threading
{
List workers = new ArrayList();
//divide the entire dataset into chunks of equal size for each worker thread
int[] partition = p.partition(samples.size());
int current = 0;
for(int i=0;i need to map back with idx[j] and idx[k]
for(int j=0;j cutoff && k > cutoff)//swaping these pair won't result in any change in target measures since they're below the cut-off point
break;
DataPoint p2 = rl.get(k);
int mk = idx[k];
if(p1.getLabel() > p2.getLabel())
{
double deltaNDCG = Math.abs(changes[j][k]);
if(deltaNDCG > 0)
{
double rho = 1.0 / (1 + Math.exp(modelScores[mj] - modelScores[mk]));
double lambda = rho * deltaNDCG;
pseudoResponses[mj] += lambda;
pseudoResponses[mk] -= lambda;
double delta = rho * (1.0 - rho) * deltaNDCG;
weights[mj] += delta;
weights[mk] += delta;
}
}
}
}
current += orig.size();
}
}
protected void updateTreeOutput(RegressionTree rt)
{
List leaves = rt.leaves();
for(int i=0;i workers = new ArrayList();
//divide the entire dataset into chunks of equal size for each worker thread
int[] partition = p.partition(samples.size());
for(int i=0;i workers = new ArrayList();
//divide the entire dataset into chunks of equal size for each worker thread
int[] partition = p.partition(validationSamples.size());
for(int i=0;i