All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ciir.umass.edu.learning.tree.LambdaMART Maven / Gradle / Ivy

/*===============================================================================
 * Copyright (c) 2010-2012 University of Massachusetts.  All Rights Reserved.
 *
 * Use of the RankLib package is subject to the terms of the software license set 
 * forth in the LICENSE file included with this software, and also available at
 * http://people.cs.umass.edu/~vdang/ranklib_license.html
 *===============================================================================
 */

package ciir.umass.edu.learning.tree;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.MyThreadPool;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;

import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * @author vdang
 *
 *  This class implements LambdaMART.
 *  Q. Wu, C.J.C. Burges, K. Svore and J. Gao. Adapting Boosting for Information Retrieval Measures. 
 *  Journal of Information Retrieval, 2007.
 */
public class LambdaMART extends Ranker {
	//Parameters
	public static int nTrees = 1000;//the number of trees
	public static float learningRate = 0.1F;//or shrinkage
	public static int nThreshold = 256;
	public static int nRoundToStopEarly = 100;//If no performance gain on the *VALIDATION* data is observed in #rounds, stop the training process right away. 
	public static int nTreeLeaves = 10;
	public static int minLeafSupport = 1;
	
	//for debugging
	public static int gcCycle = 100;
	
	//Local variables
	protected float[][] thresholds = null;
	protected Ensemble ensemble = null;
	protected double[] modelScores = null;//on training data
	
	protected double[][] modelScoresOnValidation = null;
	protected int bestModelOnValidation = Integer.MAX_VALUE-2;
	
	//Training instances prepared for MART
	protected DataPoint[] martSamples = null;//Need initializing only once
	protected int[][] sortedIdx = null;//sorted list of samples in @martSamples by each feature -- Need initializing only once 
	protected FeatureHistogram hist = null;
	protected double[] pseudoResponses = null;//different for each iteration
	protected double[] weights = null;//different for each iteration
	
	public LambdaMART()
	{		
	}

	public LambdaMART(List samples, int[] features, MetricScorer scorer)
	{
		super(samples, features, scorer);
	}
	
	public void init()
	{
		PRINT("Initializing... ");		
		//initialize samples for MART
		int dpCount = 0;
		for(int i=0;i values = new ArrayList();
			float fmax = Float.NEGATIVE_INFINITY;
			float fmin = Float.MAX_VALUE;
			for(int i=0;i fv)
					fmin = fv;
				//skip all samples with the same feature value
				int j=i+1;
				while(j < martSamples.length)
				{
					if(martSamples[sortedIdx[f][j]].getFeatureValue(features[f]) > fv)
						break;
					j++;
				}
				i = j-1;//[i, j] gives the range of samples with the same feature value
			}
			
			if(values.size() <= nThreshold || nThreshold == -1)
			{
				thresholds[f] = new float[values.size()+1];
				for(int i=0;i leaves = rt.leaves();
			for(int i=0;i bestScoreOnValidationData)
				{
					bestScoreOnValidationData = score;
					bestModelOnValidation = ensemble.treeCount()-1;
				}
			}
			
			PRINTLN("");
			
			//Should we stop early?
			if(m - bestModelOnValidation > nRoundToStopEarly)
				break;
		}
		
		//Rollback to the best model observed on the validation data
		while(ensemble.treeCount() > bestModelOnValidation+1)
			ensemble.remove(ensemble.treeCount()-1);
		
		//Finishing up
		scoreOnTrainingData = scorer.score(rank(samples));
		PRINTLN("---------------------------------");
		PRINTLN("Finished sucessfully.");
		PRINTLN(scorer.name() + " on training data: " + SimpleMath.round(scoreOnTrainingData, 4));
		if(validationSamples != null)
		{
			bestScoreOnValidationData = scorer.score(rank(validationSamples));
			PRINTLN(scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
		}
		PRINTLN("---------------------------------");
	}

	public double eval(DataPoint dp)
	{
		return ensemble.eval(dp);
	}	

	public Ranker createNew()
	{
		return new LambdaMART();
	}

	public String toString()
	{
		return ensemble.toString();
	}

	public String model()
	{
		String output = "## " + name() + "\n";
		output += "## No. of trees = " + nTrees + "\n";
		output += "## No. of leaves = " + nTreeLeaves + "\n";
		output += "## No. of threshold candidates = " + nThreshold + "\n";
		output += "## Learning rate = " + learningRate + "\n";
		output += "## Stop early = " + nRoundToStopEarly + "\n";
		output += "\n";
		output += toString();
		return output;
	}

        @Override
	public void loadFromString(String fullText)
	{
		try {
			String content = "";
			//String model = "";
                        StringBuffer model = new StringBuffer ();
			BufferedReader in = new BufferedReader(new StringReader(fullText));
			while((content = in.readLine()) != null)
			{
				content = content.trim();
				if(content.length() == 0)
					continue;
				if(content.indexOf("##")==0)
					continue;
				//actual model component
				//model += content;
                                model.append (content);
			}
			in.close();
			//load the ensemble
			ensemble = new Ensemble(model.toString());
			features = ensemble.getFeatures();
		}
		catch(Exception ex)
		{
			throw RankLibError.create("Error in LambdaMART::load(): ", ex);
		}
	}

	public void printParameters()
	{
		PRINTLN("No. of trees: " + nTrees);
		PRINTLN("No. of leaves: " + nTreeLeaves);
		PRINTLN("No. of threshold candidates: " + nThreshold);
		PRINTLN("Min leaf support: " + minLeafSupport);
		PRINTLN("Learning rate: " + learningRate);
		PRINTLN("Stop early: " + nRoundToStopEarly + " rounds without performance gain on validation data");		
	}	

	public String name()
	{
		return "LambdaMART";
	}

	public Ensemble getEnsemble()
	{
		return ensemble;
	}
	
	protected void computePseudoResponses()
	{
		Arrays.fill(pseudoResponses, 0F);
		Arrays.fill(weights, 0);
		MyThreadPool p = MyThreadPool.getInstance();
		if(p.size() == 1)//single-thread
			computePseudoResponses(0, samples.size()-1, 0);
		else //multi-threading
		{
			List workers = new ArrayList();
			//divide the entire dataset into chunks of equal size for each worker thread
			int[] partition = p.partition(samples.size());
			int current = 0;
			for(int i=0;i need to map back with idx[j] and idx[k] 
			for(int j=0;j cutoff && k > cutoff)//swaping these pair won't result in any change in target measures since they're below the cut-off point
						break;
					DataPoint p2 = rl.get(k);
					int mk = idx[k];
					if(p1.getLabel() > p2.getLabel())
					{
						double deltaNDCG = Math.abs(changes[j][k]);
						if(deltaNDCG > 0)
						{
							double rho = 1.0 / (1 + Math.exp(modelScores[mj] - modelScores[mk]));
							double lambda = rho * deltaNDCG;
							pseudoResponses[mj] += lambda;
							pseudoResponses[mk] -= lambda;
							double delta = rho * (1.0 - rho) * deltaNDCG;
							weights[mj] += delta;
							weights[mk] += delta;
						}
					}
				}
			}
			current += orig.size();
		}
	}

	protected void updateTreeOutput(RegressionTree rt)
	{
		List leaves = rt.leaves();
		for(int i=0;i workers = new ArrayList();
			//divide the entire dataset into chunks of equal size for each worker thread
			int[] partition = p.partition(samples.size());
			for(int i=0;i workers = new ArrayList();
			//divide the entire dataset into chunks of equal size for each worker thread
			int[] partition = p.partition(validationSamples.size());
			for(int i=0;i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy