org.openimaj.ml.neuralnet.OnlineBackpropOneHidden Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of sandbox Show documentation
Show all versions of sandbox Show documentation
A project for various tests that don't quite constitute
demos but might be useful to look at.
/**
* Copyright (c) 2011, The University of Southampton and the individual contributors.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* * Neither the name of the University of Southampton nor the names of its
* contributors may be used to endorse or promote products derived from this
* software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
* ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package org.openimaj.ml.neuralnet;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ;
import org.openimaj.data.RandomData;
import org.openimaj.image.DisplayUtilities;
import org.openimaj.image.FImage;
import org.openimaj.image.colour.ColourMap;
import org.openimaj.util.function.Function;
/**
* Implement an online version of the backprop algorithm against an 2D
*
* @author Sina Samangooei ([email protected])
*
*/
public class OnlineBackpropOneHidden {
private static final double LEARNRATE = 0.005;
private Matrix weightsL1;
private Matrix weightsL2;
MatrixFactory extends Matrix> DMF = DenseMatrixFactoryMTJ.getDenseDefault();
private Function g;
private Function gMat;
private Function gPrime;
private Function gPrimeMat;
/**
* @param nInput
* the number of input values
* @param nHidden
* the number of hidden values
* @param nFinal
* the number of final values
*/
public OnlineBackpropOneHidden(int nInput, int nHidden, int nFinal) {
final double[][] weightsL1dat = RandomData.getRandomDoubleArray(nInput + 1, nHidden, -1, 1.);
final double[][] weightsL2dat = RandomData.getRandomDoubleArray(nHidden + 1, nFinal, -1, 1.);
weightsL1 = DMF.copyArray(weightsL1dat);
weightsL2 = DMF.copyArray(weightsL2dat);
;
g = new Function() {
@Override
public Double apply(Double in) {
return 1. / (1 + Math.exp(-in));
}
};
gPrime = new Function() {
@Override
public Double apply(Double in) {
return g.apply(in) * (1 - g.apply(in));
}
};
gPrimeMat = new Function() {
@Override
public Matrix apply(Matrix in) {
final Matrix out = DMF.copyMatrix(in);
for (int i = 0; i < in.getNumRows(); i++) {
for (int j = 0; j < in.getNumColumns(); j++) {
out.setElement(i, j, gPrime.apply(in.getElement(i, j)));
}
}
return out;
}
};
gMat = new Function() {
@Override
public Matrix apply(Matrix in) {
final Matrix out = DMF.copyMatrix(in);
for (int i = 0; i < in.getNumRows(); i++) {
for (int j = 0; j < in.getNumColumns(); j++) {
out.setElement(i, j, g.apply(in.getElement(i, j)));
}
}
return out;
}
};
}
public void update(double[] x, double[] y) {
final Matrix X = prepareMatrix(x);
final Matrix Y = DMF.copyArray(new double[][] { y });
final Matrix hiddenOutput = weightsL1.transpose().times(X); // nHiddenLayers
// x nInputs
// (usually
// 2 x 1)
final Matrix gHiddenOutput = prepareMatrix(gMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers
// +
// 1
// x
// nInputs
// (usually
// 3x1)
final Matrix gPrimeHiddenOutput = prepareMatrix(gPrimeMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers
// +
// 1
// x
// nInputs
// (usually
// 3x1)
final Matrix finalOutput = weightsL2.transpose().times(gHiddenOutput);
final Matrix finalOutputGPrime = gPrimeMat.apply(finalOutput); // nFinalLayers
// x
// nInputs
// (usually
// 1x1)
final Matrix errmat = Y.minus(finalOutput);
final double err = errmat.sumOfColumns().sum();
Matrix dL2 = finalOutputGPrime.times(gHiddenOutput.transpose()).scale(err * LEARNRATE).transpose(); // should
// be
// nHiddenLayers
// +
// 1
// x
// nInputs
// (3
// x
// 1)
Matrix dL1 = finalOutputGPrime.times(weightsL2.transpose().times(gPrimeHiddenOutput).times(X.transpose()))
.scale(err * LEARNRATE).transpose();
dL1 = repmat(dL1, 1, weightsL1.getNumColumns());
dL2 = repmat(dL2, 1, weightsL2.getNumColumns());
this.weightsL1.plusEquals(dL1);
this.weightsL2.plusEquals(dL2);
}
private Matrix repmat(Matrix dL1, int nRows, int nCols) {
final Matrix out = DMF.createMatrix(nRows * dL1.getNumRows(), nCols * dL1.getNumColumns());
for (int i = 0; i < nRows; i++) {
for (int j = 0; j < nCols; j++) {
out.setSubMatrix(i * dL1.getNumRows(), j * dL1.getNumColumns(), dL1);
}
}
return out;
}
public Matrix predict(double[] x) {
final Matrix X = prepareMatrix(x);
final Matrix hiddenTimes = weightsL1.transpose().times(X);
final Matrix hiddenVal = prepareMatrix(gMat.apply(hiddenTimes).getColumn(0));
final Matrix finalTimes = weightsL2.transpose().times(hiddenVal);
final Matrix finalVal = gMat.apply(finalTimes);
return finalVal;
}
private Matrix prepareMatrix(Vector y) {
final Matrix Y = DMF.createMatrix(1, y.getDimensionality() + 1);
Y.setElement(0, 0, 1);
Y.setSubMatrix(0, 1, DMF.copyRowVectors(y));
return Y.transpose();
}
private Matrix prepareMatrix(double[] y) {
final Matrix Y = DMF.createMatrix(1, y.length + 1);
Y.setElement(0, 0, 1);
Y.setSubMatrix(0, 1, DMF.copyArray(new double[][] { y }));
return Y.transpose();
}
public static void main(String[] args) throws InterruptedException {
final OnlineBackpropOneHidden bp = new OnlineBackpropOneHidden(2, 2, 1);
FImage img = new FImage(200, 200);
img = imagePredict(bp, img);
final ColourMap m = ColourMap.Hot;
DisplayUtilities.displayName(m.apply(img), "xor");
final int npixels = img.width * img.height;
final int half = img.width / 2;
final int[] pixels = RandomData.getUniqueRandomInts(npixels, 0, npixels);
while (true) {
// for (int i = 0; i < pixels.length; i++) {
// int pixel = pixels[i];
// int y = pixel / img.width;
// int x = pixel - (y * img.width);
// bp.update(new double[]{x < half ? -1 : 1,y < half ? -1 : 1},new
// double[]{xorValue(half,x,y)});
// // Thread.sleep(5);
// }
bp.update(new double[] { 0, 0 }, new double[] { 0 });
bp.update(new double[] { 1, 1 }, new double[] { 0 });
bp.update(new double[] { 0, 1 }, new double[] { 1 });
bp.update(new double[] { 1, 0 }, new double[] { 1 });
imagePredict(bp, img);
DisplayUtilities.displayName(m.apply(img), "xor");
}
}
private static FImage imagePredict(OnlineBackpropOneHidden bp, FImage img) {
final double[] pos = new double[2];
final int half = img.width / 2;
for (int y = 0; y < img.height; y++) {
for (int x = 0; x < img.width; x++) {
pos[0] = x < half ? 0 : 1;
pos[1] = y < half ? 0 : 1;
final float ret = (float) bp.predict(pos).getElement(0, 0);
img.pixels[y][x] = ret;
}
}
return img;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy