maths.errorfunctions.LogisticMSEVectorFunction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jstat Show documentation
Show all versions of jstat Show documentation
Java Library for Statistical Analysis.
The newest version!
package maths.errorfunctions;
import base.CommonConstants;
import datasets.VectorDouble;
import datastructs.I2DDataSet;
import datastructs.IVector;
import maths.functions.IVectorRealFunction;
public class LogisticMSEVectorFunction implements IVectorErrorRealFunction {
/**
* Constructor
*/
public LogisticMSEVectorFunction(IVectorRealFunction> hypothesis ){
if(hypothesis == null){
throw new IllegalArgumentException("Hypothesis function cannot be null");
}
this.hypothesis = hypothesis;
}
/**
* Evaluate the error function using the given data, labels
*/
@Override
public double evaluate(DataSetType data, VectorDouble labels){
if(data.m() != labels.size()){
throw new IllegalArgumentException("Invalid number of data points and labels vector size");
}
double result = 0.0;
for(int rowIdx=0; rowIdx CommonConstants.getTol()){
result += 1.0;
}
}
else{
//do it normally
//calculate the logarithms and check if they are
//infinite or nana
double log_one_minus_h = Math.log(1. - hypothesisValue);
double log_h = Math.log(hypothesisValue);
result += y*log_h +(1.-y)*log_one_minus_h;
}
}
return -result/data.m();
}
/**
* Returns the gradients on the given data
*/
@Override
public VectorDouble gradients(DataSetType data, VectorDouble labels){
VectorDouble gradients = new VectorDouble(this.hypothesis.numCoeffs(), 0.0);
for(int rowIdx=0; rowIdx hypothesisGrads = this.hypothesis.coeffGradients(row);
for(int coeff=0; coeff> hypothesis;
}