org.deeplearning4j.example.iris.IrisRBMExample Maven / Gradle / Ivy
package org.deeplearning4j.example.iris;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.dbn.GaussianRectifiedLinearDBN;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.example.mnist.IrisExample;
import org.deeplearning4j.nn.activation.Activations;
import org.deeplearning4j.rbm.GaussianRectifiedLinearRBM;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class IrisRBMExample {
private static Logger log = LoggerFactory.getLogger(IrisRBMExample.class);
/**
* @param args
*/
public static void main(String[] args) {
DataSetIterator irisData = new IrisDataSetIterator(150,150);
DataSet next = irisData.next();
next.normalizeZeroMeanZeroUnitVariance();
int numExamples = next.numExamples();
log.info("Training on " + numExamples);
GaussianRectifiedLinearRBM r = new GaussianRectifiedLinearRBM.Builder()
.numberOfVisible(irisData.inputColumns())
.useAdaGrad(true)
.numHidden(10).normalizeByInputRows(false).useRegularization(false)
.build();
r.trainTillConvergence(next.getFirst(),1e-3, new Object[]{1,1e-3,2000});
log.info("\nData " + String.valueOf("\n" + next.getFirst()).replaceAll(";","\n"));
log.info("\nReconstruct " + String.valueOf("\n" + r.reconstruct(r.getInput())).replaceAll(";","\n"));
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy