All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.enterprisemath.math.nn.BackpropagationUtils Maven / Gradle / Ivy

The newest version!
package com.enterprisemath.math.nn;

import com.enterprisemath.utils.DomainUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
 * Utilities for back propagation.
 * 
 * This class is for internal use only!!!
 * Modifying or removing it will not be considered as a breaking consistency.
 * 
 *
 * @author radek.hecl
 */
class BackpropagationUtils {

    /**
     * Prevents construction.
     */
    private BackpropagationUtils() {
    }

    /**
     * Trains the network.
     *
     * @param inputs input nodes identification, cannot contain keys starting with bias or hidden
     * @param numHiddenNodes number of hidden nodes
     * @param outputs output nodes identification, cannot contain keys starting with bias or hidden
     * @param records training records
     * @param maxIterations maximum number of iterations
     * @param minError1000Improvement minimal error delta for average of 1000 iterations
     * @param allowedError allowed error
     * @return created network
     */
    public static Network train(List inputs, int numHiddenNodes, List outputs, List records,
            int maxIterations, double minError1000Delta, double allowedError) {
        List recs = DomainUtils.softCopyList(records);
        List possibleNetworks = new ArrayList();
        int numNodes = inputs.size() + numHiddenNodes + outputs.size();
        double[][] weights = new double[numNodes][numNodes];
        double[][] deltaWeights = new double[numNodes][numNodes];
        double[] biases = new double[numNodes];
        double[] deltaBiases = new double[numNodes];
        double[] values = new double[numNodes];
        double[] expectedValues = new double[outputs.size()];
        double learningRate = 0.2;
        double momentum = 0.1;

        Random random = new Random(System.currentTimeMillis());
        for (int i = 0; i < numNodes; ++i) {
            for (int j = 0; j < numNodes; ++j) {
                weights[i][j] = random.nextDouble();
            }
        }

        double minError = Double.POSITIVE_INFINITY;
        double error = Double.POSITIVE_INFINITY;
        double error1000 = 0;
        double targetError1000 = Double.POSITIVE_INFINITY;
        double divergenceErrorLimitTrunc = Double.POSITIVE_INFINITY;
        int numGThanDivergenceErrorLimitTrunc = 0;
        for (int iteration = 0; iteration < maxIterations && error > allowedError; ++iteration) {
            Collections.shuffle(recs, random);
            error = 0;
            // run across all records
            for (SupervisedTrainingRecord rec : recs) {
                // expected values
                for (int i = 0; i < outputs.size(); ++i) {
                    expectedValues[i] = rec.getOutputs().get(outputs.get(i));
                }
                // input values
                for (int i = 0; i < inputs.size(); ++i) {
                    values[i] = rec.getInputs().get(inputs.get(i));
                }
                // populate to the hidden layer
                for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
                    double weightedInput = 0;
                    for (int i = 0; i < inputs.size(); ++i) {
                        weightedInput += weights[i][h] * values[i];
                    }
                    weightedInput += (-1 * biases[h]);
                    values[h] = 1d / (1d + Math.exp(-weightedInput));
                }
                // populated to the output layer
                for (int o = inputs.size() + numHiddenNodes; o < numNodes; ++o) {
                    double weightedInput = 0;
                    for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
                        weightedInput += weights[h][o] * values[h];
                    }
                    weightedInput += (-1 * biases[o]);
                    values[o] = 1.0 / (1.0 + Math.exp(-weightedInput));
                }
                // update weights
                double sumOfSquaredErrors = 0;
                for (int o = inputs.size() + numHiddenNodes; o < numNodes; ++o) {
                    double absoluteerror = expectedValues[o - inputs.size() - numHiddenNodes] - values[o];
                    sumOfSquaredErrors += absoluteerror * absoluteerror;
                    double outputErrorGradient = values[o] * (1.0 - values[o]) * absoluteerror;
                    // update weights in the hidden layer
                    for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
                        double delta = learningRate * values[h] * outputErrorGradient + momentum * deltaWeights[h][o];
                        weights[h][o] += delta;
                        deltaWeights[h][o] = delta;
                        double hiddenErrorGradient = values[h] * (1 - values[h]) * outputErrorGradient * weights[h][o];
                        for (int i = 1; i < inputs.size(); ++i) {
                            double hdelta = learningRate * values[i] * hiddenErrorGradient + momentum * deltaWeights[i][h];
                            weights[i][h] += hdelta;
                            deltaWeights[i][h] = hdelta;
                        }
                        double biasDelta = learningRate * -1 * hiddenErrorGradient + momentum * deltaBiases[h];
                        biases[h] += biasDelta;
                        deltaBiases[h] = biasDelta;
                    }
                    double biasDelta = learningRate * -1 * outputErrorGradient + momentum * deltaBiases[o];
                    biases[o] += biasDelta;
                    deltaBiases[o] = biasDelta;
                }
                error += sumOfSquaredErrors;
            }
            // possibly save the best
            if (error < minError) {
                possibleNetworks.add(createNetwork(inputs, numHiddenNodes, outputs, weights, biases));
                minError = error;
            }

            // evaluate the errors 
            error1000 += error;
            double truncError = ((int) (error * 1000)) / 1000d;
            if (truncError > divergenceErrorLimitTrunc) {
                ++numGThanDivergenceErrorLimitTrunc;
            }
            if (iteration > 0 && iteration % 1000 == 0) {
                error1000 = error1000 / 1000;
                boolean rebalance = numGThanDivergenceErrorLimitTrunc > 700 ? true : false;
                //System.out.println("Backpropagation: iteration = " + iteration + "; error1000 = " + error1000 +
                //            "; rebalance = " + rebalance + "; numGThanDivergenceErrorLimitTrunc = " + numGThanDivergenceErrorLimitTrunc);
                if (rebalance) {
                    possibleNetworks.add(createNetwork(inputs, numHiddenNodes, outputs, weights, biases));
                    for (int i = 0; i < numNodes; ++i) {
                        for (int j = 0; j < numNodes; ++j) {
                            weights[i][j] = weights[i][j] + random.nextGaussian();
                        }
                        biases[i] = biases[i] + random.nextGaussian();
                    }
                    deltaWeights = new double[numNodes][numNodes];
                    deltaBiases = new double[numNodes];
                    divergenceErrorLimitTrunc = Double.POSITIVE_INFINITY;
                    targetError1000 = Double.POSITIVE_INFINITY;
                }
                else {
                    divergenceErrorLimitTrunc = ((int) (error1000 * 1000)) / 1000d;
                    if (targetError1000 < error1000) {
                        break;
                    }
                    else {
                        targetError1000 = error1000 - minError1000Delta;
                        error1000 = 0;
                    }
                }
                numGThanDivergenceErrorLimitTrunc = 0;
            }
        }
        possibleNetworks.add(createNetwork(inputs, numHiddenNodes, outputs, weights, biases));

        return selectBestNetwork(possibleNetworks, records);
    }

    /**
     * Creates network.
     *
     * @param inputs names of the input nodes
     * @param numHiddenNodes number of hidden nodes
     * @param outputs outputs
     * @param weights weights
     * @param biases biases
     * @return created network
     */
    private static Network createNetwork(List inputs, int numHiddenNodes, List outputs, double[][] weights, double[] biases) {
        FFSHLNetwork.Builder res = new FFSHLNetwork.Builder();
        // create neurons
        for (String input : inputs) {
            res.addInput(IdentityNeuron.create(input));
        }
        res.addInput(ConstantNeuron.create("input_bias"));
        for (int i = 0; i < numHiddenNodes; ++i) {
            res.addHidden(SigmoidNeuron.create("hidden" + i));
        }
        res.addHidden(ConstantNeuron.create("hidden_bias"));
        for (String output : outputs) {
            res.addOutput(SigmoidNeuron.create(output));
        }
        // synapses from input to hidden
        for (int i = 0; i < inputs.size(); ++i) {
            for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
                res.addSynapse(Synapse.create(inputs.get(i), "hidden" + (h - inputs.size()), weights[i][h]));
            }
        }
        for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
            res.addSynapse(Synapse.create("input_bias", "hidden" + (h - inputs.size()), -biases[h]));
        }

        // synapses from hidden to output
        for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
            for (int o = inputs.size() + numHiddenNodes; o < inputs.size() + numHiddenNodes + outputs.size(); ++o) {
                res.addSynapse(Synapse.create("hidden" + (h - inputs.size()), outputs.get(o - inputs.size() - numHiddenNodes), weights[h][o]));
            }
        }
        for (int o = inputs.size() + numHiddenNodes; o < inputs.size() + numHiddenNodes + outputs.size(); ++o) {
            res.addSynapse(Synapse.create("hidden_bias", outputs.get(o - inputs.size() - numHiddenNodes), -biases[o]));
        }

        return res.build();
    }

    /**
     * Selects the best network for fitting the problem.
     *
     * @param possibleNetworks possible networks
     * @param records records
     * @return best network
     */
    private static Network selectBestNetwork(List possibleNetworks, List records) {
        Network res = null;
        double minError = Double.MAX_VALUE;
        for (Network network : possibleNetworks) {
            double error = 0;
            for (SupervisedTrainingRecord rec : records) {
                Map actual = network.process(rec.getInputs());
                Map expected = rec.getOutputs();
                for (String key : expected.keySet()) {
                    double err = actual.get(key) - expected.get(key);
                    error += err * err;
                }
            }
            if (error < minError) {
                res = network;
                minError = error;
            }
        }
        return res;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy