All Downloads are FREE. Search and download functionalities are using the official Maven repository.

data.regression.MultipleLinearRegressionModel Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2017 Jacob Rachiele
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy of this software
 * and associated documentation files (the "Software"), to deal in the Software without restriction
 * including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense
 * and/or sell copies of the Software, and to permit persons to whom the Software is furnished to
 * do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all copies or
 * substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED
 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
 * PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
 * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
 * USE OR OTHER DEALINGS IN THE SOFTWARE.
 *
 * Contributors:
 *
 * Jacob Rachiele
 */
package data.regression;

import com.google.common.collect.ImmutableList;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import org.ejml.alg.dense.mult.MatrixVectorMult;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.LinearSolverFactory;
import org.ejml.interfaces.decomposition.QRDecomposition;
import org.ejml.interfaces.linsol.LinearSolver;
import org.ejml.ops.CommonOps;
import math.stats.Statistics;

import java.util.ArrayList;
import java.util.List;

import static data.DoubleFunctions.*;

/**
 * A linear regression model with support for both single and multiple prediction variables.
 * This implementation is immutable and thread-safe.
 */
@EqualsAndHashCode @ToString
public final class MultipleLinearRegressionModel implements LinearRegressionModel {

    private final List> predictors;
    private final List response;
    private final List beta;
    private final List standardErrors;
    private final List fitted;
    private final List residuals;
    private final double sigma2;
    private final boolean hasIntercept;

    private MultipleLinearRegressionModel(Builder builder) {
        this.predictors = builder.listBuilder.build();
        this.response = builder.response;
        this.hasIntercept = builder.hasIntercept;
        MatrixFormulation matrixFormulation = new MatrixFormulation();
        this.beta = matrixFormulation.getBetaEstimates();
        this.fitted = matrixFormulation.getFittedvalues();
        this.residuals = matrixFormulation.getResiduals();
        this.sigma2 = matrixFormulation.getSigma2();
        this.standardErrors = matrixFormulation.getBetaStandardErrors(beta.size());
    }

    @Override
    public List> predictors() {
        return this.predictors;
    }

    @Override
    public List beta() {
        return ImmutableList.copyOf(this.beta);
    }

    @Override
    public List standardErrors() {
        return ImmutableList.copyOf(this.standardErrors);
    }

    @Override
    public List response() {
        return this.response;
    }

    @Override
    public List fitted() {
        return ImmutableList.copyOf(this.fitted);
    }

    @Override
    public List residuals() {
        return ImmutableList.copyOf(this.residuals);
    }

    @Override
    public double sigma2() {
        return this.sigma2;
    }

    @Override
    public boolean hasIntercept() {
        return this.hasIntercept;
    }

    /**
     * Create a new linear regression model from this one, using the given boolean to determine whether
     * to fit an intercept or not.
     *
     * @param hasIntercept whether or not the new regression should have an intercept.
     * @return a new linear regression model using the given boolean to determine whether to fit an intercept.
     */
    public MultipleLinearRegressionModel withHasIntercept(boolean hasIntercept) {
        return new Builder().from(this).hasIntercept(hasIntercept).build();
    }

    /**
     * Create a new linear regression model from this one, replacing the current response with the provided one.
     *
     * @param response the response variable of the new regression.
     * @return a new linear regression model with the given response variable in place of the current one.
     */
    public MultipleLinearRegressionModel withResponse(List response) {
        return new Builder().from(this).response(response).build();
    }

    /**
     * Create a new regression from this one, adding the given predictor to the current ones.
     *
     * @param predictor The prediction variable to add to this regression.
     * @return a new regression with the given predictor added to the current ones.
     */
    public MultipleLinearRegressionModel withPredictor(List predictor) {
        return new Builder().from(this).predictor(predictor).build();
    }

    /**
     * Create a new linear regression model from this one, with the given predictors fully replacing the current ones.
     *
     * @param predictors The new list of prediction variables to use for the regression.
     * @return a new linear regression model using the given predictors in place of the current ones.
     */
    public MultipleLinearRegressionModel withPredictors(List> predictors) {
        return new Builder().from(this).predictors(predictors).build();
    }

    /**
     * Create and return a new builder for this class.
     *
     * @return a new builder for this class.
     */
    public static Builder builder() {
        return new Builder();
    }

    /**
     * A builder for a multiple linear regression model.
     */
    public static final class Builder {

        private ImmutableList.Builder> listBuilder;
        private List response;
        private boolean hasIntercept = true;

        /**
         * Copy the attributes of the given regression object to this builder and return this builder.
         *
         * @param regression the object to copy the attributes from.
         * @return this builder.
         */
        public Builder from(LinearRegressionModel regression) {
            this.listBuilder = ImmutableList.builder();
            for (List predictor : regression.predictors()) {
                this.listBuilder.add(ImmutableList.copyOf(predictor));
            }
            this.response = ImmutableList.copyOf(regression.response());
            this.hasIntercept = regression.hasIntercept();
            return this;
        }

        Builder predictors(List> predictors) {
            this.listBuilder = ImmutableList.builder();
            for (List predictor : predictors) {
                this.listBuilder.add(ImmutableList.copyOf(predictor));
            }
            return this;
        }

        public Builder predictor(List predictor) {
            if (this.listBuilder == null) {
                this.listBuilder = ImmutableList.builder();
            }
            this.listBuilder.add(ImmutableList.copyOf(predictor));
            return this;
        }

        public Builder response(List response) {
            this.response = ImmutableList.copyOf(response);
            return this;
        }

        public Builder hasIntercept(boolean hasIntercept) {
            this.hasIntercept = hasIntercept;
            return this;
        }

        public MultipleLinearRegressionModel build() {
            return new MultipleLinearRegressionModel(this);
        }
    }

    private class MatrixFormulation {

        private final DenseMatrix64F A; // The design matrix.
        private final DenseMatrix64F At; // The transpose of A.
        private final DenseMatrix64F AtAInv; // The inverse of At times A.
        private final DenseMatrix64F b; // The parameter estimate vector.
        private final DenseMatrix64F y; // The response vector.
        private final D1Matrix64F fitted;
        private final List residuals;
        private final double sigma2;
        private final DenseMatrix64F covarianceMatrix;

        private MatrixFormulation() {
            int numRows = response.size();
            int numCols = predictors.size() + ((hasIntercept()) ? 1 : 0);
            this.A = createMatrixA(numRows, numCols);
            this.At = new DenseMatrix64F(numCols, numRows);
            CommonOps.transpose(A, At);
            this.AtAInv = new DenseMatrix64F(numCols, numCols);
            this.b = new DenseMatrix64F(numCols, 1);
            this.y = new DenseMatrix64F(numRows, 1);
            solveSystem(numRows, numCols);
            this.fitted = computeFittedValues();
            this.residuals = computeResiduals();
            this.sigma2 = estimateSigma2(numCols);
            this.covarianceMatrix = new DenseMatrix64F(numCols, numCols);
            CommonOps.scale(sigma2, AtAInv, covarianceMatrix);
        }

        private void solveSystem(int numRows, int numCols) {
            LinearSolver qrSolver = LinearSolverFactory.qr(numRows, numCols);
            QRDecomposition decomposition = qrSolver.getDecomposition();
            qrSolver.setA(A);
            y.setData(arrayFrom(response));
            qrSolver.solve(this.y, this.b);
            DenseMatrix64F R = decomposition.getR(null, true);
            LinearSolver linearSolver = LinearSolverFactory.linear(numCols);
            linearSolver.setA(R);
            DenseMatrix64F Rinverse = new DenseMatrix64F(numCols, numCols);
            linearSolver.invert(Rinverse); // stores solver's solution inside of Rinverse.
            CommonOps.multOuter(Rinverse, this.AtAInv);
        }

        private DenseMatrix64F createMatrixA(int numRows, int numCols) {
            double[] data = hasIntercept ? fill(numRows, 1.0) : arrayFrom();
            for (List predictor : predictors) {
                data = combine(data, arrayFrom(predictor));
            }
            boolean isRowMajor = false;
            return new DenseMatrix64F(numRows, numCols, isRowMajor, data);
        }

        private D1Matrix64F computeFittedValues() {
            D1Matrix64F fitted = new DenseMatrix64F(response.size(), 1);
            MatrixVectorMult.mult(A, b, fitted);
            return fitted;
        }

        private List computeResiduals() {
            List fitted = getFittedvalues();
            List residuals = new ArrayList<>(fitted.size());
            for (int i = 0; i < fitted.size(); i++) {
                residuals.add(response.get(i) - fitted.get(i));
            }
            return residuals;
        }

        private double estimateSigma2(int df) {
            double ssq = Statistics.sumOfSquared(arrayFrom(this.residuals));
            return ssq / (this.residuals.size() - df);
        }

        private List getFittedvalues() {
            return listFrom(fitted.getData());
        }

        private List getResiduals() {
            return residuals;
        }

        private List getBetaEstimates() {
            return listFrom(b.getData());
        }

        private List getBetaStandardErrors(int numCols) {
            DenseMatrix64F diag = new DenseMatrix64F(numCols, 1);
            CommonOps.extractDiag(this.covarianceMatrix, diag);
            return listFrom(sqrt(diag.getData()));
        }

        private double getSigma2() {
            return this.sigma2;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy