All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ciir.umass.edu.learning.boosting.AdaRank Maven / Gradle / Ivy

The newest version!
/*===============================================================================
 * Copyright (c) 2010-2012 University of Massachusetts.  All Rights Reserved.
 *
 * Use of the RankLib package is subject to the terms of the software license set
 * forth in the LICENSE file included with this software, and also available at
 * http://people.cs.umass.edu/~vdang/ranklib_license.html
 *===============================================================================
 */

package ciir.umass.edu.learning.boosting;

import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.logging.Logger;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.KeyValuePair;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;

/**
 * @author vdang
 *
 * This class implements the AdaRank algorithm. Here's the paper:
 *  J. Xu and H. Li. AdaRank: a boosting algorithm for information retrieval. In Proc. of SIGIR, pages 391-398, 2007.
 */
public class AdaRank extends Ranker {
    private static final Logger logger = Logger.getLogger(AdaRank.class.getName());

    //Paramters
    public static int nIteration = 500;
    public static double tolerance = 0.002;
    public static boolean trainWithEnqueue = true;
    public static int maxSelCount = 5;//the max. number of times a feature can be selected consecutively before being removed

    protected HashMap usedFeatures = new HashMap<>();
    protected double[] sweight = null;//sample weight
    protected List rankers = null;//alpha
    protected List rweight = null;//weak rankers' weight
    //to store the best model on validation data (if specified)
    protected List bestModelRankers = null;
    protected List bestModelWeights = null;

    //For the implementation of tricks
    int lastFeature = -1;
    int lastFeatureConsecutiveCount = 0;
    boolean performanceChanged = false;
    List featureQueue = null;
    protected double[] backupSampleWeight = null;
    protected double backupTrainScore = 0.0;
    protected double lastTrainedScore = -1.0;

    public AdaRank() {

    }

    public AdaRank(final List samples, final int[] features, final MetricScorer scorer) {
        super(samples, features, scorer);
    }

    private void updateBestModelOnValidation() {
        bestModelRankers.clear();
        bestModelRankers.addAll(rankers);
        bestModelWeights.clear();
        bestModelWeights.addAll(rweight);
    }

    private WeakRanker learnWeakRanker() {
        double bestScore = -1.0;
        WeakRanker bestWR = null;
        for (final int i : features) {
            if (featureQueue.contains(i) || usedFeatures.get(i) != null) {
                continue;
            }

            final WeakRanker wr = new WeakRanker(i);
            double s = 0.0;
            for (int j = 0; j < samples.size(); j++) {
                final double t = scorer.score(wr.rank(samples.get(j))) * sweight[j];
                s += t;
            }

            if (bestScore < s) {
                bestScore = s;
                bestWR = wr;
            }
        }
        return bestWR;
    }

    private int learn(final int startIteration, final boolean withEnqueue) {
        int t = startIteration;
        for (; t <= nIteration; t++) {
            printLog(new int[] { 7 }, new String[] { Integer.toString(t) });

            final WeakRanker bestWR = learnWeakRanker();
            if (bestWR == null) {
                break;
            }

            if (withEnqueue) {
                if (bestWR.getFID() == lastFeature)//this feature is selected twice in a row
                {
                    //enqueue this feature
                    featureQueue.add(lastFeature);
                    //roll back the previous weak ranker since it is based on this "too strong" feature
                    rankers.remove(rankers.size() - 1);
                    rweight.remove(rweight.size() - 1);
                    copy(backupSampleWeight, sweight);
                    bestScoreOnValidationData = 0.0;//no best model just yet
                    lastTrainedScore = backupTrainScore;
                    printLogLn(new int[] { 8, 9, 9, 9 }, new String[] { Integer.toString(bestWR.getFID()), "", "", "ROLLBACK" });
                    continue;
                } else {
                    lastFeature = bestWR.getFID();
                    //save the distribution of samples' weight in case we need to rollback
                    copy(sweight, backupSampleWeight);
                    backupTrainScore = lastTrainedScore;
                }
            }

            double num = 0.0;
            double denom = 0.0;
            for (int i = 0; i < samples.size(); i++) {
                final double tmp = scorer.score(bestWR.rank(samples.get(i)));
                num += sweight[i] * (1.0 + tmp);
                denom += sweight[i] * (1.0 - tmp);
            }

            rankers.add(bestWR);
            final double alpha_t = 0.5 * SimpleMath.ln(num / denom);
            rweight.add(alpha_t);

            double trainedScore = 0.0;
            //update the distribution of sample weight
            double total = 0.0;
            for (final RankList sample : samples) {
                final double tmp = scorer.score(rank(sample));
                total += Math.exp(-alpha_t * tmp);
                trainedScore += tmp;
            }
            trainedScore /= samples.size();
            final double delta = trainedScore + tolerance - lastTrainedScore;
            String status = (delta > 0) ? "OK" : "DAMN";

            if (!withEnqueue) {
                if (trainedScore != lastTrainedScore) {
                    performanceChanged = true;
                    lastFeatureConsecutiveCount = 0;
                    //all removed features are added back to the pool
                    usedFeatures.clear();
                } else {
                    performanceChanged = false;
                    if (lastFeature == bestWR.getFID()) {
                        lastFeatureConsecutiveCount++;
                        if (lastFeatureConsecutiveCount == maxSelCount) {
                            status = "F. REM.";
                            lastFeatureConsecutiveCount = 0;
                            usedFeatures.put(lastFeature, 1);//removed this feature from the pool
                        }
                    } else {
                        lastFeatureConsecutiveCount = 0;
                        //all removed features are added back to the pool
                        usedFeatures.clear();
                    }
                }
                lastFeature = bestWR.getFID();
            }

            printLog(new int[] { 8, 9, }, new String[] { Integer.toString(bestWR.getFID()), Double.toString(SimpleMath.round(trainedScore, 4)) });
            if (t % 1 == 0 && validationSamples != null) {
                final double scoreOnValidation = scorer.score(rank(validationSamples));
                if (scoreOnValidation > bestScoreOnValidationData) {
                    bestScoreOnValidationData = scoreOnValidation;
                    updateBestModelOnValidation();
                }
                printLog(new int[] { 9, 9 }, new String[] { Double.toString(SimpleMath.round(scoreOnValidation, 4)), status });
            } else {
                printLog(new int[] { 9, 9 }, new String[] { "", status });
            }
            flushLog();

            if (delta <= 0)//stop criteria met
            {
                rankers.remove(rankers.size() - 1);
                rweight.remove(rweight.size() - 1);
                break;
            }

            lastTrainedScore = trainedScore;
            for (int i = 0; i < sweight.length; i++) {
                sweight[i] *= Math.exp(-alpha_t * scorer.score(rank(samples.get(i)))) / total;
            }
        }
        return t;
    }

    @Override
    public void init() {
        logger.info(() -> "Initializing... ");
        //initialization
        usedFeatures.clear();
        //assign equal weight to all samples
        sweight = new double[samples.size()];
        for (int i = 0; i < sweight.length; i++) {
            sweight[i] = 1.0f / samples.size();
        }
        backupSampleWeight = new double[sweight.length];
        copy(sweight, backupSampleWeight);
        lastTrainedScore = -1.0;

        rankers = new ArrayList<>();
        rweight = new ArrayList<>();

        featureQueue = new ArrayList<>();

        bestScoreOnValidationData = 0.0;
        bestModelRankers = new ArrayList<>();
        bestModelWeights = new ArrayList<>();

    }

    @Override
    public void learn() {
        logger.info(() -> "Training starts...");
        printLogLn(new int[] { 7, 8, 9, 9, 9 }, new String[] { "#iter", "Sel. F.", scorer.name() + "-T", scorer.name() + "-V", "Status" });

        if (trainWithEnqueue) {
            int t = learn(1, true);
            //take care of the enqueued features
            for (int i = featureQueue.size() - 1; i >= 0; i--) {
                featureQueue.remove(i);
                t = learn(t, false);
            }
        } else {
            learn(1, false);
        }

        //if validation data is specified ==> best model on this data has been saved
        //we now restore the current model to that best model
        if (validationSamples != null && bestModelRankers.size() > 0) {
            rankers.clear();
            rweight.clear();
            rankers.addAll(bestModelRankers);
            rweight.addAll(bestModelWeights);
        }

        //print learning score
        scoreOnTrainingData = SimpleMath.round(scorer.score(rank(samples)), 4);
        logger.info(() -> "Finished sucessfully.");
        logger.info(() -> scorer.name() + " on training data: " + scoreOnTrainingData);
        if (validationSamples != null) {
            bestScoreOnValidationData = scorer.score(rank(validationSamples));
            logger.info(() -> scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
        }
    }

    @Override
    public double eval(final DataPoint p) {
        double score = 0.0;
        for (int j = 0; j < rankers.size(); j++) {
            score += rweight.get(j) * p.getFeatureValue(rankers.get(j).getFID());
        }
        return score;
    }

    @Override
    public Ranker createNew() {
        return new AdaRank();
    }

    @Override
    public String toString() {
        final StringBuilder output = new StringBuilder();
        for (int i = 0; i < rankers.size(); i++) {
            output.append(rankers.get(i).getFID() + ":" + rweight.get(i) + ((i == rankers.size() - 1) ? "" : " "));
        }
        return output.toString();
    }

    @Override
    public String model() {
        final StringBuilder output = new StringBuilder();
        output.append("## " + name() + "\n");
        output.append("## Iteration = " + nIteration + "\n");
        output.append("## Train with enqueue: " + ((trainWithEnqueue) ? "Yes" : "No") + "\n");
        output.append("## Tolerance = " + tolerance + "\n");
        output.append("## Max consecutive selection count = " + maxSelCount + "\n");
        output.append(toString());
        return output.toString();
    }

    @Override
    public void loadFromString(final String fullText) {
        try (BufferedReader in = new BufferedReader(new StringReader(fullText))) {
            String content = null;

            KeyValuePair kvp = null;
            while ((content = in.readLine()) != null) {
                content = content.trim();
                if (content.length() == 0) {
                    continue;
                }
                if (content.indexOf("##") == 0) {
                    continue;
                }
                kvp = new KeyValuePair(content);
                break;
            }

            assert (kvp != null);

            final List keys = kvp.keys();
            final List values = kvp.values();
            rweight = new ArrayList<>();
            rankers = new ArrayList<>();
            features = new int[keys.size()];
            for (int i = 0; i < keys.size(); i++) {
                features[i] = Integer.parseInt(keys.get(i));
                rankers.add(new WeakRanker(features[i]));
                rweight.add(Double.parseDouble(values.get(i)));
            }
        } catch (final Exception ex) {
            throw RankLibError.create("Error in AdaRank::load(): ", ex);
        }
    }

    @Override
    public void printParameters() {
        logger.info(() -> "No. of rounds: " + nIteration);
        logger.info(() -> "Train with 'enequeue': " + ((trainWithEnqueue) ? "Yes" : "No"));
        logger.info(() -> "Tolerance: " + tolerance);
        logger.info(() -> "Max Sel. Count: " + maxSelCount);
    }

    @Override
    public String name() {
        return "AdaRank";
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy