
aima.core.learning.neural.BackPropLearning Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of aima-core Show documentation
Show all versions of aima-core Show documentation
AIMA-Java Core Algorithms from the book Artificial Intelligence a Modern Approach 3rd Ed.
package aima.core.learning.neural;
import aima.core.util.math.Matrix;
import aima.core.util.math.Vector;
/**
* @author Ravi Mohan
*
*/
public class BackPropLearning implements NNTrainingScheme {
private final double learningRate;
private final double momentum;
private Layer hiddenLayer;
private Layer outputLayer;
private LayerSensitivity hiddenSensitivity;
private LayerSensitivity outputSensitivity;
public BackPropLearning(double learningRate, double momentum) {
this.learningRate = learningRate;
this.momentum = momentum;
}
public void setNeuralNetwork(FunctionApproximator fapp) {
FeedForwardNeuralNetwork ffnn = (FeedForwardNeuralNetwork) fapp;
this.hiddenLayer = ffnn.getHiddenLayer();
this.outputLayer = ffnn.getOutputLayer();
this.hiddenSensitivity = new LayerSensitivity(hiddenLayer);
this.outputSensitivity = new LayerSensitivity(outputLayer);
}
public Vector processInput(FeedForwardNeuralNetwork network, Vector input) {
hiddenLayer.feedForward(input);
outputLayer.feedForward(hiddenLayer.getLastActivationValues());
return outputLayer.getLastActivationValues();
}
public void processError(FeedForwardNeuralNetwork network, Vector error) {
// TODO calculate total error somewhere
// create Sensitivity Matrices
outputSensitivity.sensitivityMatrixFromErrorMatrix(error);
hiddenSensitivity
.sensitivityMatrixFromSucceedingLayer(outputSensitivity);
// calculate weight Updates
calculateWeightUpdates(outputSensitivity,
hiddenLayer.getLastActivationValues(), learningRate, momentum);
calculateWeightUpdates(hiddenSensitivity,
hiddenLayer.getLastInputValues(), learningRate, momentum);
// calculate Bias Updates
calculateBiasUpdates(outputSensitivity, learningRate, momentum);
calculateBiasUpdates(hiddenSensitivity, learningRate, momentum);
// update weightsAndBiases
outputLayer.updateWeights();
outputLayer.updateBiases();
hiddenLayer.updateWeights();
hiddenLayer.updateBiases();
}
public Matrix calculateWeightUpdates(LayerSensitivity layerSensitivity,
Vector previousLayerActivationOrInput, double alpha, double momentum) {
Layer layer = layerSensitivity.getLayer();
Matrix activationTranspose = previousLayerActivationOrInput.transpose();
Matrix momentumLessUpdate = layerSensitivity.getSensitivityMatrix()
.times(activationTranspose).times(alpha).times(-1.0);
Matrix updateWithMomentum = layer.getLastWeightUpdateMatrix()
.times(momentum).plus(momentumLessUpdate.times(1.0 - momentum));
layer.acceptNewWeightUpdate(updateWithMomentum.copy());
return updateWithMomentum;
}
public static Matrix calculateWeightUpdates(
LayerSensitivity layerSensitivity,
Vector previousLayerActivationOrInput, double alpha) {
Layer layer = layerSensitivity.getLayer();
Matrix activationTranspose = previousLayerActivationOrInput.transpose();
Matrix weightUpdateMatrix = layerSensitivity.getSensitivityMatrix()
.times(activationTranspose).times(alpha).times(-1.0);
layer.acceptNewWeightUpdate(weightUpdateMatrix.copy());
return weightUpdateMatrix;
}
public Vector calculateBiasUpdates(LayerSensitivity layerSensitivity,
double alpha, double momentum) {
Layer layer = layerSensitivity.getLayer();
Matrix biasUpdateMatrixWithoutMomentum = layerSensitivity
.getSensitivityMatrix().times(alpha).times(-1.0);
Matrix biasUpdateMatrixWithMomentum = layer.getLastBiasUpdateVector()
.times(momentum)
.plus(biasUpdateMatrixWithoutMomentum.times(1.0 - momentum));
Vector result = new Vector(
biasUpdateMatrixWithMomentum.getRowDimension());
for (int i = 0; i < biasUpdateMatrixWithMomentum.getRowDimension(); i++) {
result.setValue(i, biasUpdateMatrixWithMomentum.get(i, 0));
}
layer.acceptNewBiasUpdate(result.copyVector());
return result;
}
public static Vector calculateBiasUpdates(
LayerSensitivity layerSensitivity, double alpha) {
Layer layer = layerSensitivity.getLayer();
Matrix biasUpdateMatrix = layerSensitivity.getSensitivityMatrix()
.times(alpha).times(-1.0);
Vector result = new Vector(biasUpdateMatrix.getRowDimension());
for (int i = 0; i < biasUpdateMatrix.getRowDimension(); i++) {
result.setValue(i, biasUpdateMatrix.get(i, 0));
}
layer.acceptNewBiasUpdate(result.copyVector());
return result;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy