All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.encog.ensemble.Ensemble Maven / Gradle / Ivy

The newest version!
/*
 * Encog(tm) Core v3.4 - Java Version
 * http://www.heatonresearch.com/encog/
 * https://github.com/encog/encog-java-core
 
 * Copyright 2008-2017 Heaton Research, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *   
 * For more information on Heaton Research copyrights, licenses 
 * and trademarks visit:
 * http://www.heatonresearch.com/copyright
 */
package org.encog.ensemble;

import java.util.ArrayList;

import org.encog.ensemble.data.EnsembleDataSet;
import org.encog.ensemble.data.factories.EnsembleDataSetFactory;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ensemble.aggregator.WeightedAveraging.WeightMismatchException;

public abstract class Ensemble {

	private final int DEFAULT_MAX_ITERATIONS = 2000;
	protected EnsembleDataSetFactory dataSetFactory;
	protected EnsembleTrainFactory trainFactory;
	protected EnsembleAggregator aggregator;
	protected ArrayList members;
	protected EnsembleMLMethodFactory mlFactory;
	protected MLDataSet aggregatorDataSet;

	public class NotPossibleInThisMethod extends Exception {

		/**
		 * This means the current feature is not applicable in the specified method
		 */
		private static final long serialVersionUID = 5118253806179408868L;

	}

	public class TrainingAborted extends Exception {

		public TrainingAborted(String string) {
			super(string);
		}

		/**
		 * This means we tried training this ensemble and failed too many times in a row
		 */
		private static final long serialVersionUID = -5074472788684621859L;
		
	}

	/**
	 * Initialise ensemble components
	 */
	abstract public void initMembers();

	public EnsembleML generateNewMember() {
		GenericEnsembleML newML = new GenericEnsembleML(mlFactory.createML(this.dataSetFactory.getInputCount(), this.dataSetFactory.getOutputCount()),mlFactory.getLabel());
		newML.setTrainingSet(dataSetFactory.getNewDataSet());
		newML.setTraining(trainFactory.getTraining(newML.getMl(), newML.getTrainingSet()));
		return newML;
	}
	
	public void addNewMember() {
		members.add(generateNewMember());		
	}
	
	public void initMembersBySplits(int splits)
	{
		if ((this.dataSetFactory != null) &&
			(splits > 0) &&
			(this.dataSetFactory.hasSource()))
		{
			for (int i = 0; i < splits; i++)
			{
				GenericEnsembleML newML = new GenericEnsembleML(mlFactory.createML(this.dataSetFactory.getInputCount(), this.dataSetFactory.getOutputCount()),mlFactory.getLabel());
				newML.setTrainingSet(dataSetFactory.getNewDataSet());
				newML.setTraining(trainFactory.getTraining(newML.getMl(), newML.getTrainingSet()));
				members.add(newML);
			}
			if(aggregator.needsTraining())
				aggregatorDataSet = dataSetFactory.getNewDataSet();
		}
	}

	/**
	 * Set the training method to use for this ensemble
	 * @param newTrainFactory The training factory.
	 */
	public void setTrainingMethod(EnsembleTrainFactory newTrainFactory) {
		this.trainFactory = newTrainFactory;
		initMembers();
	}

	/**
	 * Set which training data to base the training on
	 * @param data The training data.
	 */
	public void setTrainingData(MLDataSet data) {
		dataSetFactory.setInputData(data);
		initMembers();
	}

	/**
	 * Set which dataSetFactory to use to create the correct tranining sets
	 * @param dataSetFactory The data set factory.
	 */
	public void setTrainingDataFactory(EnsembleDataSetFactory dataSetFactory) {
		this.dataSetFactory = dataSetFactory;
		initMembers();
	}

	public void trainMember(int index, double targetError, double selectionError, int maxIterations, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
		EnsembleML current = members.get(index);
		trainMember(current,targetError,selectionError,maxIterations,DEFAULT_MAX_ITERATIONS,selectionSet,verbose);
	}

	public void trainMember(EnsembleML current, double targetError, double selectionError, int maxIterations, int maxLoops, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
		int attempt = 0;
		do {
			long startTime = System.nanoTime();
			mlFactory.reInit(current.getMl());
			current.train(targetError, maxIterations, verbose);
			long endTime = System.nanoTime();
			if (verbose)
			{
				System.out.println("training took " + ((double)(endTime - startTime) / 1000000000.0));
				System.out.println("test MSE: " + current.getError(selectionSet) + " on " + selectionSet.size() + " data points");
			};
			attempt++;
			if (attempt > maxLoops) {
				throw new TrainingAborted("Too many attempts at training ensemble member");
			}
		} while (current.getError(selectionSet) > selectionError);
	}
	
	public void trainMember(EnsembleML current, double targetError, double selectionError, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
		trainMember(current, targetError, selectionError, DEFAULT_MAX_ITERATIONS, DEFAULT_MAX_ITERATIONS, selectionSet, verbose);
	}

	public void trainMember(int index, double targetError, double selectionError, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
		trainMember(index, targetError, selectionError, DEFAULT_MAX_ITERATIONS, selectionSet, verbose);
	}
	
	public void retrainAggregator() {
		EnsembleDataSet aggTrainingSet = new EnsembleDataSet(members.size() * aggregatorDataSet.getIdealSize(),aggregatorDataSet.getIdealSize());
		for (MLDataPair trainingInput:aggregatorDataSet) {
			BasicMLData trainingInstance = new BasicMLData(members.size() * aggregatorDataSet.getIdealSize());
			int index = 0;
			for(EnsembleML member:members){
				for(double val:member.compute(trainingInput.getInput()).getData()) {
					trainingInstance.add(index++, val);
				}
			}
			aggTrainingSet.add(trainingInstance,trainingInput.getIdeal());
		}
		aggregator.setTrainingSet(aggTrainingSet);
		aggregator.train();
	}
	
	/**
	 * Train the ensemble to a target accuracy
	 * @param targetError The target error.
	 * @param selectionError The selection error.
	 * @param maxIterations Max iterations.
	 * @param maxLoops Max loops.
	 * @param selectionSet Selection set.
	 * @param verbose Verbose.
     * @throws TrainingAborted Training was aborted.
     */
	public void train(double targetError, double selectionError, int maxIterations, int maxLoops, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
		
		for (EnsembleML current : members)
		{
			trainMember(current, targetError, selectionError, maxIterations, maxLoops, selectionSet, verbose);
		}
		if(aggregator.needsTraining()) {
			retrainAggregator();
		}
	}
	public void train(double targetError, double selectionError, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
		train(targetError, selectionError, DEFAULT_MAX_ITERATIONS, DEFAULT_MAX_ITERATIONS, selectionSet, verbose);
	}

	/**
	 * Train the ensemble to a target accuracy
	 * @param targetError The target error.
	 * @param selectionError The selection error.
	 * @param testset The test set.
	 * @throws TrainingAborted Training aborted.
	 */
	public void train(double targetError, double selectionError, EnsembleDataSet testset) throws TrainingAborted {
		train(targetError, selectionError, testset, false);
	}

	public void train(double targetError, double selectionError, int maxIterations, EnsembleDataSet testset) throws TrainingAborted {
		train(targetError, selectionError, maxIterations, DEFAULT_MAX_ITERATIONS, testset, false);
	}

	/**
	 * Extract a specific training set from the Ensemble
	 * @param setNumber The set number.
	 * @return The training set.
	 */
	public MLDataSet getTrainingSet(int setNumber) {
		return members.get(setNumber).getTrainingSet();
	}

	/**
	 * Extract a specific MLMethod
	 * @param memberNumber The member number.
	 * @return The MLMethod.
	 */
	public EnsembleML getMember(int memberNumber) {
		return members.get(memberNumber);
	}

	/**
	 * Add a member to the ensemble
	 * @param newMember The new member.
	 * @throws NotPossibleInThisMethod Not possible in this method.
	 */
	public void addMember(EnsembleML newMember) throws NotPossibleInThisMethod {
		members.add(newMember);
	}

	/**
	 * Compute the output for a specific input
	 * @param input The input.
	 * @return The data.
	 * @throws WeightMismatchException Weight mismatch exception.
	 */
	public MLData compute(MLData input) throws WeightMismatchException {
		ArrayList outputs = new ArrayList();
		for(EnsembleML member: members)
		{
			MLData computed = member.compute(input);
			outputs.add(computed);
		}
		return aggregator.evaluate(outputs);
	}

	/**
	 * @return Returns the ensemble aggregation method
	 */
	public EnsembleAggregator getAggregator() {
		return aggregator;
	}

	/**
	 * Sets the ensemble aggregation method
	 * @param aggregator The aggregator.
	 */
	public void setAggregator(EnsembleAggregator aggregator) {
		this.aggregator = aggregator;
	}

	/**
	 * Return what type of problem this Ensemble is solving
	 * @return The problem type.
	 */
	abstract public EnsembleTypes.ProblemType getProblemType();

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy