org.deeplearning4j.example.mnist.DenoisingAutoEncoderMnistExample Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of deeplearning4j-examples Show documentation
Show all versions of deeplearning4j-examples Show documentation
Examples of training different data sets
The newest version!
package org.deeplearning4j.example.mnist;
import org.apache.commons.math3.random.MersenneTwister;
import org.deeplearning4j.da.DenoisingAutoEncoder;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.RawMnistDataSetIterator;
import org.deeplearning4j.datasets.mnist.draw.DrawMnistGreyScale;
import org.deeplearning4j.nn.NeuralNetwork.LossFunction;
import org.deeplearning4j.nn.NeuralNetwork.OptimizationAlgorithm;
import org.deeplearning4j.plot.FilterRenderer;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
public class DenoisingAutoEncoderMnistExample {
/**
* @param args
*/
public static void main(String[] args) throws Exception {
DenoisingAutoEncoder autoEncoder = new DenoisingAutoEncoder.Builder()
.numberOfVisible(784).numHidden(500).normalizeByInputRows(true).withLossFunction(LossFunction.NEGATIVELOGLIKELIHOOD)
.useAdaGrad(true).useRegularization(true).withSparsity(0).withL2(0.01)
.withOptmizationAlgo(OptimizationAlgorithm.GRADIENT_DESCENT)
.withMomentum(0.5).build();
//batches of 10, 60000 examples total
DataSetIterator iter = new RawMnistDataSetIterator(10,30);
for(int i = 0;i < 20; i++) {
while(iter.hasNext()) {
DataSet next = iter.next();
//train with k = 1 0.01 learning rate and 1000 epochs
autoEncoder.trainTillConvergence(next.getFirst(), 1e-1, new Object[]{0.6,1e-1,1000});
}
iter.reset();
}
FilterRenderer render = new FilterRenderer();
render.renderFilters(autoEncoder.getW(), "example-render.jpg", 28, 28);
//Iterate over the data set after done training and show the 2 side by side (you have to drag the test image over to the right)
while(iter.hasNext()) {
DataSet first = iter.next();
DoubleMatrix reconstruct = autoEncoder.reconstruct(first.getFirst());
for(int j = 0; j < first.numExamples(); j++) {
DoubleMatrix draw1 = first.get(j).getFirst().mul(255);
DoubleMatrix reconstructed2 = reconstruct.getRow(j);
DoubleMatrix draw2 = MatrixUtil.binomial(reconstructed2,1,new MersenneTwister(123)).mul(255);
DrawMnistGreyScale d = new DrawMnistGreyScale(draw1);
d.title = "REAL";
d.draw();
DrawMnistGreyScale d2 = new DrawMnistGreyScale(draw2,1000,1000);
d2.title = "TEST";
d2.draw();
Thread.sleep(10000);
d.frame.dispose();
d2.frame.dispose();
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy