![JAR search and dependency download from the Maven repository](/logo.png)
linear.regression.MultipleLinearRegression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of java-timeseries Show documentation
Show all versions of java-timeseries Show documentation
Time Series Analysis in Java
The newest version!
/*
* Copyright (c) 2017 Jacob Rachiele
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software
* and associated documentation files (the "Software"), to deal in the Software without restriction
* including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense
* and/or sell copies of the Software, and to permit persons to whom the Software is furnished to
* do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or
* substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED
* INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
* PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
* LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
* USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* Contributors:
*
* Jacob Rachiele
*/
package linear.regression;
import com.google.common.collect.ImmutableList;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import org.ejml.alg.dense.mult.MatrixVectorMult;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.LinearSolverFactory;
import org.ejml.interfaces.decomposition.QRDecomposition;
import org.ejml.interfaces.linsol.LinearSolver;
import org.ejml.ops.CommonOps;
import stats.Statistics;
import java.util.ArrayList;
import java.util.List;
import static data.DoubleFunctions.*;
/**
* A linear regression model with support for both single and multiple prediction variables.
* This implementation is immutable and thread-safe.
*/
@EqualsAndHashCode @ToString
public final class MultipleLinearRegression implements LinearRegression {
private final List> predictors;
private final List response;
private final List beta;
private final List standardErrors;
private final List fitted;
private final List residuals;
private final double sigma2;
private final boolean hasIntercept;
private MultipleLinearRegression(Builder builder) {
this.predictors = builder.listBuilder.build();
this.response = builder.response;
this.hasIntercept = builder.hasIntercept;
MatrixFormulation matrixFormulation = new MatrixFormulation();
this.beta = matrixFormulation.getBetaEstimates();
this.fitted = matrixFormulation.getFittedvalues();
this.residuals = matrixFormulation.getResiduals();
this.sigma2 = matrixFormulation.getSigma2();
this.standardErrors = matrixFormulation.getBetaStandardErrors(beta.size());
}
@Override
public List> predictors() {
return this.predictors;
}
@Override
public List beta() {
return ImmutableList.copyOf(this.beta);
}
@Override
public List standardErrors() {
return ImmutableList.copyOf(this.standardErrors);
}
@Override
public List response() {
return this.response;
}
@Override
public List fitted() {
return ImmutableList.copyOf(this.fitted);
}
@Override
public List residuals() {
return ImmutableList.copyOf(this.residuals);
}
@Override
public double sigma2() {
return this.sigma2;
}
@Override
public boolean hasIntercept() {
return this.hasIntercept;
}
/**
* Create a new regression from this one, using the given boolean to determine whether to fit an intercept.
*
* @param hasIntercept whether or not the new regression should have an intercept.
* @return a new regression using the given boolean to determine whether to fit an intercept.
*/
public MultipleLinearRegression withHasIntercept(boolean hasIntercept) {
return new Builder().from(this).hasIntercept(hasIntercept).build();
}
/**
* Create a new regression from this one, replacing the current response with the provided one.
*
* @param response the response variable of the new regression.
* @return a new regression with the given response variable in place of the current one.
*/
public MultipleLinearRegression withResponse(List response) {
return new Builder().from(this).response(response).build();
}
/**
* Create a new regression from this one, adding the given predictor to the current ones.
*
* @param predictor The prediction variable to add to this regression.
* @return a new regression with the given predictor added to the current ones.
*/
public MultipleLinearRegression withPredictor(List predictor) {
return new Builder().from(this).predictor(predictor).build();
}
/**
* Create a new regression from this one, with the given predictors fully replacing the current ones.
*
* @param predictors The new list of prediction variables to use for the regression.
* @return a new regression using the given predictors in place of the current ones.
*/
public MultipleLinearRegression withPredictors(List> predictors) {
return new Builder().from(this).predictors(predictors).build();
}
/**
* Create and return a new builder for this class.
*
* @return a new builder for this class.
*/
public static Builder builder() {
return new Builder();
}
/**
* A builder for a multiple linear regression model.
*/
public static final class Builder {
private ImmutableList.Builder> listBuilder;
private List response;
private boolean hasIntercept = true;
/**
* Copy the attributes of the given regression object to this builder and return this builder.
*
* @param regression the object to copy the attributes from.
* @return this builder.
*/
public Builder from(LinearRegression regression) {
this.listBuilder = ImmutableList.builder();
for (List predictor : regression.predictors()) {
this.listBuilder.add(ImmutableList.copyOf(predictor));
}
this.response = ImmutableList.copyOf(regression.response());
this.hasIntercept = regression.hasIntercept();
return this;
}
Builder predictors(List> predictors) {
this.listBuilder = ImmutableList.builder();
for (List predictor : predictors) {
this.listBuilder.add(ImmutableList.copyOf(predictor));
}
return this;
}
public Builder predictor(List predictor) {
if (this.listBuilder == null) {
this.listBuilder = ImmutableList.builder();
}
this.listBuilder.add(ImmutableList.copyOf(predictor));
return this;
}
public Builder response(List response) {
this.response = ImmutableList.copyOf(response);
return this;
}
public Builder hasIntercept(boolean hasIntercept) {
this.hasIntercept = hasIntercept;
return this;
}
public MultipleLinearRegression build() {
return new MultipleLinearRegression(this);
}
}
private class MatrixFormulation {
private final DenseMatrix64F A; // The design matrix.
private final DenseMatrix64F At; // The transpose of A.
private final DenseMatrix64F AtAInv; // The inverse of At times A.
private final DenseMatrix64F b; // The parameter estimate vector.
private final DenseMatrix64F y; // The response vector.
private final D1Matrix64F fitted;
private final List residuals;
private final double sigma2;
private final DenseMatrix64F covarianceMatrix;
private MatrixFormulation() {
int numRows = response.size();
int numCols = predictors.size() + ((hasIntercept()) ? 1 : 0);
this.A = createMatrixA(numRows, numCols);
this.At = new DenseMatrix64F(numCols, numRows);
CommonOps.transpose(A, At);
this.AtAInv = new DenseMatrix64F(numCols, numCols);
this.b = new DenseMatrix64F(numCols, 1);
this.y = new DenseMatrix64F(numRows, 1);
solveSystem(numRows, numCols);
this.fitted = computeFittedValues();
this.residuals = computeResiduals();
this.sigma2 = estimateSigma2(numCols);
this.covarianceMatrix = new DenseMatrix64F(numCols, numCols);
CommonOps.scale(sigma2, AtAInv, covarianceMatrix);
}
private void solveSystem(int numRows, int numCols) {
LinearSolver qrSolver = LinearSolverFactory.qr(numRows, numCols);
QRDecomposition decomposition = qrSolver.getDecomposition();
qrSolver.setA(A);
y.setData(arrayFrom(response));
qrSolver.solve(this.y, this.b);
DenseMatrix64F R = decomposition.getR(null, true);
LinearSolver linearSolver = LinearSolverFactory.linear(numCols);
linearSolver.setA(R);
DenseMatrix64F Rinverse = new DenseMatrix64F(numCols, numCols);
linearSolver.invert(Rinverse); // stores solver's solution inside of Rinverse.
CommonOps.multOuter(Rinverse, this.AtAInv);
}
private DenseMatrix64F createMatrixA(int numRows, int numCols) {
double[] data = hasIntercept ? fill(numRows, 1.0) : arrayFrom();
for (List predictor : predictors) {
data = combine(data, arrayFrom(predictor));
}
boolean isRowMajor = false;
return new DenseMatrix64F(numRows, numCols, isRowMajor, data);
}
private D1Matrix64F computeFittedValues() {
D1Matrix64F fitted = new DenseMatrix64F(response.size(), 1);
MatrixVectorMult.mult(A, b, fitted);
return fitted;
}
private List computeResiduals() {
List fitted = getFittedvalues();
List residuals = new ArrayList<>(fitted.size());
for (int i = 0; i < fitted.size(); i++) {
residuals.add(response.get(i) - fitted.get(i));
}
return residuals;
}
private double estimateSigma2(int df) {
double ssq = Statistics.sumOfSquared(arrayFrom(this.residuals));
return ssq / (this.residuals.size() - df);
}
private List getFittedvalues() {
return listFrom(fitted.getData());
}
private List getResiduals() {
return residuals;
}
private List getBetaEstimates() {
return listFrom(b.getData());
}
private List getBetaStandardErrors(int numCols) {
DenseMatrix64F diag = new DenseMatrix64F(numCols, 1);
CommonOps.extractDiag(this.covarianceMatrix, diag);
return listFrom(sqrt(diag.getData()));
}
private double getSigma2() {
return this.sigma2;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy