All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.example.iris.IrisRBMExample Maven / Gradle / Ivy

package org.deeplearning4j.example.iris;

import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.dbn.GaussianRectifiedLinearDBN;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.example.mnist.IrisExample;
import org.deeplearning4j.nn.activation.Activations;
import org.deeplearning4j.rbm.GaussianRectifiedLinearRBM;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class IrisRBMExample {

	private static Logger log = LoggerFactory.getLogger(IrisRBMExample.class);

	/**
	 * @param args
	 */
	public static void main(String[] args) {
		DataSetIterator irisData = new IrisDataSetIterator(150,150);
		DataSet next = irisData.next();
		next.normalizeZeroMeanZeroUnitVariance();
		
		int numExamples = next.numExamples();
		log.info("Training on " + numExamples);

		GaussianRectifiedLinearRBM r = new GaussianRectifiedLinearRBM.Builder()
		.numberOfVisible(irisData.inputColumns())
		.useAdaGrad(true)
		.numHidden(10).normalizeByInputRows(false).useRegularization(false)
		.build();

		r.trainTillConvergence(next.getFirst(),1e-3, new Object[]{1,1e-3,2000});
		log.info("\nData " + String.valueOf("\n" + next.getFirst()).replaceAll(";","\n"));
		log.info("\nReconstruct " + String.valueOf("\n" + r.reconstruct(r.getInput())).replaceAll(";","\n"));
		
		

	}


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy