All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.sourceforge.cilib.gd.GradientDescentBackpropagationTraining Maven / Gradle / Ivy

/**           __  __
 *    _____ _/ /_/ /_    Computational Intelligence Library (CIlib)
 *   / ___/ / / / __ \   (c) CIRG @ UP
 *  / /__/ / / / /_/ /   http://cilib.net
 *  \___/_/_/_/_.___/
 */
package net.sourceforge.cilib.gd;

import java.util.ArrayList;
import java.util.List;
import net.sourceforge.cilib.algorithm.AbstractAlgorithm;
import net.sourceforge.cilib.algorithm.SingularAlgorithm;
import net.sourceforge.cilib.controlparameter.ConstantControlParameter;
import net.sourceforge.cilib.controlparameter.ControlParameter;
import net.sourceforge.cilib.io.StandardPatternDataTable;
import net.sourceforge.cilib.io.exception.CIlibIOException;
import net.sourceforge.cilib.io.pattern.StandardPattern;
import net.sourceforge.cilib.nn.NeuralNetwork;
import net.sourceforge.cilib.nn.architecture.visitors.BackPropagationVisitor;
import net.sourceforge.cilib.nn.architecture.visitors.OutputErrorVisitor;
import net.sourceforge.cilib.problem.nn.NNTrainingProblem;
import net.sourceforge.cilib.problem.solution.MinimisationFitness;
import net.sourceforge.cilib.problem.solution.OptimisationSolution;
import net.sourceforge.cilib.type.types.Numeric;
import net.sourceforge.cilib.type.types.container.Vector;

/**
 * Class implementing the gradient descent backpropagation training algorithm. The
 * error measure used is MSE and it supports both learning rate and momentum
 * parameters.
 */
public class GradientDescentBackpropagationTraining extends AbstractAlgorithm implements SingularAlgorithm {
    private static final long serialVersionUID = 7984749431187521004L;

    private ControlParameter learningRate;
    private ControlParameter momentum;
    private double errorTraining;
    private BackPropagationVisitor bpVisitor;
    private double[][] previousWeightChanges;

    /**
     * Default constructor.
     */
    public GradientDescentBackpropagationTraining() {
        learningRate = ConstantControlParameter.of(0.1);
        momentum = ConstantControlParameter.of(0.9);
        bpVisitor = new BackPropagationVisitor();
    }

    public GradientDescentBackpropagationTraining(GradientDescentBackpropagationTraining copy) {
        learningRate = copy.learningRate.getClone();
        momentum = copy.momentum.getClone();
        bpVisitor = copy.bpVisitor;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void algorithmInitialisation() {
        NNTrainingProblem problem = (NNTrainingProblem) getOptimisationProblem();
        problem.initialise();
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void algorithmIteration() {
        try {
            NNTrainingProblem problem = (NNTrainingProblem) getOptimisationProblem();
            NeuralNetwork neuralNetwork = problem.getNeuralNetwork();
            StandardPatternDataTable trainingSet = problem.getTrainingSet();
            problem.getShuffler().operate(trainingSet);
            bpVisitor.setLearningRate(this.learningRate.getParameter());
            bpVisitor.setMomentum(this.momentum.getParameter());

            errorTraining = 0.0;
            OutputErrorVisitor visitor = new OutputErrorVisitor();
            Vector error = null;
            for (StandardPattern pattern : trainingSet) {
                neuralNetwork.evaluatePattern(pattern);
                visitor.setInput(pattern);
                neuralNetwork.getArchitecture().accept(visitor);
                error = visitor.getOutput();
                for (Numeric real : error) {
                    errorTraining += real.doubleValue()*real.doubleValue();
                }

                // backpropagate
                bpVisitor.setPreviousPattern(pattern);
                bpVisitor.setPreviousWeightUpdates(previousWeightChanges);
                neuralNetwork.getArchitecture().accept(bpVisitor);
                previousWeightChanges = bpVisitor.getPreviousWeightUpdates();

            }
            errorTraining /= trainingSet.getNumRows() * error.size();

        } catch (CIlibIOException ex) {
            ex.printStackTrace();
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public GradientDescentBackpropagationTraining getClone() {
        return new GradientDescentBackpropagationTraining(this);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public OptimisationSolution getBestSolution() {
        NNTrainingProblem problem = (NNTrainingProblem) getOptimisationProblem();
        NeuralNetwork neuralNetwork = problem.getNeuralNetwork();
        return new OptimisationSolution(neuralNetwork.getWeights(), new MinimisationFitness(errorTraining));
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public List getSolutions() {
        ArrayList list = new ArrayList();
        NNTrainingProblem problem = (NNTrainingProblem) getOptimisationProblem();
        NeuralNetwork neuralNetwork = problem.getNeuralNetwork();
        list.add(new OptimisationSolution(neuralNetwork.getWeights(), new MinimisationFitness(errorTraining)));
        return list;
    }

    /**
     * Gets the learning rate control parameter.
     * @return the learning rate control parameter.
     */
    public ControlParameter getLearningRate() {
        return learningRate;
    }

    /**
     * Sets the learning rate control parameter.
     * @param learningRate the new learning rate control parameter.
     */
    public void setLearningRate(ControlParameter learningRate) {
        this.learningRate = learningRate;
    }

    /**
     * Gets the momentum control parameter.
     * @return the momentum control parameter.
     */
    public ControlParameter getMomentum() {
        return momentum;
    }

    /**
     * Sets the momentum control parameter.
     * @param momentum the new momentum control parameter.
     */
    public void setMomentum(ControlParameter momentum) {
        this.momentum = momentum;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy