All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ciir.umass.edu.learning.CoorAscent Maven / Gradle / Ivy

/*===============================================================================
 * Copyright (c) 2010-2012 University of Massachusetts.  All Rights Reserved.
 *
 * Use of the RankLib package is subject to the terms of the software license set 
 * forth in the LICENSE file included with this software, and also available at
 * http://people.cs.umass.edu/~vdang/ranklib_license.html
 *===============================================================================
 */

package ciir.umass.edu.learning;

import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.KeyValuePair;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;

import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
 * @author vdang
 * 
 * This class implements the linear ranking model known as Coordinate Ascent. It was proposed in this paper:
 *  D. Metzler and W.B. Croft. Linear feature-based models for information retrieval. Information Retrieval, 10(3): 257-274, 2007.
 */
public class CoorAscent extends Ranker {

	//Parameters
	public static int nRestart = 5;
	public static int nMaxIteration = 25;
	public static double stepBase = 0.05;
	public static double stepScale = 2.0;
	public static double tolerance = 0.001;
	public static boolean regularized = false;
	public static double slack = 0.001;//regularized parameter
	
	//Local variables
	public double[] weight = null;
	
	protected int current_feature = -1;//used only during learning
	protected double weight_change = -1.0;//used only during learning
	
	public CoorAscent()
	{
		
	}
	public CoorAscent(List samples, int[] features, MetricScorer scorer)
	{
		super(samples, features, scorer);
	}
	
	public void init()
	{
		PRINT("Initializing with " + features.length + " features... ");
		weight = new double[features.length];
		Arrays.fill(weight, 1.0 / features.length);
		PRINTLN("[Done]");
	}
	public void learn()
	{
		double[] regVector = new double[weight.length];
		copy(weight, regVector);//uniform weight distribution
		
		//this holds the final best model/score
		double[] bestModel = null;
		double bestModelScore = 0.0;

		// look in both directions and with feature removed.
		final int[] sign = new int[]{1, -1, 0};
		
		PRINTLN("---------------------------");
		PRINTLN("Training starts...");
		PRINTLN("---------------------------");
		
		for(int r=0;r1&&consecutive_fails < weight.length - 1) || (weight.length==1&&consecutive_fails==0))
			{
				PRINTLN("Shuffling features' order... [Done.]");
				PRINTLN("Optimizing weight vector... ");
				PRINTLN("------------------------------");
				PRINTLN(new int[]{7, 8, 7}, new String[]{"Feature", "weight", scorer.name()});
				PRINTLN("------------------------------");

				int[] fids = getShuffledFeatures();//contain index of elements in the variable @features
				//Try maximizing each feature individually
				for(int i=0;i 0.5 * Math.abs(origWeight))
					    	step = stepBase * Math.abs(origWeight);
						totalStep = step;
						int numIter = nMaxIteration;
						if(dir == 0) {
							numIter = 1;
							totalStep = -origWeight;
						}
						for (int j = 0; j < numIter; j++)
						{
							double w = origWeight + totalStep;
							weight_change = step;//weight_change is used in the "else" branch in the procedure rank()
							weight[fids[i]] = w;
							double score = scorer.score(rank(samples));
							if(regularized)
							{
								double penalty = slack * getDistance(weight, regVector);
								score -= penalty;
								//PRINTLN("Penalty: " + penalty);
							}
							if(score > bestScore)//better than the local best, replace the local best with this model
							{
								bestScore = score;
								bestTotalStep = totalStep;
								succeeds = true;
								String bw = ((weight[fids[i]]>0)?"+":"") + SimpleMath.round(weight[fids[i]], 4);
								PRINTLN(new int[]{7, 8, 7}, new String[]{features[fids[i]]+"", bw+"", SimpleMath.round(bestScore, 4)+""});
							}
							if(j < nMaxIteration-1)
							{
								step *= stepScale;
								totalStep += step;
							}
						}
						if(succeeds)
							break;//no need to search the other direction (e.g. sign = '-')
						else if(s < sign.length-1)
						{
							weight_change = -totalStep;
							updateCached();//restore the cached to reflect the orig. weight for the current feature 
							//so that we can start searching in the other direction (since the optimization in the first direction failed)
							weight[fids[i]] = origWeight;//restore the weight to its initial value
						}
					}
					if(succeeds) 
					{
						weight_change = bestTotalStep - totalStep;
						updateCached();//restore the cached to reflect the best weight for the current feature
						weight[fids[i]] = origWeight + bestTotalStep;
						consecutive_fails = 0;//since we found a better weight value
						double sum = normalize(weight);
						scaleCached(sum);
						copy(weight, bestWeight);						
					}
					else
					{
						consecutive_fails++;
						weight_change = -totalStep;
						updateCached();//restore the cached to reflect the orig. weight for the current feature since the optimization failed
						//Restore the orig. weight value
						weight[fids[i]] = origWeight;
					}
				}
				PRINTLN("------------------------------");
				
				//if we haven't made much progress then quit
				if(bestScore - startScore < tolerance)
					break;
			}
			//update the (global) best model with the best model found in this round
			if(validationSamples != null)
			{
				current_feature = -1;
				bestScore = scorer.score(rank(validationSamples));
			}			
			if(bestModel == null || bestScore > bestModelScore)
			{
				bestModelScore = bestScore;
				bestModel = bestWeight;
			}
		}
		
		copy(bestModel, weight);
		current_feature = -1;//turn off the cache mode
		scoreOnTrainingData = SimpleMath.round(scorer.score(rank(samples)), 4);
		PRINTLN("---------------------------------");
		PRINTLN("Finished sucessfully.");
		PRINTLN(scorer.name() + " on training data: " + scoreOnTrainingData);

		if(validationSamples != null)
		{
			bestScoreOnValidationData = scorer.score(rank(validationSamples));
			PRINTLN(scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
		}
		PRINTLN("---------------------------------");
	}
	public RankList rank(RankList rl)
	{
		double[] score = new double[rl.size()];
		if(current_feature == -1)
		{
			for(int i=0;i a'_2
				//new score = cached score + (a'_2 - a_2)*x_2  ====> NO NEED TO RE-COMPUTE THE WHOLE THING
				score[i] = rl.get(i).getCached() + weight_change * rl.get(i).getFeatureValue(features[current_feature]);
				rl.get(i).setCached(score[i]);
			}
		}
		int[] idx = MergeSorter.sort(score, false); 
		return new RankList(rl, idx);
	}
	public double eval(DataPoint p)
	{
		double score = 0.0;
		for(int i=0;i keys = kvp.keys();
			List values = kvp.values();
			weight = new double[keys.size()];
			features = new int[keys.size()];
			for(int i=0;i a'_2
				//new score = cached score + (a'_2 - a_2)*x_2  ====> NO NEED TO RE-COMPUTE THE WHOLE THING
				double score = rl.get(i).getCached() + weight_change * rl.get(i).getFeatureValue(features[current_feature]);
				rl.get(i).setCached(score);
			}
		}
	}
	private void scaleCached(double sum)
	{
		for(int j=0;j l = new ArrayList();
		for(int i=0;i 0)
		{
			for(int j=0;j




© 2015 - 2024 Weber Informatics LLC | Privacy Policy