All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.rbm.RBMOptimizer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.rbm;

import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.NeuralNetworkGradient;
import org.deeplearning4j.optimize.NeuralNetworkOptimizer;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.jblas.DoubleMatrix;


public class RBMOptimizer extends NeuralNetworkOptimizer {

	
	private static final long serialVersionUID = 3676032651650426749L;
	protected int k = -1;
	protected int numTimesIterated = 0;
	
	public RBMOptimizer(BaseNeuralNetwork network, double lr,
			Object[] trainingParams) {
		super(network, lr, trainingParams);
	}

	@Override
	public synchronized void getValueGradient(double[] buffer) {
		int k = (int) extraParams[0];
		numTimesIterated++;
		//adaptive k based on the number of iterations.
		//typically over time, you want to increase k.
		if(this.k <= 0)
			this.k = k;
		if(numTimesIterated % 10 == 0) {
			this.k++;
		}
		
		
		//Don't go over 15
		if(this.k >= 15) 
		     this.k = 15;
		
		k = this.k;
		NeuralNetworkGradient gradient = network.getGradient(new Object[]{k,lr});
	
		DoubleMatrix wAdd = gradient.getwGradient();
		DoubleMatrix vBiasAdd = gradient.getvBiasGradient();
		DoubleMatrix hBiasAdd = gradient.gethBiasGradient();
		
		int idx = 0;
		for (int i = 0; i < wAdd.length; i++) 
			buffer[idx++] = wAdd.get(i);
		
		
		for (int i = 0; i < vBiasAdd.length; i++) 
			buffer[idx++] = vBiasAdd.get(i);
		

		
		for (int i = 0; i < hBiasAdd.length; i++) 
			buffer[idx++] = hBiasAdd.get(i);
		
	
	}



}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy