All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.HiddenLayer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.nn;

import java.io.Serializable;

import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.nn.activation.ActivationFunction;
import org.deeplearning4j.nn.activation.Sigmoid;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;


/**
 * Vectorized Hidden Layer
 * @author Adam Gibson
 *
 */
public class HiddenLayer implements Serializable {

	private static final long serialVersionUID = 915783367350830495L;
	private int nIn;
	private int nOut;
	private DoubleMatrix W;
	private DoubleMatrix b;
	private RandomGenerator rng;
	private DoubleMatrix input;
	private ActivationFunction activationFunction = new Sigmoid();


	private HiddenLayer() {}

	public HiddenLayer(int nIn, int nOut, DoubleMatrix W, DoubleMatrix b, RandomGenerator rng,DoubleMatrix input,ActivationFunction activationFunction) {
		this.nIn = nIn;
		this.nOut = nOut;
		this.input = input;
		this.activationFunction = activationFunction;

		if(rng == null) {
			this.rng = new MersenneTwister(1234);
		}
		else 
			this.rng = rng;

		if(W == null) {

			NormalDistribution u = new NormalDistribution(this.rng,0,.01,NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);

			this.W = DoubleMatrix.zeros(nIn,nOut);

			for(int i = 0; i < this.W.rows; i++) 
				this.W.putRow(i,new DoubleMatrix(u.sample(this.W.columns)));
		}

		else 
			this.W = W;


		if(b == null) 
			this.b = DoubleMatrix.zeros(nOut);
		else 
			this.b = b;
	}


	public HiddenLayer(int nIn, int nOut, DoubleMatrix W, DoubleMatrix b, RandomGenerator rng,DoubleMatrix input) {
		this.nIn = nIn;
		this.nOut = nOut;
		this.input = input;

		if(rng == null) {
			this.rng = new MersenneTwister(1234);
		}
		else 
			this.rng = rng;

		if(W == null) {
			NormalDistribution u = new NormalDistribution(this.rng,0,.01,NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);

			this.W = DoubleMatrix.zeros(nIn,nOut);

			for(int i = 0; i < this.W.rows; i++) 
				this.W.putRow(i,new DoubleMatrix(u.sample(this.W.columns)));
		}

		else 
			this.W = W;


		if(b == null) 
			this.b = DoubleMatrix.zeros(nOut);
		else 
			this.b = b;
	}

	public synchronized int getnIn() {
		return nIn;
	}

	public synchronized void setnIn(int nIn) {
		this.nIn = nIn;
	}

	public synchronized int getnOut() {
		return nOut;
	}

	public synchronized void setnOut(int nOut) {
		this.nOut = nOut;
	}

	public synchronized DoubleMatrix getW() {
		return W;
	}

	public synchronized void setW(DoubleMatrix w) {
		W = w;
	}

	public synchronized DoubleMatrix getB() {
		return b;
	}

	public synchronized void setB(DoubleMatrix b) {
		this.b = b;
	}

	public synchronized RandomGenerator getRng() {
		return rng;
	}

	public synchronized void setRng(RandomGenerator rng) {
		this.rng = rng;
	}

	public synchronized DoubleMatrix getInput() {
		return input;
	}

	public synchronized void setInput(DoubleMatrix input) {
		this.input = input;
	}

	public synchronized ActivationFunction getActivationFunction() {
		return activationFunction;
	}

	public synchronized void setActivationFunction(
			ActivationFunction activationFunction) {
		this.activationFunction = activationFunction;
	}

	@Override
	public HiddenLayer clone() {
		HiddenLayer layer = new HiddenLayer();
		layer.b = b.dup();
		layer.W = W.dup();
		layer.input = input.dup();
		layer.activationFunction = activationFunction;
		layer.nOut = nOut;
		layer.nIn = nIn;
		layer.rng = rng;
		return layer;
	}


	public HiddenLayer transpose() {
		HiddenLayer layer = new HiddenLayer();
		layer.b = b.dup();
		layer.W = W.transpose();
		layer.input = input.transpose();
		layer.activationFunction = activationFunction;
		layer.nOut = nIn;
		layer.nIn = nOut;
		layer.rng = rng;
		return layer;
	}

	
	
	/**
	 * Trigger an activation with the last specified input
	 * @return the activation of the last specified input
	 */
	public synchronized DoubleMatrix activate() {
		return getActivationFunction().apply(getInput().mmul(getW()).addRowVector(getB()));
	}

	/**
	 * Initialize the layer with the given input
	 * and return the activation for this layer
	 * given this input
	 * @param input the input to use
	 * @return
	 */
	public synchronized DoubleMatrix activate(DoubleMatrix input) {
		if(input != null)
			this.input = input;
		return activate();
	}

	/**
	 * Sample this hidden layer given the input
	 * and initialize this layer with the given input
	 * @param input the input to sample
	 * @return the activation for this layer
	 * given the input
	 */
	public DoubleMatrix sampleHGivenV(DoubleMatrix input) {
		this.input = input;
		DoubleMatrix ret = MatrixUtil.binomial(activate(), 1, rng);
		return ret;
	}

	/**
	 * Sample this hidden layer given the last input.
	 * @return the activation for this layer given 
	 * the previous input
	 */
	public DoubleMatrix sample_h_given_v() {
		DoubleMatrix output = activate();
		//reset the seed to ensure consistent generation of data
		DoubleMatrix ret = MatrixUtil.binomial(output, 1, rng);
		return ret;
	}



	public static class Builder {
		private int nIn;
		private int nOut;
		private DoubleMatrix W;
		private DoubleMatrix b;
		private RandomGenerator rng;
		private DoubleMatrix input;
		private ActivationFunction activationFunction = new Sigmoid();


		public Builder nIn(int nIn) {
			this.nIn = nIn;
			return this;
		}

		public Builder nOut(int nOut) {
			this.nOut = nOut;
			return this;
		}

		public Builder withWeights(DoubleMatrix W) {
			this.W = W;
			return this;
		}

		public Builder withRng(RandomGenerator gen) {
			this.rng = gen;
			return this;
		}

		public Builder withActivation(ActivationFunction function) {
			this.activationFunction = function;
			return this;
		}

		public Builder withBias(DoubleMatrix b) {
			this.b = b;
			return this;
		}

		public Builder withInput(DoubleMatrix input) {
			this.input = input;
			return this;
		}

		public HiddenLayer build() {
			HiddenLayer ret =  new HiddenLayer(nIn,nOut,W,b,rng,input); 
			ret.activationFunction = activationFunction;
			return ret;
		}

	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy