All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.rbm.CRBM Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.rbm;


import static org.deeplearning4j.util.MatrixUtil.log;
import static org.deeplearning4j.util.MatrixUtil.oneDiv;
import static org.deeplearning4j.util.MatrixUtil.oneMinus;
import static org.deeplearning4j.util.MatrixUtil.uniform;
import static org.jblas.MatrixFunctions.exp;

import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.jblas.DoubleMatrix;


/**
 * Continuous Restricted Boltzmann Machine
 * @author Adam Gibson
 *
 */
public class CRBM extends RBM {

	/**
	 * 
	 */
	private static final long serialVersionUID = 598767790003731193L;


	


	public CRBM() {
		super();
	}

	public CRBM(DoubleMatrix input, int n_visible, int n_hidden,
			DoubleMatrix W, DoubleMatrix hbias, DoubleMatrix vbias,
			RandomGenerator rng, double fanIn, RealDistribution dist) {
		super(input, n_visible, n_hidden, W, hbias, vbias, rng, fanIn, dist);
	}

	

	@Override
	public DoubleMatrix propDown(DoubleMatrix h) {
		return h.mmul(W.transpose()).addRowVector(vBias);
	}

	@Override
	public Pair sampleVGivenH(DoubleMatrix h) {
		DoubleMatrix aH = propDown(h);
		DoubleMatrix en = exp(aH.neg());
		DoubleMatrix ep = exp(aH);


		DoubleMatrix v1Mean = oneDiv(oneMinus(en).sub(oneDiv(aH)));
		DoubleMatrix v1Sample = log(
				oneMinus(
				uniform(rng,v1Mean.rows,v1Mean.columns)
				.mul(oneMinus(ep)))
				).div(aH);


		return new Pair(v1Mean,v1Sample);



	}


	public static class Builder extends BaseNeuralNetwork.Builder {
		public Builder() {
			this.clazz = CRBM.class;
		}
	}



}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy