All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.optimize.LogisticRegressionOptimizer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.optimize;

import java.io.Serializable;

import org.deeplearning4j.nn.LogisticRegression;
import org.deeplearning4j.nn.LogisticRegressionGradient;


import cc.mallet.optimize.Optimizable;

public class LogisticRegressionOptimizer implements Optimizable.ByGradientValue,Serializable {

	/**
	 * 
	 */
	private static final long serialVersionUID = 5229426347154854746L;
	private LogisticRegression logReg;
	private double lr;
	
	
	
	public LogisticRegressionOptimizer(LogisticRegression logReg, double lr) {
		super();
		this.logReg = logReg;
		this.lr = lr;
	}

	@Override
	public int getNumParameters() {
		return logReg.getW().length + logReg.getB().length;
		
	}

	@Override
	public void getParameters(double[] buffer) {
		for(int i = 0; i < buffer.length; i++) {
			buffer[i] = getParameter(i);
		}

		


	}

	@Override
	public double getParameter(int index) {
		if(index >= logReg.getW().length)
			return logReg.getB().get(index - logReg.getW().length);
		return logReg.getW().get(index);
	}

	@Override
	public void setParameters(double[] params) {
		for(int i = 0; i < params.length; i++) {
			setParameter(i,params[i]);
		}
	}

	@Override
	public void setParameter(int index, double value) {
		if(index >= logReg.getW().length)
			logReg.getB().put(index - logReg.getW().length,value);
		else
			logReg.getW().put(index,value);
	}

	@Override
	public void getValueGradient(double[] buffer) {
		LogisticRegressionGradient grad = logReg.getGradient(lr);
		for(int i = 0; i < buffer.length; i++) {
			if(i < logReg.getW().length)
				buffer[i] = grad.getwGradient().get(i);
			else
				buffer[i] = grad.getbGradient().get(i - logReg.getW().length);
			
		}
	}

	@Override
	public double getValue() {
		return -logReg.negativeLogLikelihood();
	}



}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy