All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ciir.umass.edu.learning.tree.LambdaMART Maven / Gradle / Ivy

The newest version!
/*===============================================================================
 * Copyright (c) 2010-2012 University of Massachusetts.  All Rights Reserved.
 *
 * Use of the RankLib package is subject to the terms of the software license set
 * forth in the LICENSE file included with this software, and also available at
 * http://people.cs.umass.edu/~vdang/ranklib_license.html
 *===============================================================================
 */

package ciir.umass.edu.learning.tree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.parsing.ModelLineProducer;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.MyThreadPool;
import ciir.umass.edu.utilities.SimpleMath;

/**
 * @author vdang
 *
 *  This class implements LambdaMART.
 *  Q. Wu, C.J.C. Burges, K. Svore and J. Gao. Adapting Boosting for Information Retrieval Measures.
 *  Journal of Information Retrieval, 2007.
 */
public class LambdaMART extends Ranker {
    private static final Logger logger = Logger.getLogger(LambdaMART.class.getName());

    //Parameters
    public static int nTrees = 1000;//the number of trees
    public static float learningRate = 0.1F;//or shrinkage
    public static int nThreshold = 256;
    public static int nRoundToStopEarly = 100;//If no performance gain on the *VALIDATION* data is observed in #rounds, stop the training process right away.
    public static int nTreeLeaves = 10;
    public static int minLeafSupport = 1;

    //Local variables
    protected float[][] thresholds = null;
    protected Ensemble ensemble = null;
    protected double[] modelScores = null;//on training data

    protected double[][] modelScoresOnValidation = null;
    protected int bestModelOnValidation = Integer.MAX_VALUE - 2;

    //Training instances prepared for MART
    protected DataPoint[] martSamples = null;//Need initializing only once
    protected int[][] sortedIdx = null;//sorted list of samples in @martSamples by each feature -- Need initializing only once
    protected FeatureHistogram hist = null;
    protected double[] pseudoResponses = null;//different for each iteration
    protected double[] weights = null;//different for each iteration
    protected double[] impacts = null; // accumulated impact of each feature

    public LambdaMART() {
    }

    public LambdaMART(final List samples, final int[] features, final MetricScorer scorer) {
        super(samples, features, scorer);
    }

    @Override
    public void init() {
        logger.info(() -> "Initializing... ");
        //initialize samples for MART
        int dpCount = 0;
        for (int i = 0; i < samples.size(); i++) {
            final RankList rl = samples.get(i);
            dpCount += rl.size();
        }
        int current = 0;
        martSamples = new DataPoint[dpCount];
        modelScores = new double[dpCount];
        pseudoResponses = new double[dpCount];
        impacts = new double[features.length];
        weights = new double[dpCount];
        for (int i = 0; i < samples.size(); i++) {
            final RankList rl = samples.get(i);
            for (int j = 0; j < rl.size(); j++) {
                martSamples[current + j] = rl.get(j);
                modelScores[current + j] = 0.0F;
                pseudoResponses[current + j] = 0.0F;
                weights[current + j] = 0;
            }
            current += rl.size();
        }

        //sort (MART) samples by each feature so that we can quickly retrieve a sorted list of samples by any feature later on.
        sortedIdx = new int[features.length][];
        final MyThreadPool p = MyThreadPool.getInstance();
        if (p.size() == 1) {
            sortSamplesByFeature(0, features.length - 1);
        } else//multi-thread
        {
            final int[] partition = p.partition(features.length);
            for (int i = 0; i < partition.length - 1; i++) {
                p.execute(new SortWorker(this, partition[i], partition[i + 1] - 1));
            }
            p.await();
        }

        //Create a table of candidate thresholds (for each feature). Later on, we will select the best tree split from these candidates
        thresholds = new float[features.length][];
        for (int f = 0; f < features.length; f++) {
            //For this feature, keep track of the list of unique values and the max/min
            final List values = new ArrayList<>();
            float fmax = Float.NEGATIVE_INFINITY;
            float fmin = Float.MAX_VALUE;
            for (int i = 0; i < martSamples.length; i++) {
                final int k = sortedIdx[f][i];//get samples sorted with respect to this feature
                final float fv = martSamples[k].getFeatureValue(features[f]);
                values.add(fv);
                if (fmax < fv) {
                    fmax = fv;
                }
                if (fmin > fv) {
                    fmin = fv;
                }
                //skip all samples with the same feature value
                int j = i + 1;
                while (j < martSamples.length) {
                    if (martSamples[sortedIdx[f][j]].getFeatureValue(features[f]) > fv) {
                        break;
                    }
                    j++;
                }
                i = j - 1;//[i, j] gives the range of samples with the same feature value
            }

            if (values.size() <= nThreshold || nThreshold == -1) {
                thresholds[f] = new float[values.size() + 1];
                for (int i = 0; i < values.size(); i++) {
                    thresholds[f][i] = values.get(i);
                }
                thresholds[f][values.size()] = Float.MAX_VALUE;
            } else {
                final float step = (Math.abs(fmax - fmin)) / nThreshold;
                thresholds[f] = new float[nThreshold + 1];
                thresholds[f][0] = fmin;
                for (int j = 1; j < nThreshold; j++) {
                    thresholds[f][j] = thresholds[f][j - 1] + step;
                }
                thresholds[f][nThreshold] = Float.MAX_VALUE;
            }
        }

        if (validationSamples != null) {
            modelScoresOnValidation = new double[validationSamples.size()][];
            for (int i = 0; i < validationSamples.size(); i++) {
                modelScoresOnValidation[i] = new double[validationSamples.get(i).size()];
                Arrays.fill(modelScoresOnValidation[i], 0);
            }
        }

        //compute the feature histogram (this is used to speed up the procedure of finding the best tree split later on)
        hist = new FeatureHistogram();
        hist.construct(martSamples, pseudoResponses, sortedIdx, features, thresholds, impacts);
        //we no longer need the sorted indexes of samples
        sortedIdx = null;

    }

    @Override
    public void learn() {
        ensemble = new Ensemble();

        logger.info(() -> "Training starts...");
        if (validationSamples != null) {
            printLogLn(new int[] { 7, 9, 9 }, new String[] { "#iter", scorer.name() + "-T", scorer.name() + "-V" });
        } else {
            printLogLn(new int[] { 7, 9 }, new String[] { "#iter", scorer.name() + "-T" });
        }

        //Start the gradient boosting process
        for (int m = 0; m < nTrees; m++) {
            printLog(new int[] { 7 }, new String[] { Integer.toString(m + 1) });

            //Compute lambdas (which act as the "pseudo responses")
            //Create training instances for MART:
            //  - Each document is a training sample
            //	- The lambda for this document serves as its training label
            computePseudoResponses();

            //update the histogram with these training labels (the feature histogram will be used to find the best tree split)
            hist.update(pseudoResponses);

            //Fit a regression tree
            final RegressionTree rt = new RegressionTree(nTreeLeaves, martSamples, pseudoResponses, hist, minLeafSupport);
            rt.fit();

            //Add this tree to the ensemble (our model)
            ensemble.add(rt, learningRate);

            //update the outputs of the tree (with gamma computed using the Newton-Raphson method)
            updateTreeOutput(rt);

            //Update the model's outputs on all training samples
            final List leaves = rt.leaves();
            for (int i = 0; i < leaves.size(); i++) {
                final Split s = leaves.get(i);
                final int[] idx = s.getSamples();
                for (int j = 0; j < idx.length; j++) {
                    modelScores[idx[j]] += learningRate * s.getOutput();
                }
            }

            //clear references to data that is no longer used
            rt.clearSamples();

            //Evaluate the current model
            scoreOnTrainingData = computeModelScoreOnTraining();
            //**** NOTE ****
            //The above function to evaluate the current model on the training data is equivalent to a single call:
            //
            //		scoreOnTrainingData = scorer.score(rank(samples);
            //
            //However, this function is more efficient since it uses the cached outputs of the model (as opposed to re-evaluating the model
            //on the entire training set).

            printLog(new int[] { 9 }, new String[] { Double.toString(SimpleMath.round(scoreOnTrainingData, 4)) });

            //Evaluate the current model on the validation data (if available)
            if (validationSamples != null) {
                //Update the model's scores on all validation samples
                for (int i = 0; i < modelScoresOnValidation.length; i++) {
                    for (int j = 0; j < modelScoresOnValidation[i].length; j++) {
                        modelScoresOnValidation[i][j] += learningRate * rt.eval(validationSamples.get(i).get(j));
                    }
                }

                //again, equivalent to scoreOnValidation=scorer.score(rank(validationSamples)), but more efficient since we use the cached models' outputs
                final double score = computeModelScoreOnValidation();

                printLog(new int[] { 9 }, new String[] { Double.toString(SimpleMath.round(score, 4)) });
                if (score > bestScoreOnValidationData) {
                    bestScoreOnValidationData = score;
                    bestModelOnValidation = ensemble.treeCount() - 1;
                }
            }
            flushLog();

            //Should we stop early?
            if (m - bestModelOnValidation > nRoundToStopEarly) {
                break;
            }
        }

        //Rollback to the best model observed on the validation data
        while (ensemble.treeCount() > bestModelOnValidation + 1) {
            ensemble.remove(ensemble.treeCount() - 1);
        }

        //Finishing up
        scoreOnTrainingData = scorer.score(rank(samples));
        logger.info(() -> "Finished sucessfully.");
        logger.info(() -> scorer.name() + " on training data: " + SimpleMath.round(scoreOnTrainingData, 4));
        if (validationSamples != null) {
            bestScoreOnValidationData = scorer.score(rank(validationSamples));
            logger.info(() -> scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
        }

        logger.info(() -> "-- FEATURE IMPACTS");
        final int ftrsSorted[] = MergeSorter.sort(this.impacts, false);
        for (final int ftr : ftrsSorted) {
            logger.info(() -> " Feature " + features[ftr] + " reduced error " + impacts[ftr]);
        }
    }

    @Override
    public double eval(final DataPoint dp) {
        return ensemble.eval(dp);
    }

    @Override
    public Ranker createNew() {
        return new LambdaMART();
    }

    @Override
    public String toString() {
        return ensemble.toString();
    }

    @Override
    public String model() {
        final StringBuilder output = new StringBuilder();
        output.append("## " + name() + "\n");
        output.append("## No. of trees = " + nTrees + "\n");
        output.append("## No. of leaves = " + nTreeLeaves + "\n");
        output.append("## No. of threshold candidates = " + nThreshold + "\n");
        output.append("## Learning rate = " + learningRate + "\n");
        output.append("## Stop early = " + nRoundToStopEarly + "\n");
        output .append("\n");
        output.append(toString());
        return output.toString();
    }

    @Override
    public void loadFromString(final String fullText) {
        final ModelLineProducer lineByLine = new ModelLineProducer();
        lineByLine.parse(fullText, (model, endEns) -> {});
        //load the ensemble
        ensemble = new Ensemble(lineByLine.getModel().toString());
        features = ensemble.getFeatures();
    }

    @Override
    public void printParameters() {
        logger.info(() -> "No. of trees: " + nTrees);
        logger.info(() -> "No. of leaves: " + nTreeLeaves);
        logger.info(() -> "No. of threshold candidates: " + nThreshold);
        logger.info(() -> "Min leaf support: " + minLeafSupport);
        logger.info(() -> "Learning rate: " + learningRate);
        logger.info(() -> "Stop early: " + nRoundToStopEarly + " rounds without performance gain on validation data");
    }

    @Override
    public String name() {
        return "LambdaMART";
    }

    public Ensemble getEnsemble() {
        return ensemble;
    }

    protected void computePseudoResponses() {
        Arrays.fill(pseudoResponses, 0F);
        Arrays.fill(weights, 0);
        final MyThreadPool p = MyThreadPool.getInstance();
        if (p.size() == 1) {
            computePseudoResponses(0, samples.size() - 1, 0);
        } else //multi-threading
        {
            final List workers = new ArrayList<>();
            //divide the entire dataset into chunks of equal size for each worker thread
            final int[] partition = p.partition(samples.size());
            int current = 0;
            for (int i = 0; i < partition.length - 1; i++) {
                //execute the worker
                final LambdaComputationWorker wk = new LambdaComputationWorker(this, partition[i], partition[i + 1] - 1, current);
                workers.add(wk);//keep it so we can get back results from it later on
                p.execute(wk);

                if (i < partition.length - 2) {
                    for (int j = partition[i]; j <= partition[i + 1] - 1; j++) {
                        current += samples.get(j).size();
                    }
                }
            }

            //wait for all workers to complete before we move on to the next stage
            p.await();
        }
    }

    protected void computePseudoResponses(final int start, final int end, int current) {
        final int cutoff = scorer.getK();
        //compute the lambda for each document (a.k.a "pseudo response")
        for (int i = start; i <= end; i++) {
            final RankList orig = samples.get(i);
            final int[] idx = MergeSorter.sort(modelScores, current, current + orig.size() - 1, false);
            final RankList rl = new RankList(orig, idx, current);
            final double[][] changes = scorer.swapChange(rl);
            //NOTE: j, k are indices in the sorted (by modelScore) list, not the original
            // ==> need to map back with idx[j] and idx[k]
            for (int j = 0; j < rl.size(); j++) {
                final DataPoint p1 = rl.get(j);
                final int mj = idx[j];
                for (int k = 0; k < rl.size(); k++) {
                    if (j > cutoff && k > cutoff) {
                        break;
                    }
                    final DataPoint p2 = rl.get(k);
                    final int mk = idx[k];
                    if (p1.getLabel() > p2.getLabel()) {
                        final double deltaNDCG = Math.abs(changes[j][k]);
                        if (deltaNDCG > 0) {
                            final double rho = 1.0 / (1 + Math.exp(modelScores[mj] - modelScores[mk]));
                            final double lambda = rho * deltaNDCG;
                            pseudoResponses[mj] += lambda;
                            pseudoResponses[mk] -= lambda;
                            final double delta = rho * (1.0 - rho) * deltaNDCG;
                            weights[mj] += delta;
                            weights[mk] += delta;
                        }
                    }
                }
            }
            current += orig.size();
        }
    }

    protected void updateTreeOutput(final RegressionTree rt) {
        final List leaves = rt.leaves();
        for (int i = 0; i < leaves.size(); i++) {
            float s1 = 0F;
            float s2 = 0F;
            final Split s = leaves.get(i);
            final int[] idx = s.getSamples();
            for (final int k : idx) {
                s1 += pseudoResponses[k];
                s2 += weights[k];
            }
            if (s2 == 0) {
                s.setOutput(0);
            } else {
                s.setOutput(s1 / s2);
            }
        }
    }

    protected int[] sortSamplesByFeature(final DataPoint[] samples, final int fid) {
        final double[] score = new double[samples.length];
        for (int i = 0; i < samples.length; i++) {
            score[i] = samples[i].getFeatureValue(fid);
        }
        final int[] idx = MergeSorter.sort(score, true);
        return idx;
    }

    /**
     * This function is equivalent to the inherited function rank(...), but it uses the cached model's outputs instead of computing them from scratch.
     * @param rankListIndex
     * @param current
     * @return
     */
    protected RankList rank(final int rankListIndex, final int current) {
        final RankList orig = samples.get(rankListIndex);
        final double[] scores = new double[orig.size()];
        for (int i = 0; i < scores.length; i++) {
            scores[i] = modelScores[current + i];
        }
        final int[] idx = MergeSorter.sort(scores, false);
        return new RankList(orig, idx);
    }

    protected float computeModelScoreOnTraining() {
        /*float s = 0;
        int current = 0;
        MyThreadPool p = MyThreadPool.getInstance();
        if(p.size() == 1)//single-thread
        	s = computeModelScoreOnTraining(0, samples.size()-1, current);
        else
        {
        	List workers = new ArrayList();
        	//divide the entire dataset into chunks of equal size for each worker thread
        	int[] partition = p.partition(samples.size());
        	for(int i=0;i workers = new ArrayList();
        	//divide the entire dataset into chunks of equal size for each worker thread
        	int[] partition = p.partition(validationSamples.size());
        	for(int i=0;i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy