All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.uci.jforestsx.applications.ClassificationApp Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package edu.uci.jforestsx.applications;

import java.io.InputStream;
import java.util.Properties;
import java.util.Random;

import edu.uci.jforestsx.config.TrainingConfig;
import edu.uci.jforestsx.dataset.Dataset;
import edu.uci.jforestsx.dataset.DatasetLoader;
import edu.uci.jforestsx.eval.AUC;
import edu.uci.jforestsx.eval.Accuracy;
import edu.uci.jforestsx.eval.BalancedYoundenIndex;
import edu.uci.jforestsx.eval.EvaluationMetric;
import edu.uci.jforestsx.eval.RMSE;
import edu.uci.jforestsx.learning.LearningModule;
import edu.uci.jforestsx.learning.LearningProgressListener;
import edu.uci.jforestsx.learning.boosting.GradientBoosting;
import edu.uci.jforestsx.learning.classification.GradientBoostingBinaryClassifier;
import edu.uci.jforestsx.learning.trees.Ensemble;
import edu.uci.jforestsx.learning.trees.decision.RandomForest;
import edu.uci.jforestsx.learning.trees.regression.RegressionTreeLearner;
import edu.uci.jforestsx.sample.Sample;
import edu.uci.jforestsx.util.ConfigHolder;
import edu.uci.jforestsx.util.Constants;
import edu.uci.jforestsx.util.IOUtils;
import edu.uci.jforestsx.util.Timer;
import edu.uci.jforestsx.util.concurrency.BlockingThreadPoolExecutor;

/**
 * @author Yasser Ganjisaffar 
 */

public class ClassificationApp {

	protected Dataset trainDataset;
	protected Dataset validDataset;
	protected LearningModule topLearner;

	protected Sample trainSet;
	protected Sample validSet;

	protected IOUtils ioUtils;
	protected EvaluationMetric evaluationMetric;
	protected Random rnd;

	protected TrainingConfig trainingConfig;
	protected ConfigHolder configHolder;
	
	protected LearningProgressListener progressListener = null;
	
	public ClassificationApp() {
		initIOUtils();
	}

	public void setProgressListener(LearningProgressListener progressListener) {
		this.progressListener = progressListener;
	}

	protected void initIOUtils() {
		if (ioUtils == null) {
			ioUtils = new IOUtils();
		}
	}

	protected void loadConfig() {
		trainingConfig = new TrainingConfig();
		trainingConfig.init(configHolder);
	}

	protected void init() throws Exception {
		BlockingThreadPoolExecutor.init(trainingConfig.numThreads);

		initDataset(trainDataset);
		if (validSet != null) {
			initDataset(validSet.dataset);
		}

		if (trainingConfig.featureNamesFilename != null) {
			trainDataset.loadFeatureNamesFromExternalResource(ioUtils.getInputStream(trainingConfig.featureNamesFilename));
		}
	}

	protected LearningModule getLearningModule(String name) throws Exception {
		int maxNumTrainInstances = trainDataset.numInstances;
		int maxNumValidInstances = (validDataset != null ? validDataset.numInstances : trainDataset.numInstances);
		if (name.equals("GradientBoostingBinaryClassifier")) {
			GradientBoostingBinaryClassifier learner = new GradientBoostingBinaryClassifier();
			learner.init(configHolder, maxNumTrainInstances, maxNumValidInstances, evaluationMetric);
			return learner;
		} else if (name.equals("GradientBoosting")) {
			GradientBoosting learner = new GradientBoosting();
			learner.init(configHolder, maxNumTrainInstances, maxNumValidInstances, evaluationMetric);
			return learner;
		} else if (name.equals("RegressionTree")) {
			RegressionTreeLearner learner = new RegressionTreeLearner();
			learner.init(trainDataset, configHolder, maxNumTrainInstances);
			return learner;
		} else if (name.equals("RandomForest")) {
			RandomForest learner = new RandomForest();
			learner.init(trainDataset, configHolder, maxNumTrainInstances, maxNumValidInstances, evaluationMetric);
			return learner;
		} else {
			throw new Exception("Unknown algorithm: " + name);
		}
	}

	protected EvaluationMetric getEvaluationMetric(String name) throws Exception {
		if (name.equals("AUC")) {
			return new AUC();
		} else if (name.equals("RMSE")) {
			return new RMSE();
		} else if (name.equals("Accuracy")) {
			return new Accuracy();
		} else if (name.equals("BalancedYoundenIndex")) {
			return new BalancedYoundenIndex();
		} else {
			throw new Exception("Unknown evaluation metric: " + name);
		}
	}

	protected void createLearner() throws Exception {
		String[] parts = trainingConfig.learningAlgorithm.split("-");
		topLearner = getLearningModule(parts[0]);
		if (progressListener != null) {
			topLearner.setProgressListener(progressListener);
		}
		LearningModule curModule = topLearner;
		for (int i = 1; i < parts.length; i++) {
			LearningModule newModule = getLearningModule(parts[i]);
			if (progressListener != null) {
				newModule.setProgressListener(progressListener);
			}
			curModule.setSubModule(newModule);
			curModule = newModule;
		}
	}

	protected void loadDataset(InputStream in, Dataset dataset) throws Exception {
		DatasetLoader.load(in, dataset);
	}

	public void loadDataset(String uri, Dataset dataset) throws Exception {
		if (dataset != null && dataset.uri != null && dataset.uri.equals(uri)) {
			// This data set is already loaded.
			dataset.needsInitialization = false;
			return;
		}
		InputStream in = ioUtils.getInputStream(uri);
		loadDataset(in, dataset);
		dataset.uri = uri;
		dataset.needsInitialization = true;
		in.close();
	}

	protected double getMeasurement(double[] scores, Sample sample) throws Exception {
		return sample.evaluate(scores, evaluationMetric, 1.0);
	}

	protected Dataset newDataset() {
		return new Dataset();
	}

	protected void initDataset(Dataset dataset) throws Exception {
		// Will be overridden by subclasses
	}

	protected Sample createSample(Dataset dataset, boolean trainSample) {
		return new Sample(dataset);
	}

	protected int getMaxTrainInstances() {
		return trainDataset.numInstances;
	}

	public Ensemble run(Properties configProperties) {
		try {
			configHolder = new ConfigHolder(configProperties);
			loadConfig();
			if (!trainingConfig.validate(ioUtils)) {
				System.out.println("Error: " + trainingConfig.getErrorMessage());
				return null;
			}
			rnd = new Random(trainingConfig.randomSeed);

			System.out.println("Loading datasets...");
			if (trainDataset == null) {
				trainDataset = newDataset();
			}
			loadDataset(trainingConfig.trainFilename, trainDataset);
			int maxInstances = getMaxTrainInstances();

			if (trainingConfig.validFilename != null) {
				if (validDataset == null) {
					validDataset = newDataset();
				}
				loadDataset(trainingConfig.validFilename, validDataset);
				if (validDataset.numInstances > maxInstances) {
					maxInstances = validDataset.numInstances;
				}
			} else {
				validDataset = null;
			}
			System.out.println("Finished loading datasets.");

			Constants.init(maxInstances);

			Sample allTrainSample = createSample(trainDataset, true);
			trainSet = allTrainSample.getRandomSubSample(trainingConfig.trainFraction, rnd);

			if (validDataset != null) {
				validSet = createSample(validDataset, false);
				if (trainingConfig.validFraction < 1.0) {
					validSet = validSet.getRandomSubSample(trainingConfig.validFraction, rnd);
				}
			} else if (trainingConfig.validOutOfTrain) {
				validSet = allTrainSample.getOutOfSample(trainSet);
			}

			init();

			evaluationMetric = getEvaluationMetric(trainingConfig.evaluationMetric);

			createLearner();

			Timer timer = new Timer();
			timer.start();
			Ensemble ensemble = topLearner.learn(trainSet, validSet);
			System.out.println("Time taken to build model: " + (timer.getElapsedMillis() / 1000.0) + " seconds.");
			return ensemble;

		} catch (Exception e) {
			e.printStackTrace();
		}

		return null;
	}

	public int getTrainingRandomSeed() {
		return trainingConfig.randomSeed;
	}

	public static void shutdown() {
		BlockingThreadPoolExecutor executor = BlockingThreadPoolExecutor.getInstance();
		if (executor != null && !executor.isShutdown()) {
			executor.shutdownNow();
		}
	}

	public EvaluationMetric getEvaluationMetric() {
		return evaluationMetric;
	}

	public double getValidMeasurement() throws Exception {
		return topLearner.getValidationMeasurement();
	}

	public Sample getTrainSample() {
		return trainSet;
	}

	public Sample getValidSample() {
		return validSet;
	}
	
	public ConfigHolder getConfigHolder() {
		return configHolder;
	}
	
	public IOUtils getIOUtils() {
		return ioUtils;
	}
	
	public LearningProgressListener getProgressListener() {
		return progressListener;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy