![JAR search and dependency download from the Maven repository](/logo.png)
com.enterprisemath.math.nn.BackpropagationUtils Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of em-math Show documentation
Show all versions of em-math Show documentation
Advanced mathematical algorithms.
The newest version!
package com.enterprisemath.math.nn;
import com.enterprisemath.utils.DomainUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
/**
* Utilities for back propagation.
*
* This class is for internal use only!!!
* Modifying or removing it will not be considered as a breaking consistency.
*
*
* @author radek.hecl
*/
class BackpropagationUtils {
/**
* Prevents construction.
*/
private BackpropagationUtils() {
}
/**
* Trains the network.
*
* @param inputs input nodes identification, cannot contain keys starting with bias or hidden
* @param numHiddenNodes number of hidden nodes
* @param outputs output nodes identification, cannot contain keys starting with bias or hidden
* @param records training records
* @param maxIterations maximum number of iterations
* @param minError1000Improvement minimal error delta for average of 1000 iterations
* @param allowedError allowed error
* @return created network
*/
public static Network train(List inputs, int numHiddenNodes, List outputs, List records,
int maxIterations, double minError1000Delta, double allowedError) {
List recs = DomainUtils.softCopyList(records);
List possibleNetworks = new ArrayList();
int numNodes = inputs.size() + numHiddenNodes + outputs.size();
double[][] weights = new double[numNodes][numNodes];
double[][] deltaWeights = new double[numNodes][numNodes];
double[] biases = new double[numNodes];
double[] deltaBiases = new double[numNodes];
double[] values = new double[numNodes];
double[] expectedValues = new double[outputs.size()];
double learningRate = 0.2;
double momentum = 0.1;
Random random = new Random(System.currentTimeMillis());
for (int i = 0; i < numNodes; ++i) {
for (int j = 0; j < numNodes; ++j) {
weights[i][j] = random.nextDouble();
}
}
double minError = Double.POSITIVE_INFINITY;
double error = Double.POSITIVE_INFINITY;
double error1000 = 0;
double targetError1000 = Double.POSITIVE_INFINITY;
double divergenceErrorLimitTrunc = Double.POSITIVE_INFINITY;
int numGThanDivergenceErrorLimitTrunc = 0;
for (int iteration = 0; iteration < maxIterations && error > allowedError; ++iteration) {
Collections.shuffle(recs, random);
error = 0;
// run across all records
for (SupervisedTrainingRecord rec : recs) {
// expected values
for (int i = 0; i < outputs.size(); ++i) {
expectedValues[i] = rec.getOutputs().get(outputs.get(i));
}
// input values
for (int i = 0; i < inputs.size(); ++i) {
values[i] = rec.getInputs().get(inputs.get(i));
}
// populate to the hidden layer
for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
double weightedInput = 0;
for (int i = 0; i < inputs.size(); ++i) {
weightedInput += weights[i][h] * values[i];
}
weightedInput += (-1 * biases[h]);
values[h] = 1d / (1d + Math.exp(-weightedInput));
}
// populated to the output layer
for (int o = inputs.size() + numHiddenNodes; o < numNodes; ++o) {
double weightedInput = 0;
for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
weightedInput += weights[h][o] * values[h];
}
weightedInput += (-1 * biases[o]);
values[o] = 1.0 / (1.0 + Math.exp(-weightedInput));
}
// update weights
double sumOfSquaredErrors = 0;
for (int o = inputs.size() + numHiddenNodes; o < numNodes; ++o) {
double absoluteerror = expectedValues[o - inputs.size() - numHiddenNodes] - values[o];
sumOfSquaredErrors += absoluteerror * absoluteerror;
double outputErrorGradient = values[o] * (1.0 - values[o]) * absoluteerror;
// update weights in the hidden layer
for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
double delta = learningRate * values[h] * outputErrorGradient + momentum * deltaWeights[h][o];
weights[h][o] += delta;
deltaWeights[h][o] = delta;
double hiddenErrorGradient = values[h] * (1 - values[h]) * outputErrorGradient * weights[h][o];
for (int i = 1; i < inputs.size(); ++i) {
double hdelta = learningRate * values[i] * hiddenErrorGradient + momentum * deltaWeights[i][h];
weights[i][h] += hdelta;
deltaWeights[i][h] = hdelta;
}
double biasDelta = learningRate * -1 * hiddenErrorGradient + momentum * deltaBiases[h];
biases[h] += biasDelta;
deltaBiases[h] = biasDelta;
}
double biasDelta = learningRate * -1 * outputErrorGradient + momentum * deltaBiases[o];
biases[o] += biasDelta;
deltaBiases[o] = biasDelta;
}
error += sumOfSquaredErrors;
}
// possibly save the best
if (error < minError) {
possibleNetworks.add(createNetwork(inputs, numHiddenNodes, outputs, weights, biases));
minError = error;
}
// evaluate the errors
error1000 += error;
double truncError = ((int) (error * 1000)) / 1000d;
if (truncError > divergenceErrorLimitTrunc) {
++numGThanDivergenceErrorLimitTrunc;
}
if (iteration > 0 && iteration % 1000 == 0) {
error1000 = error1000 / 1000;
boolean rebalance = numGThanDivergenceErrorLimitTrunc > 700 ? true : false;
//System.out.println("Backpropagation: iteration = " + iteration + "; error1000 = " + error1000 +
// "; rebalance = " + rebalance + "; numGThanDivergenceErrorLimitTrunc = " + numGThanDivergenceErrorLimitTrunc);
if (rebalance) {
possibleNetworks.add(createNetwork(inputs, numHiddenNodes, outputs, weights, biases));
for (int i = 0; i < numNodes; ++i) {
for (int j = 0; j < numNodes; ++j) {
weights[i][j] = weights[i][j] + random.nextGaussian();
}
biases[i] = biases[i] + random.nextGaussian();
}
deltaWeights = new double[numNodes][numNodes];
deltaBiases = new double[numNodes];
divergenceErrorLimitTrunc = Double.POSITIVE_INFINITY;
targetError1000 = Double.POSITIVE_INFINITY;
}
else {
divergenceErrorLimitTrunc = ((int) (error1000 * 1000)) / 1000d;
if (targetError1000 < error1000) {
break;
}
else {
targetError1000 = error1000 - minError1000Delta;
error1000 = 0;
}
}
numGThanDivergenceErrorLimitTrunc = 0;
}
}
possibleNetworks.add(createNetwork(inputs, numHiddenNodes, outputs, weights, biases));
return selectBestNetwork(possibleNetworks, records);
}
/**
* Creates network.
*
* @param inputs names of the input nodes
* @param numHiddenNodes number of hidden nodes
* @param outputs outputs
* @param weights weights
* @param biases biases
* @return created network
*/
private static Network createNetwork(List inputs, int numHiddenNodes, List outputs, double[][] weights, double[] biases) {
FFSHLNetwork.Builder res = new FFSHLNetwork.Builder();
// create neurons
for (String input : inputs) {
res.addInput(IdentityNeuron.create(input));
}
res.addInput(ConstantNeuron.create("input_bias"));
for (int i = 0; i < numHiddenNodes; ++i) {
res.addHidden(SigmoidNeuron.create("hidden" + i));
}
res.addHidden(ConstantNeuron.create("hidden_bias"));
for (String output : outputs) {
res.addOutput(SigmoidNeuron.create(output));
}
// synapses from input to hidden
for (int i = 0; i < inputs.size(); ++i) {
for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
res.addSynapse(Synapse.create(inputs.get(i), "hidden" + (h - inputs.size()), weights[i][h]));
}
}
for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
res.addSynapse(Synapse.create("input_bias", "hidden" + (h - inputs.size()), -biases[h]));
}
// synapses from hidden to output
for (int h = inputs.size(); h < inputs.size() + numHiddenNodes; ++h) {
for (int o = inputs.size() + numHiddenNodes; o < inputs.size() + numHiddenNodes + outputs.size(); ++o) {
res.addSynapse(Synapse.create("hidden" + (h - inputs.size()), outputs.get(o - inputs.size() - numHiddenNodes), weights[h][o]));
}
}
for (int o = inputs.size() + numHiddenNodes; o < inputs.size() + numHiddenNodes + outputs.size(); ++o) {
res.addSynapse(Synapse.create("hidden_bias", outputs.get(o - inputs.size() - numHiddenNodes), -biases[o]));
}
return res.build();
}
/**
* Selects the best network for fitting the problem.
*
* @param possibleNetworks possible networks
* @param records records
* @return best network
*/
private static Network selectBestNetwork(List possibleNetworks, List records) {
Network res = null;
double minError = Double.MAX_VALUE;
for (Network network : possibleNetworks) {
double error = 0;
for (SupervisedTrainingRecord rec : records) {
Map actual = network.process(rec.getInputs());
Map expected = rec.getOutputs();
for (String key : expected.keySet()) {
double err = actual.get(key) - expected.get(key);
error += err * err;
}
}
if (error < minError) {
res = network;
minError = error;
}
}
return res;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy