All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.sda.StackedDenoisingAutoEncoder Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.sda;

import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.da.DenoisingAutoEncoder;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.NeuralNetwork;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;



/**
 * A JBlas implementation of 
 * stacked denoising auto encoders.
 * @author Adam Gibson
 *
 */
public class StackedDenoisingAutoEncoder extends BaseMultiLayerNetwork  {

	private static final long serialVersionUID = 1448581794985193009L;
	private static Logger log = LoggerFactory.getLogger(StackedDenoisingAutoEncoder.class);



	public StackedDenoisingAutoEncoder() {}

	public StackedDenoisingAutoEncoder(int n_ins, int[] hiddenLayerSizes, int nOuts,
			int nLayers, RandomGenerator rng, DoubleMatrix input,DoubleMatrix labels) {
		super(n_ins, hiddenLayerSizes, nOuts, nLayers, rng, input,labels);

	}


	public StackedDenoisingAutoEncoder(int nIns, int[] hiddenLayerSizes, int nOuts,
			int n_layers, RandomGenerator rng) {
		super(nIns, hiddenLayerSizes, nOuts, n_layers, rng);
	}


	public void pretrain( double lr,  double corruptionLevel,  int epochs) {
		pretrain(this.getInput(),lr,corruptionLevel,epochs);
	}

	/**
	 * Unsupervised pretraining based on reconstructing the input
	 * from a corrupted version
	 * @param input the input to train on
	 * @param lr the starting learning rate
	 * @param corruptionLevel the corruption level (the smaller number of inputs; the higher the 
	 * corruption level should be) the percent of inputs to corrupt
	 * @param epochs the number of iterations to run
	 */
	public void pretrain(DoubleMatrix input,double lr,  double corruptionLevel,  int epochs) {
		if(this.getInput() == null)
			initializeLayers(input.dup());

		DoubleMatrix layerInput = null;

		for(int i = 0; i < this.getnLayers(); i++) {  // layer-wise                        
			//input layer
			if(i == 0)
				layerInput = input;
			else
				layerInput = this.getSigmoidLayers()[i - 1].sampleHGivenV(layerInput);
			if(isForceNumEpochs()) {
				for(int epoch = 0; epoch < epochs; epoch++) {
					layers[i].train(layerInput, lr,  new Object[]{corruptionLevel,lr});
					log.info("Error on epoch " + epoch + " for layer " + (i + 1) + " is " + layers[i].getReConstructionCrossEntropy());
				}
			}
			else
				layers[i].trainTillConvergence(layerInput, lr, new Object[]{corruptionLevel,lr});


		}	
	}

	/**
	 * 
	 * @param input input examples
	 * @param labels output labels
	 * @param otherParams
	 * 
	 * (double) learningRate
	 * (double) corruptionLevel
	 * (int) epochs
	 * 
	 * Optional:
	 * (double) finetune lr
	 * (int) finetune epochs
	 * 
	 */
	@Override
	public void trainNetwork(DoubleMatrix input, DoubleMatrix labels,
			Object[] otherParams) {
		if(otherParams == null) {
			otherParams = new Object[]{0.01,0.3,1000};
		}
		
		Double lr = (Double) otherParams[0];
		Double corruptionLevel = (Double) otherParams[1];
		Integer epochs = (Integer) otherParams[2];

		pretrain(input, lr, corruptionLevel, epochs);
		if(otherParams.length <= 3)
			finetune(labels, lr, epochs);
		else {
			Double finetuneLr = (Double) otherParams[3];
			Integer fineTuneEpochs = (Integer) otherParams[4];
			finetune(labels,finetuneLr,fineTuneEpochs);
		}
	}



	@Override
	public NeuralNetwork createLayer(DoubleMatrix input, int nVisible,
			int nHidden, DoubleMatrix W, DoubleMatrix hbias,
			DoubleMatrix vBias, RandomGenerator rng,int index) {
		DenoisingAutoEncoder ret = new DenoisingAutoEncoder.Builder()
		.withHBias(hbias).withInput(input).withWeights(W)
		.withRandom(rng).withMomentum(getMomentum()).withVisibleBias(vBias)
		.numberOfVisible(nVisible).numHidden(nHidden).withDistribution(getDist())
		.withSparsity(this.getSparsity()).renderWeights(getRenderWeightsEveryNEpochs()).fanIn(getFanIn())
		.build();

		return ret;
	}


	@Override
	public NeuralNetwork[] createNetworkLayers(int numLayers) {
		return new DenoisingAutoEncoder[numLayers];
	}


	public static class Builder extends BaseMultiLayerNetwork.Builder {
		public Builder() {
			this.clazz = StackedDenoisingAutoEncoder.class;
		}
	}





}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy