All Downloads are FREE. Search and download functionalities are using the official Maven repository.

java.main.ivory.ltr.GreedyLearn Maven / Gradle / Ivy

/*
 * Ivory: A Hadoop toolkit for web-scale information retrieval
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"); you
 * may not use this file except in compliance with the License. You may
 * obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0 
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

package ivory.ltr;



import ivory.core.ConfigurationException;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;


/**
 * @author Don Metzler
 *
 */
public class GreedyLearn {

	private static final double TOLERANCE = 0.0001;

	public void train(String featFile, String modelOutputFile, int numModels, String metricClassName, boolean pruneCorrelated, double correlationThreshold, boolean logFeatures, boolean productFeatures, boolean quotientFeatures, int numThreads) throws IOException, InterruptedException, ExecutionException, ConfigurationException, InstantiationException, IllegalAccessException, ClassNotFoundException {		
		// read training instances
		Instances trainInstances = new Instances(featFile);

		// get feature map (mapping from feature names to feature number)
		Map featureMap = trainInstances.getFeatureMap();

		// construct initial model
		Model initialModel = new Model();

		// initialize feature pools
		Map> featurePool = new HashMap>();
		featurePool.put(initialModel, new ArrayList());

		// add simple features to feature pools
		for(String featureName : featureMap.keySet()) {
			featurePool.get(initialModel).add(new SimpleFeature(featureMap.get(featureName), featureName));
		}

		// eliminate document-independent features
		List constantFeatures = new ArrayList();
		for(int i = 0; i < featurePool.size(); i++) {
			Feature f = featurePool.get(initialModel).get(i);
			if(trainInstances.featureIsConstant(f)) {
				System.err.println("Feature " + f.getName() + " is constant -- removing from feature pool!");
				constantFeatures.add(f);
			}
		}
		featurePool.get(initialModel).removeAll(constantFeatures);

		// initialize score tables
		Map scoreTable = new HashMap();
		scoreTable.put(initialModel, new ScoreTable(trainInstances));

		// initialize model queue
		List models = new ArrayList();
		models.add(initialModel);

		// set up threading
		ExecutorService threadPool = Executors.newFixedThreadPool(numThreads);

		Map>> featureBatches = new HashMap>>();
		featureBatches.put(initialModel, new ArrayList>());

		for(int i = 0; i < numThreads; i++) {
			featureBatches.get(initialModel).add(new ArrayList());
		}

		for(int i = 0; i < featurePool.get(initialModel).size(); i++) {
			featureBatches.get(initialModel).get(i % numThreads).add(featurePool.get(initialModel).get(i));
		}

		// greedily add features
		double curMetric = 0.0;
		double prevMetric = Double.NEGATIVE_INFINITY;
		int iter = 1;

		while(curMetric - prevMetric > TOLERANCE ) {

			Map modelFeaturePairMeasures = new HashMap();

			// update models
			for(Model model : models) {
				List>> futures = new ArrayList>>();
				for(int i = 0; i < numThreads; i++) {
					// construct measure
					Measure metric = (Measure)Class.forName(metricClassName).newInstance();

					// line searcher
					LineSearch search = new LineSearch(model, featureBatches.get(model).get(i), scoreTable.get(model), metric);
					
					Future> future = threadPool.submit(search);
					futures.add(future);
				}

				for(int i = 0; i < numThreads; i++) {
					Map featAlphaMeasureMap = futures.get(i).get();
					for(Feature f : featAlphaMeasureMap.keySet()) {
						AlphaMeasurePair featAlphaMeasure = featAlphaMeasureMap.get(f);
						modelFeaturePairMeasures.put(new ModelFeaturePair(model, f), featAlphaMeasure);
					}
				}				
			}

			// sort model-feature pairs
			List modelFeaturePairs = new ArrayList(modelFeaturePairMeasures.keySet());
			Collections.sort(modelFeaturePairs, new ModelFeatureComparator(modelFeaturePairMeasures));

			// preserve current list of models
			List oldModels = new ArrayList(models);

			// add best model feature pairs to pool
			models = new ArrayList();

			//Lidan: here consider top-K features, rather than just the best one

			for(int i = 0; i < numModels; i++) {
				Model model = modelFeaturePairs.get(i).model;
				Feature feature = modelFeaturePairs.get(i).feature;
				String bestFeatureName = feature.getName();
				AlphaMeasurePair bestAlphaMeasure = modelFeaturePairMeasures.get(modelFeaturePairs.get(i));

				System.err.println("Model = " + model);
				System.err.println("Best feature: " + bestFeatureName);
				System.err.println("Best alpha: " + bestAlphaMeasure.alpha);
				System.err.println("Best measure: " + bestAlphaMeasure.measure);

				Model newModel = new Model(model);
				models.add(newModel);

				ArrayList> newFeatureBatch = new ArrayList>();
				for(ArrayList fb : featureBatches.get(model)) {
					newFeatureBatch.add(new ArrayList(fb));
				}
				featureBatches.put(newModel, newFeatureBatch);
				featurePool.put(newModel, new ArrayList(featurePool.get(model)));

				// add auxiliary features (for atomic features only)
				if(featureMap.containsKey(bestFeatureName)) {
					int bestFeatureIndex = featureMap.get(bestFeatureName);

					// add log features, if requested
					if(logFeatures) {
						Feature logFeature = new LogFeature(bestFeatureIndex, "log(" + bestFeatureName + ")");
						featureBatches.get(newModel).get(bestFeatureIndex % numThreads).add(logFeature);
						featurePool.get(newModel).add(logFeature);
					}

					// add product features, if requested
					if(productFeatures) {
						for(String featureNameB : featureMap.keySet()) {
							int indexB = featureMap.get(featureNameB);
							Feature prodFeature = new ProductFeature(bestFeatureIndex, indexB, bestFeatureName + "*" + featureNameB);
							featureBatches.get(newModel).get(indexB % numThreads).add(prodFeature);
							featurePool.get(newModel).add(prodFeature);
						}
					}

					// add quotient features, if requested
					if(quotientFeatures) {
						for(String featureNameB : featureMap.keySet()) {
							int indexB = featureMap.get(featureNameB);
							Feature divFeature = new QuotientFeature(bestFeatureIndex, indexB, bestFeatureName + "/" + featureNameB);
							featureBatches.get(newModel).get(indexB % numThreads).add(divFeature);
							featurePool.get(newModel).add(divFeature);
						}
					}
				}

				// prune highly correlated features
				if(pruneCorrelated) {
					if(!newModel.containsFeature(feature)) {
						List correlatedFeatures = new ArrayList();

						for(Feature f : featurePool.get(newModel)) {
							if(f == feature) {
								continue;
							}
							double correl = trainInstances.getCorrelation(f, feature);
							if(correl > correlationThreshold) {
								System.err.println("Pruning highly correlated feature: " + f.getName());
								correlatedFeatures.add(f);
							}
						}

						for(ArrayList batch : featureBatches.get(newModel)) {
							batch.removeAll(correlatedFeatures);
						}

						featurePool.get(newModel).removeAll(correlatedFeatures);
					}
				}

				// update score table
				if(iter == 0) {
					scoreTable.put(newModel, scoreTable.get(model).translate(feature, 1.0, 1.0));
					newModel.addFeature(feature, 1.0);
				}
				else {
					scoreTable.put(newModel, scoreTable.get(model).translate(feature, bestAlphaMeasure.alpha, 1.0 / (1.0 + bestAlphaMeasure.alpha)));
					newModel.addFeature(feature, bestAlphaMeasure.alpha);
				}
			}

			for(Model model : oldModels) {
				featurePool.remove(model);
				featureBatches.remove(model);
				scoreTable.remove(model);
			}

			// update metrics
			prevMetric = curMetric;
			curMetric = modelFeaturePairMeasures.get(modelFeaturePairs.get(0)).measure;

			iter++;
		}

		// serialize model
		System.out.println("Final Model: " + models.get(0));
		models.get(0).write(modelOutputFile);

		threadPool.shutdown();
	}

	public class ModelFeaturePair {
		public Model model;
		public Feature feature;

		public ModelFeaturePair(Model m, Feature f) {
			model = m;
			feature = f;
		}
	}

	public class ModelFeatureComparator implements Comparator {

		private Map lookup = null;

		public ModelFeatureComparator(Map lookup) {
			this.lookup = lookup;
		}

		public int compare(ModelFeaturePair o1, ModelFeaturePair o2) {
			if(lookup.get(o1).measure > lookup.get(o2).measure) {
				return -1;
			}
			else if(lookup.get(o1).measure < lookup.get(o2).measure) {
				return 1;
			}
			return 0;
		}		
	}	

	@SuppressWarnings("static-access")
	public static void main(String[] args) throws InterruptedException, ExecutionException {
		Options options = new Options();

		options.addOption( OptionBuilder.withArgName("input").hasArg().withDescription("Input file that contains training instances.").isRequired().create("input") );
		options.addOption( OptionBuilder.withArgName("model").hasArg().withDescription("Model file to create.").isRequired().create("model") );		
		options.addOption( OptionBuilder.withArgName("numModels").hasArg().withDescription("Number of models to consider each iteration (default=1).").create("numModels") );
		options.addOption( OptionBuilder.withArgName("className").hasArg().withDescription("Java class name of metric to optimize for (default=ivory.ltr.NDCGMeasure)").create("metric") );
		options.addOption( OptionBuilder.withArgName("threshold").hasArg().withDescription("Feature correlation threshold for pruning (disabled by default).").create("pruneCorrelated") );
		options.addOption( OptionBuilder.withArgName("log").withDescription("Include log features (default=false).").create("log") );
		options.addOption( OptionBuilder.withArgName("product").withDescription("Include product features (default=false).").create("product") );
		options.addOption( OptionBuilder.withArgName("quotient").withDescription("Include quotient features (default=false).").create("quotient") );
		options.addOption( OptionBuilder.withArgName("numThreads").hasArg().withDescription("Number of threads to utilize (default=1).").create("numThreads") );

		HelpFormatter formatter = new HelpFormatter();
		CommandLineParser parser = new GnuParser();

		String trainFile = null;
		String modelOutputFile = null;

		int numModels = 1;

		String metricClassName = "ivory.ltr.NDCGMeasure";
		
		boolean pruneCorrelated = false;
		double correlationThreshold = 1.0;
		
		boolean logFeatures = false;
		boolean productFeatures = false;
		boolean quotientFeatures = false;

		int numThreads = 1;

		// parse the command-line arguments
		try {
			CommandLine line = parser.parse( options, args);

			if(line.hasOption("input")) {
				trainFile = line.getOptionValue("input");
			}

			if(line.hasOption("model")) {
				modelOutputFile = line.getOptionValue("model");
			}

			if(line.hasOption("numModels")) {
				numModels = Integer.parseInt(line.getOptionValue("numModels"));
			}

			if(line.hasOption("metric")) {
				metricClassName = line.getOptionValue("metric");
			}

			if(line.hasOption("pruneCorrelated")) {
				pruneCorrelated = true;
				correlationThreshold = Double.parseDouble(line.getOptionValue("pruneCorrelated"));
			}

			if(line.hasOption("numThreads")) {
				numThreads = Integer.parseInt(line.getOptionValue("numThreads"));
			}

			if(line.hasOption("log")) {
				logFeatures = true;
			}

			if(line.hasOption("product")) {
				productFeatures = true;
			}

			if(line.hasOption("quotient")) {
				quotientFeatures = true;
			}
		}
		catch(ParseException exp) {
			System.err.println(exp.getMessage());
		}		

		// were all of the required parameters specified?
		if(trainFile == null || modelOutputFile == null) {
			formatter.printHelp("GreedyLearn", options, true);
			System.exit(-1);
		}

		// learn the model
		try {
			GreedyLearn learn = new GreedyLearn();
			learn.train(trainFile, modelOutputFile, numModels, metricClassName, pruneCorrelated, correlationThreshold, logFeatures, productFeatures, quotientFeatures, numThreads);
		} catch (IOException e) {
			e.printStackTrace();
		} catch (ConfigurationException e) {
			e.printStackTrace();
		} catch (InstantiationException e) {
			e.printStackTrace();
		} catch (IllegalAccessException e) {
			e.printStackTrace();
		} catch (ClassNotFoundException e) {
			e.printStackTrace();
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy