All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.sourceforge.cilib.problem.nn.NNDataTrainingProblem Maven / Gradle / Ivy

/**           __  __
 *    _____ _/ /_/ /_    Computational Intelligence Library (CIlib)
 *   / ___/ / / / __ \   (c) CIRG @ UP
 *  / /__/ / / / /_/ /   http://cilib.net
 *  \___/_/_/_/_.___/
 */
package net.sourceforge.cilib.problem.nn;

import com.google.common.annotations.VisibleForTesting;
import net.sourceforge.cilib.algorithm.AbstractAlgorithm;
import net.sourceforge.cilib.io.DataTable;
import net.sourceforge.cilib.io.DataTableBuilder;
import net.sourceforge.cilib.io.DelimitedTextFileReader;
import net.sourceforge.cilib.io.StandardPatternDataTable;
import net.sourceforge.cilib.io.exception.CIlibIOException;
import net.sourceforge.cilib.io.pattern.StandardPattern;
import net.sourceforge.cilib.io.transform.ShuffleOperator;
import net.sourceforge.cilib.io.transform.TypeConversionOperator;
import net.sourceforge.cilib.nn.architecture.visitors.OutputErrorVisitor;
import net.sourceforge.cilib.nn.domain.*;
import net.sourceforge.cilib.problem.AbstractProblem;
import net.sourceforge.cilib.problem.solution.Fitness;
import net.sourceforge.cilib.type.DomainRegistry;
import net.sourceforge.cilib.type.types.Numeric;
import net.sourceforge.cilib.type.types.Type;
import net.sourceforge.cilib.type.types.container.Vector;

/**
 * Class represents a {@link NNTrainingProblem} where the goal is to optimize
 * the set of weights of a neural network to best fit a given static dataset (either
 * regression, classification etc.).
 */
public class NNDataTrainingProblem extends NNTrainingProblem {
    private static final long serialVersionUID = -8765101028460476990L;

    private DataTableBuilder dataTableBuilder;
    private DomainInitialisationStrategy domainInitialisationStrategy;
    private SolutionConversionStrategy solutionConversionStrategy;
    private int previousShuffleIteration;
    private boolean initialised;

    /**
     * Default constructor.
     */
    public NNDataTrainingProblem() {
        super();
        dataTableBuilder = new DataTableBuilder(new DelimitedTextFileReader());
        domainInitialisationStrategy = new WeightBasedDomainInitialisationStrategy();
        solutionConversionStrategy = new WeightSolutionConversionStrategy();
        previousShuffleIteration = -1;
        initialised = false;
    }

    /**
     * Initialises the problem by reading in the data and constructing the training
     * and generalisation sets. Also initialises (constructs) the neural network.
     */
    @Override
    public void initialise() {
        if (initialised) {
            return;
        }
        try {
            dataTableBuilder.addDataOperator(new TypeConversionOperator());
            dataTableBuilder.addDataOperator(patternConversionOperator);
            dataTableBuilder.buildDataTable();
            DataTable dataTable = dataTableBuilder.getDataTable();

            ShuffleOperator initialShuffler = new ShuffleOperator();
            initialShuffler.operate(dataTable);

            int trainingSize = (int) (dataTable.size() * trainingSetPercentage);
            int validationSize = (int) (dataTable.size() * validationSetPercentage);
            int generalisationSize = dataTable.size() - trainingSize - validationSize;

            trainingSet = new StandardPatternDataTable();
            validationSet = new StandardPatternDataTable();
            generalisationSet = new StandardPatternDataTable();

            for (int i = 0; i < trainingSize; i++) {
                trainingSet.addRow((StandardPattern) dataTable.getRow(i));
            }

            for (int i = trainingSize; i < validationSize + trainingSize; i++) {
                validationSet.addRow((StandardPattern) dataTable.getRow(i));
            }

            for (int i = validationSize + trainingSize; i < generalisationSize + validationSize + trainingSize; i++) {
                generalisationSet.addRow((StandardPattern) dataTable.getRow(i));
            }

            neuralNetwork.initialise();
        } catch (CIlibIOException exception) {
            exception.printStackTrace();
        }
        initialised = true;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public AbstractProblem getClone() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    /**
     * Calculates the fitness of the given solution by setting the neural network
     * weights to the solution and evaluating the training set in order to calculate
     * the MSE (which is minimized).
     *
     * @param solution the weights representing a solution.
     * @return a new MinimisationFitness wrapping the MSE training error.
     */
    @Override
    protected Fitness calculateFitness(Type solution) {
        if (trainingSet == null) {
            this.initialise();
        }

        int currentIteration = AbstractAlgorithm.get().getIterations();
        if (currentIteration != previousShuffleIteration) {
            try {
                shuffler.operate(trainingSet);
            } catch (CIlibIOException exception) {
                exception.printStackTrace();
            }
        }

        neuralNetwork.getArchitecture().accept(solutionConversionStrategy.interpretSolution(solution));

        double errorTraining = 0.0;
        OutputErrorVisitor visitor = new OutputErrorVisitor();
        Vector error = null;
        for (StandardPattern pattern : trainingSet) {
            Vector output = neuralNetwork.evaluatePattern(pattern);
            visitor.setInput(pattern);
            neuralNetwork.getArchitecture().accept(visitor);
            error = visitor.getOutput();
            for (Numeric real : error) {
                errorTraining += real.doubleValue() * real.doubleValue();
            }
        }
        errorTraining /= trainingSet.getNumRows() * error.size();

        return objective.evaluate(errorTraining);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public DomainRegistry getDomain() {
        if (!initialised) {
            this.initialise();
        }
        return initialiseDomain();
    }

    @VisibleForTesting
    protected DomainRegistry initialiseDomain() {
        solutionConversionStrategy.initialise(neuralNetwork);
        return domainInitialisationStrategy.initialiseDomain(neuralNetwork);
    }

    /**
     * Gets the datatable builder.
     *
     * @return the datatable builder.
     */
    public DataTableBuilder getDataTableBuilder() {
        return dataTableBuilder;
    }

    /**
     * Sets the datatable builder.
     *
     * @param dataTableBuilder the new datatable builder.
     */
    public void setDataTableBuilder(DataTableBuilder dataTableBuilder) {
        this.dataTableBuilder = dataTableBuilder;
    }

    /**
     * Gets the source URL of the the datatable builder.
     *
     * @return the source URL of the the datatable builder.
     */
    public String getSourceURL() {
        return dataTableBuilder.getSourceURL();
    }

    /**
     * Sets the source URL of the the datatable builder.
     *
     * @param sourceURL the new source URL of the the datatable builder.
     */
    public void setSourceURL(String sourceURL) {
        dataTableBuilder.setSourceURL(sourceURL);
    }

    public DomainInitialisationStrategy getDomainInitialisationStrategy() {
        return domainInitialisationStrategy;
    }

    public void setDomainInitialisationStrategy(DomainInitialisationStrategy domainInitialisationStrategy) {
        this.domainInitialisationStrategy = domainInitialisationStrategy;
    }

    public SolutionConversionStrategy getSolutionConversionStrategy() {
        return solutionConversionStrategy;
    }

    public void setSolutionConversionStrategy(SolutionConversionStrategy solutionConversionStrategy) {
        this.solutionConversionStrategy = solutionConversionStrategy;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy