stats.OLSLinearRegression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jstat Show documentation
Show all versions of jstat Show documentation
Java Library for Statistical Analysis.
The newest version!
package stats;
import org.apache.commons.math3.stat.regression.*;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.Table;
import utils.TableOperations;
/**
* Simple class to perform linear regression
*/
public class OLSLinearRegression {
/**
* Constructor
*/
public OLSLinearRegression(){
}
/**
* Returns the interception term
* @return
*/
public final double getIntercept(){
return this.coeffs[0];
}
/**
* Returns the coefficients of the regression
*/
public final double[] getCoeffs(){
return this.coeffs;
}
/**
* Fit a line using OLS and the given data set
*/
public void fit(Table dataSet, String[] xCols, String yCol){
DoubleColumn yColData = dataSet.doubleColumn(yCol);
if( yColData == null){
throw new IllegalStateException("Column: "+yCol+" not in data set");
}
// the object that will do the fitting for us
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
DoubleColumn y = dataSet.doubleColumn(yCol);
double[][] x = TableOperations.getTableColumnsForRegressionMatrix(dataSet, xCols, y.size());
regression.newSampleData(y.asDoubleArray(), x);
this.coeffs = regression.estimateRegressionParameters();
}
/**
* Predict for the given data point
* @return the prediction
*/
public double predict(double[] x){
double rslt = this.coeffs[0];
for(int i = 1; i