All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ciir.umass.edu.learning.tree.RFRanker Maven / Gradle / Ivy

The newest version!
/*===============================================================================
 * Copyright (c) 2010-2012 University of Massachusetts.  All Rights Reserved.
 *
 * Use of the RankLib package is subject to the terms of the software license set
 * forth in the LICENSE file included with this software, and also available at
 * http://people.cs.umass.edu/~vdang/ranklib_license.html
 *===============================================================================
 */

package ciir.umass.edu.learning.tree;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.learning.RankerFactory;
import ciir.umass.edu.learning.RankerType;
import ciir.umass.edu.learning.Sampler;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.parsing.ModelLineProducer;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.SimpleMath;

public class RFRanker extends Ranker {
    private static final Logger logger = Logger.getLogger(RFRanker.class.getName());

    //Parameters
    //[a] general bagging parameters
    public static int nBag = 300;
    public static float subSamplingRate = 1.0f;//sampling of samples (*WITH* replacement)
    public static float featureSamplingRate = 0.3f;//sampling of features (*WITHOUT* replacement)
    //[b] what to do in each bag
    public static RankerType rType = RankerType.MART;//which algorithm to bag
    public static int nTrees = 1;//how many trees in each bag. If nTree > 1 ==> each bag will contain an ensemble of gradient boosted trees.
    public static int nTreeLeaves = 100;
    public static float learningRate = 0.1F;//or shrinkage. *ONLY* matters if nTrees > 1.
    public static int nThreshold = 256;
    public static int minLeafSupport = 1;

    //Variables
    protected Ensemble[] ensembles = null;//bag of ensembles, each can be a single tree or an ensemble of gradient boosted trees

    public RFRanker() {
    }

    public RFRanker(final List samples, final int[] features, final MetricScorer scorer) {
        super(samples, features, scorer);
    }

    @Override
    public void init() {
        logger.info(() -> "Initializing... ");
        ensembles = new Ensemble[nBag];
        //initialize parameters for the tree(s) built in each bag
        LambdaMART.nTrees = nTrees;
        LambdaMART.nTreeLeaves = nTreeLeaves;
        LambdaMART.learningRate = learningRate;
        LambdaMART.nThreshold = nThreshold;
        LambdaMART.minLeafSupport = minLeafSupport;
        LambdaMART.nRoundToStopEarly = -1;//no early-stopping since we're doing bagging
        //turn on feature sampling
        FeatureHistogram.samplingRate = featureSamplingRate;
    }

    @Override
    public void learn() {
        final RankerFactory rf = new RankerFactory();
        logger.info(() -> "Training starts...");
        printLogLn(new int[] { 9, 9, 11 }, new String[] { "bag", scorer.name() + "-B", scorer.name() + "-OOB" });
        double[] impacts = null;
        //start the bagging process
        for (int i = 0; i < nBag; i++) {
            final Sampler sp = new Sampler();
            //create a "bag" of samples by random sampling from the training set
            final List bag = sp.doSampling(samples, subSamplingRate, true);
            final LambdaMART r = (LambdaMART) rf.createRanker(rType, bag, features, scorer);

            r.init();
            r.learn();
            // accumulate impacts
            if (impacts == null) {
                impacts = r.impacts;
            } else {
                for (int ftr = 0; ftr < impacts.length; ftr++) {
                    impacts[ftr] += r.impacts[ftr];
                }
            }
            printLogLn(new int[] { 9, 9 }, new String[] { "b[" + (i + 1) + "]", SimpleMath.round(r.getScoreOnTrainingData(), 4) + "" });
            ensembles[i] = r.getEnsemble();
        }
        //Finishing up
        scoreOnTrainingData = scorer.score(rank(samples));
        logger.info(() -> "Finished sucessfully.");
        logger.info(() -> scorer.name() + " on training data: " + SimpleMath.round(scoreOnTrainingData, 4));
        if (validationSamples != null) {
            bestScoreOnValidationData = scorer.score(rank(validationSamples));
            logger.info(() -> scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
        }

        logger.info(() -> "-- FEATURE IMPACTS");
        if (logger.isLoggable(Level.INFO)) {
            final int ftrsSorted[] = MergeSorter.sort(impacts, false);
            for (final int ftr : ftrsSorted) {
                logger.info(" Feature " + features[ftr] + " reduced error " + impacts[ftr]);
            }
        }

    }

    @Override
    public double eval(final DataPoint dp) {
        double s = 0;
        for (final Ensemble ensemble : ensembles) {
            s += ensemble.eval(dp);
        }
        return s / ensembles.length;
    }

    @Override
    public Ranker createNew() {
        return new RFRanker();
    }

    @Override
    public String toString() {
        String str = "";
        for (int i = 0; i < nBag; i++) {
            str += ensembles[i].toString() + "\n";
        }
        return str;
    }

    @Override
    public String model() {
        String output = "## " + name() + "\n";
        output += "## No. of bags = " + nBag + "\n";
        output += "## Sub-sampling = " + subSamplingRate + "\n";
        output += "## Feature-sampling = " + featureSamplingRate + "\n";
        output += "## No. of trees = " + nTrees + "\n";
        output += "## No. of leaves = " + nTreeLeaves + "\n";
        output += "## No. of threshold candidates = " + nThreshold + "\n";
        output += "## Learning rate = " + learningRate + "\n";
        output += "\n";
        output += toString();
        return output;
    }

    @Override
    public void loadFromString(final String fullText) {
        final List ens = new ArrayList<>();

        final ModelLineProducer lineByLine = new ModelLineProducer();

        lineByLine.parse(fullText, (model, maybeEndEns) -> {
            if (maybeEndEns) {
                final String modelAsStr = model.toString();
                if (modelAsStr.endsWith("")) {
                    ens.add(new Ensemble(modelAsStr));
                    model.setLength(0);
                }
            }
        });

        final Set uniqueFeatures = new HashSet<>();
        ensembles = new Ensemble[ens.size()];
        for (int i = 0; i < ens.size(); i++) {
            ensembles[i] = ens.get(i);
            //obtain used features
            final int[] fids = ens.get(i).getFeatures();
            for (int f = 0; f < fids.length; f++) {
                if (!uniqueFeatures.contains(fids[f])) {
                    uniqueFeatures.add(fids[f]);
                }
            }
        }
        int fi = 0;
        features = new int[uniqueFeatures.size()];
        for (final Integer f : uniqueFeatures) {
            features[fi++] = f.intValue();
        }
    }

    @Override
    public void printParameters() {
        logger.info(() -> "No. of bags: " + nBag);
        logger.info(() -> "Sub-sampling: " + subSamplingRate);
        logger.info(() -> "Feature-sampling: " + featureSamplingRate);
        logger.info(() -> "No. of trees: " + nTrees);
        logger.info(() -> "No. of leaves: " + nTreeLeaves);
        logger.info(() -> "No. of threshold candidates: " + nThreshold);
        logger.info(() -> "Learning rate: " + learningRate);
    }

    @Override
    public String name() {
        return "Random Forests";
    }

    public Ensemble[] getEnsembles() {
        return ensembles;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy