All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ciir.umass.edu.learning.RankerFactory Maven / Gradle / Ivy

/*===============================================================================
 * Copyright (c) 2010-2012 University of Massachusetts.  All Rights Reserved.
 *
 * Use of the RankLib package is subject to the terms of the software license set 
 * forth in the LICENSE file included with this software, and also available at
 * http://people.cs.umass.edu/~vdang/ranklib_license.html
 *===============================================================================
 */

package ciir.umass.edu.learning;

import ciir.umass.edu.learning.boosting.AdaRank;
import ciir.umass.edu.learning.boosting.RankBoost;
import ciir.umass.edu.learning.neuralnet.LambdaRank;
import ciir.umass.edu.learning.neuralnet.ListNet;
import ciir.umass.edu.learning.neuralnet.RankNet;
import ciir.umass.edu.learning.tree.LambdaMART;
import ciir.umass.edu.learning.tree.MART;
import ciir.umass.edu.learning.tree.RFRanker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.FileUtils;
import ciir.umass.edu.utilities.RankLibError;

import java.io.BufferedReader;
import java.io.StringReader;
import java.util.HashMap;
import java.util.List;

/**
 * @author vdang
 * 
 * This class implements the Ranker factory. All ranking algorithms implemented have to be recognized in this class. 
 */
public class RankerFactory {

	protected Ranker[] rFactory = new Ranker[]{new MART(), new RankBoost(), new RankNet(), new AdaRank(), new CoorAscent(), new LambdaRank(), new LambdaMART(), new ListNet(), new RFRanker(), new LinearRegRank()};
	protected static HashMap map = new HashMap();
	
	public RankerFactory()
	{
		map.put(createRanker(RANKER_TYPE.MART).name().toUpperCase(), RANKER_TYPE.MART);
		map.put(createRanker(RANKER_TYPE.RANKNET).name().toUpperCase(), RANKER_TYPE.RANKNET);
		map.put(createRanker(RANKER_TYPE.RANKBOOST).name().toUpperCase(), RANKER_TYPE.RANKBOOST);
		map.put(createRanker(RANKER_TYPE.ADARANK).name().toUpperCase(), RANKER_TYPE.ADARANK);
		map.put(createRanker(RANKER_TYPE.COOR_ASCENT).name().toUpperCase(), RANKER_TYPE.COOR_ASCENT);
		map.put(createRanker(RANKER_TYPE.LAMBDARANK).name().toUpperCase(), RANKER_TYPE.LAMBDARANK);
		map.put(createRanker(RANKER_TYPE.LAMBDAMART).name().toUpperCase(), RANKER_TYPE.LAMBDAMART);
		map.put(createRanker(RANKER_TYPE.LISTNET).name().toUpperCase(), RANKER_TYPE.LISTNET);
		map.put(createRanker(RANKER_TYPE.RANDOM_FOREST).name().toUpperCase(), RANKER_TYPE.RANDOM_FOREST);
		map.put(createRanker(RANKER_TYPE.LINEAR_REGRESSION).name().toUpperCase(), RANKER_TYPE.LINEAR_REGRESSION);
	}	
	public Ranker createRanker(RANKER_TYPE type)
	{
		return rFactory[type.ordinal() - RANKER_TYPE.MART.ordinal()].createNew();
	}
	public Ranker createRanker(RANKER_TYPE type, List samples, int[] features, MetricScorer scorer)
	{
		Ranker r = createRanker(type);
		r.setTrainingSet(samples);
		r.setFeatures(features);
		r.setMetricScorer(scorer);
		return r;
	}
	@SuppressWarnings("unchecked")
	public Ranker createRanker(String className)
	{
		Ranker r = null;
		try {
			Class c = Class.forName(className);
			r = (Ranker) c.newInstance();
		}
		catch (ClassNotFoundException e) {
			System.out.println("Could find the class \"" + className + "\" you specified. Make sure the jar library is in your classpath.");
			e.printStackTrace();
			System.exit(1);
		}
		catch (InstantiationException e) {
			System.out.println("Cannot create objects from the class \"" + className + "\" you specified.");
			e.printStackTrace();
			System.exit(1);
		}
		catch (IllegalAccessException e) {
			System.out.println("The class \"" + className + "\" does not implement the Ranker interface.");
			e.printStackTrace();
			System.exit(1);
		}
		return r;
	}
	public Ranker createRanker(String className, List samples, int[] features, MetricScorer scorer)
	{
		Ranker r = createRanker(className);
		r.setTrainingSet(samples);
		r.setFeatures(features);
		r.setMetricScorer(scorer);
		return r;
	}
	public Ranker loadRankerFromFile(String modelFile)
	{
    return loadRankerFromString(FileUtils.read(modelFile, "ASCII"));
	}
  public Ranker loadRankerFromString(String fullText)
  {
    try (BufferedReader in = new BufferedReader(new StringReader(fullText))) {
			Ranker r;
      String content = in.readLine();//read the first line to get the name of the ranking algorithm
      content = content.replace("## ", "").trim();
      System.out.println("Model:\t\t" + content);
      r = createRanker(map.get(content.toUpperCase()));
      r.loadFromString(fullText);
			return r;
    }
    catch(Exception ex)
    {
			throw RankLibError.create(ex);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy