All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.encog.ensemble.GenericEnsembleML Maven / Gradle / Ivy

The newest version!
/*
 * Encog(tm) Core v3.4 - Java Version
 * http://www.heatonresearch.com/encog/
 * https://github.com/encog/encog-java-core
 
 * Copyright 2008-2017 Heaton Research, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *   
 * For more information on Heaton Research copyrights, licenses 
 * and trademarks visit:
 * http://www.heatonresearch.com/copyright
 */
package org.encog.ensemble;

import org.encog.ensemble.data.EnsembleDataSet;
import org.encog.ml.MLMethod;
import org.encog.ml.data.MLData;
import org.encog.ml.train.MLTrain;
import org.encog.neural.networks.BasicNetwork;
import org.encog.util.EngineArray;

public class GenericEnsembleML implements EnsembleML {

	private EnsembleDataSet trainingSet;
	private BasicNetwork ml;
	private MLTrain trainer;
	private String label;
	private final int DEFAULT_MAX_ITERATIONS = 2000;

	public GenericEnsembleML(MLMethod fromML, String description) {
		setMl(fromML);
		this.label = description;
	}

	@Override
	public void setTrainingSet(EnsembleDataSet dataSet) {
		this.trainingSet = dataSet;
	}

	@Override
	public EnsembleDataSet getTrainingSet() {
		return trainingSet;
	}

	@Override
	public void train(double targetError, int maxIterations, boolean verbose) {
		double error = 0;
		double previouserror = 1;
		double errordelta = 1;
		int iteration = 0;
		do {
			trainer.iteration();
			iteration++;
			if (iteration > 1) {
				previouserror = error;
			}
			error = trainer.getError();
			if (iteration > 1) {
				errordelta = previouserror - error;
			}
			if (verbose) System.out.println(iteration + " " + error);
		} while ((error > targetError) &&
				 trainer.canContinue() &&
				 //errordelta / previouserror < 2 &&
				 iteration < maxIterations);
		trainer.finishTraining();
	}

	@Override
	public void train(double targetError) {
		train(targetError, false);
	}

	@Override
	public void train(double targetError, int maxIterations) {
		train(targetError, maxIterations, false);
	}

	@Override
	public void train(double targetError, boolean verbose) {
		train(targetError, DEFAULT_MAX_ITERATIONS, verbose);
		
	}

	@Override
	public void setMl(MLMethod newMl) {
		ml = (BasicNetwork) newMl;
	}

	@Override
	public MLMethod getMl() {
		return ml;
	}

	@Override
	public int classify(MLData input) {
		return ml.classify(input);
	}

	@Override
	public MLData compute(MLData input) {
		return ml.compute(input);
	}

	@Override
	public int getInputCount() {
		return ml.getInputCount();
	}

	@Override
	public int getOutputCount() {
		return ml.getOutputCount();
	}

	public int winner(MLData output) {
		return EngineArray.maxIndex(output.getData());
	}

	@Override
	public void setTraining(MLTrain train) {
		trainer = train;
	}

	@Override
	public MLTrain getTraining() {
		return trainer;
	}

	@Override
	public void trainStep() {
		trainer.iteration();
	}

	@Override
	public String getLabel() {
		return label;
	}

	@Override
	public double getError(EnsembleDataSet testset) {
		return ml.calculateError(testset);
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy