All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.rbm.RBM Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.rbm;


import static org.deeplearning4j.util.MatrixUtil.binomial;
import static org.deeplearning4j.util.MatrixUtil.mean;
import static org.deeplearning4j.util.MatrixUtil.sigmoid;

import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.NeuralNetworkOptimizer;
import org.jblas.DoubleMatrix;



/**
 * Restricted Boltzmann Machine.
 * 
 * Markov chain with gibbs sampling.
 * 
 * 
 * Based on Hinton et al.'s work 
 * 
 * Great reference:
 * http://www.iro.umontreal.ca/~lisa/publications2/index.php/publications/show/239
 * 
 * 
 * @author Adam Gibson
 *
 */
@SuppressWarnings("unused")
public class RBM extends BaseNeuralNetwork {

	/**
	 * 
	 */
	private static final long serialVersionUID = 6189188205731511957L;
	protected NeuralNetworkOptimizer optimizer;
	public RBM() {}




	public RBM(DoubleMatrix input, int n_visible, int n_hidden, DoubleMatrix W,
			DoubleMatrix hbias, DoubleMatrix vbias, RandomGenerator rng,double fanIn,RealDistribution dist) {
		super(input, n_visible, n_hidden, W, hbias, vbias, rng,fanIn,dist);
	}

	/**
	 * Trains till global minimum is found.
	 * @param learningRate
	 * @param k
	 * @param input
	 */
	public void trainTillConvergence(double learningRate,int k,DoubleMatrix input) {
		if(input != null)
			this.input = input;
		optimizer = new RBMOptimizer(this, learningRate, new Object[]{k,learningRate});
		optimizer.train(input);
	}

	/**
	 * Contrastive divergence revolves around the idea 
	 * of approximating the log likelihood around x1(input) with repeated sampling.
	 * Given is an energy based model: the higher k is (the more we sample the model)
	 * the more we lower the energy (increase the likelihood of the model)
	 * 
	 * and lower the likelihood (increase the energy) of the hidden samples.
	 * 
	 * Other insights:
	 *    CD - k involves keeping the first k samples of a gibbs sampling of the model.
	 *    
	 * @param learningRate the learning rate to scale by
	 * @param k the number of iterations to do
	 * @param input the input to sample from
	 */
	public void contrastiveDivergence(double learningRate,int k,DoubleMatrix input) {
		if(input != null)
			this.input = input;
		NeuralNetworkGradient gradient = getGradient(new Object[]{k,learningRate});
		W.addi(gradient.getwGradient());
		hBias.addi(gradient.gethBiasGradient());
		vBias.addi(gradient.getvBiasGradient());

	}


	@Override
	public NeuralNetworkGradient getGradient(Object[] params) {
		int k = (int) params[0];
		double learningRate = (double) params[1];
		/*
		 * Cost and updates dictionary.
		 * This is the update rules for weights and biases
		 */
		Pair probHidden = sampleHiddenGivenVisible(input);

		/*
		 * Start the gibbs sampling.
		 */
		DoubleMatrix chainStart = probHidden.getSecond();

		/*
		 * Note that at a later date, we can explore alternative methods of 
		 * storing the chain transitions for different kinds of sampling
		 * and exploring the search space.
		 */
		Pair,Pair> matrices = null;
		//negative visible means or expected values
		DoubleMatrix nvMeans = null;
		//negative value samples
		DoubleMatrix nvSamples = null;
		//negative hidden means or expected values
		DoubleMatrix nhMeans = null;
		//negative hidden samples
		DoubleMatrix nhSamples = null;

		/*
		 * K steps of gibbs sampling. THis is the positive phase of contrastive divergence.
		 * 
		 * There are 4 matrices being computed for each gibbs sampling.
		 * The samples from both the positive and negative phases and their expected values or averages.
		 * 
		 */

		for(int i = 0; i < k; i++) {


			if(i == 0) 
				matrices = gibbhVh(chainStart);
			else
				matrices = gibbhVh(nhSamples);
			//get the cost updates for sampling in the chain after k iterations
			nvMeans = matrices.getFirst().getFirst();
			nvSamples = matrices.getFirst().getSecond();
			nhMeans = matrices.getSecond().getFirst();
			nhSamples = matrices.getSecond().getSecond();
		}

		/*
		 * Update gradient parameters
		 */
		DoubleMatrix wGradient = input.transpose().mmul(probHidden.getSecond()).sub(nvSamples.transpose().mmul(nhMeans)).mul(learningRate);

		//weight decay via l2 regularization
		if(useRegularization) 
			wGradient.subi(W.muli(l2));
		if(momentum != 0)
			wGradient.muli( 1 - momentum);

		wGradient.divi(input.rows);

		DoubleMatrix hBiasGradient = null;

		if(this.sparsity != 0) {
			//all hidden units must stay around this number
			hBiasGradient = mean(probHidden.getSecond().add( -sparsity),0).mul(learningRate);
		}
		else {
			//update rule: the expected values of the hidden input - the negative hidden  means adjusted by the learning rate
			hBiasGradient = mean(probHidden.getSecond().sub(nhMeans), 0).mul(learningRate);
		}

		//update rule: the expected values of the input - the negative samples adjusted by the learning rate
		DoubleMatrix  vBiasGradient = mean(input.sub(nvSamples), 0).mul(learningRate);


		return new NeuralNetworkGradient(wGradient, vBiasGradient, hBiasGradient);
	}



	/**
	 * Binomial sampling of the hidden values given visible
	 * @param v the visible values
	 * @return a binomial distribution containing the expected values and the samples
	 */
	public Pair sampleHiddenGivenVisible(DoubleMatrix v) {
		DoubleMatrix h1Mean = propUp(v);
		DoubleMatrix h1Sample = binomial(h1Mean, 1, rng);
		return new Pair(h1Mean,h1Sample);

	}

	/**
	 * Gibbs sampling step: hidden ---> visible ---> hidden
	 * @param h the hidden input
	 * @return the expected values and samples of both the visible samples given the hidden
	 * and the new hidden input and expected values
	 */
	public Pair,Pair> gibbhVh(DoubleMatrix h) {
		Pair v1MeanAndSample = sampleVGivenH(h);
		DoubleMatrix vSample = v1MeanAndSample.getSecond();
		Pair h1MeanAndSample = sampleHiddenGivenVisible(vSample);
		return new Pair<>(v1MeanAndSample,h1MeanAndSample);
	}


	/**
	 * Guess the visible values given the hidden
	 * @param h
	 * @return
	 */
	public Pair sampleVGivenH(DoubleMatrix h) {
		DoubleMatrix v1Mean = propDown(h);
		DoubleMatrix v1Sample = binomial(v1Mean, 1, rng);
		return new Pair<>(v1Mean,v1Sample);
	}


	public DoubleMatrix propUp(DoubleMatrix v) {
		DoubleMatrix preSig = v.mmul(W).addiRowVector(hBias);
		return sigmoid(preSig);

	}

	/**
	 * Propagates hidden down to visible
	 * @param h the hidden layer
	 * @return the approximated output of the hidden layer
	 */
	public DoubleMatrix propDown(DoubleMatrix h) {
		DoubleMatrix preSig = h.mmul(W.transpose()).addRowVector(vBias);
		return sigmoid(preSig);

	}

	/**
	 * Reconstructs the visible input.
	 * A reconstruction is a propdown of the reconstructed hidden input.
	 * @param v the visible input
	 * @return the reconstruction of the visible input
	 */
	@Override
	public DoubleMatrix reconstruct(DoubleMatrix v) {
		//reconstructed: propUp ----> hidden propDown to reconstruct
		return propDown(propUp(v));
	}

	public static class Builder extends BaseNeuralNetwork.Builder {
		public Builder() {
			clazz =  RBM.class;
		}

	}

	/**
	 * Note: k is the first input in params.
	 */
	@Override
	public void trainTillConvergence(DoubleMatrix input, double lr,
			Object[] params) {
		if(input != null)
			this.input = input;
		optimizer = new RBMOptimizer(this, lr, params);
		optimizer.train(input);
	}

	@Override
	public double lossFunction(Object[] params) {
		return getReConstructionCrossEntropy();
	}

	@Override
	public void train(DoubleMatrix input,double lr, Object[] params) {
		int k = (int) params[0];
		contrastiveDivergence(lr, k, input);
	}



}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy