All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression Maven / Gradle / Ivy

There is a newer version: 23.0.6
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.math3.stat.regression;

import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.stat.descriptive.moment.SecondMoment;

/**
 * 

Implements ordinary least squares (OLS) to estimate the parameters of a * multiple linear regression model.

* *

The regression coefficients, b, satisfy the normal equations: *

 XT X b = XT y 

* *

To solve the normal equations, this implementation uses QR decomposition * of the X matrix. (See {@link QRDecomposition} for details on the * decomposition algorithm.) The X matrix, also known as the design matrix, * has rows corresponding to sample observations and columns corresponding to independent * variables. When the model is estimated using an intercept term (i.e. when * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the X * matrix includes an initial column identically equal to 1. We solve the normal equations * as follows: *

 XTX b = XT y
 * (QR)T (QR) b = (QR)Ty
 * RT (QTQ) R b = RT QT y
 * RT R b = RT QT y
 * (RT)-1 RT R b = (RT)-1 RT QT y
 * R b = QT y 

* *

Given Q and R, the last equation is solved by back-substitution.

* * @version $Id: OLSMultipleLinearRegression.java 1416643 2012-12-03 19:37:14Z tn $ * @since 2.0 */ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression { /** Cached QR decomposition of X matrix */ private QRDecomposition qr = null; /** * Loads model x and y sample data, overriding any previous sample. * * Computes and caches QR decomposition of the X matrix. * @param y the [n,1] array representing the y sample * @param x the [n,k] array representing the x sample * @throws MathIllegalArgumentException if the x and y array data are not * compatible for the regression */ public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException { validateSampleData(x, y); newYSampleData(y); newXSampleData(x); } /** * {@inheritDoc} *

This implementation computes and caches the QR decomposition of the X matrix.

*/ @Override public void newSampleData(double[] data, int nobs, int nvars) { super.newSampleData(data, nobs, nvars); qr = new QRDecomposition(getX()); } /** *

Compute the "hat" matrix. *

*

The hat matrix is defined in terms of the design matrix X * by X(XTX)-1XT *

*

The implementation here uses the QR decomposition to compute the * hat matrix as Q IpQT where Ip is the * p-dimensional identity matrix augmented by 0's. This computational * formula is from "The Hat Matrix in Regression and ANOVA", * David C. Hoaglin and Roy E. Welsch, * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. *

*

Data for the model must have been successfully loaded using one of * the {@code newSampleData} methods before invoking this method; otherwise * a {@code NullPointerException} will be thrown.

* * @return the hat matrix */ public RealMatrix calculateHat() { // Create augmented identity matrix RealMatrix Q = qr.getQ(); final int p = qr.getR().getColumnDimension(); final int n = Q.getColumnDimension(); // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3 Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n); double[][] augIData = augI.getDataRef(); for (int i = 0; i < n; i++) { for (int j =0; j < n; j++) { if (i == j && i < p) { augIData[i][j] = 1d; } else { augIData[i][j] = 0d; } } } // Compute and return Hat matrix // No DME advertised - args valid if we get here return Q.multiply(augI).multiply(Q.transpose()); } /** *

Returns the sum of squared deviations of Y from its mean.

* *

If the model has no intercept term, 0 is used for the * mean of Y - i.e., what is returned is the sum of the squared Y values.

* *

The value returned by this method is the SSTO value used in * the {@link #calculateRSquared() R-squared} computation.

* * @return SSTO - the total sum of squares * @throws MathIllegalArgumentException if the sample has not been set or does * not contain at least 3 observations * @see #isNoIntercept() * @since 2.2 */ public double calculateTotalSumOfSquares() throws MathIllegalArgumentException { if (isNoIntercept()) { return StatUtils.sumSq(getY().toArray()); } else { return new SecondMoment().evaluate(getY().toArray()); } } /** * Returns the sum of squared residuals. * * @return residual sum of squares * @since 2.2 */ public double calculateResidualSumOfSquares() { final RealVector residuals = calculateResiduals(); // No advertised DME, args are valid return residuals.dotProduct(residuals); } /** * Returns the R-Squared statistic, defined by the formula
     * R2 = 1 - SSR / SSTO
     * 
* where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals} * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares} * * @return R-square statistic * @throws MathIllegalArgumentException if the sample has not been set or does * not contain at least 3 observations * @since 2.2 */ public double calculateRSquared() throws MathIllegalArgumentException { return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares(); } /** *

Returns the adjusted R-squared statistic, defined by the formula

     * R2adj = 1 - [SSR (n - 1)] / [SSTO (n - p)]
     * 
* where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}, * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number * of observations and p is the number of parameters estimated (including the intercept).

* *

If the regression is estimated without an intercept term, what is returned is

     *  1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) 
     * 

* * @return adjusted R-Squared statistic * @throws MathIllegalArgumentException if the sample has not been set or does * not contain at least 3 observations * @see #isNoIntercept() * @since 2.2 */ public double calculateAdjustedRSquared() throws MathIllegalArgumentException { final double n = getX().getRowDimension(); if (isNoIntercept()) { return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension())); } else { return 1 - (calculateResidualSumOfSquares() * (n - 1)) / (calculateTotalSumOfSquares() * (n - getX().getColumnDimension())); } } /** * {@inheritDoc} *

This implementation computes and caches the QR decomposition of the X matrix * once it is successfully loaded.

*/ @Override protected void newXSampleData(double[][] x) { super.newXSampleData(x); qr = new QRDecomposition(getX()); } /** * Calculates the regression coefficients using OLS. * *

Data for the model must have been successfully loaded using one of * the {@code newSampleData} methods before invoking this method; otherwise * a {@code NullPointerException} will be thrown.

* * @return beta */ @Override protected RealVector calculateBeta() { return qr.getSolver().solve(getY()); } /** *

Calculates the variance-covariance matrix of the regression parameters. *

*

Var(b) = (XTX)-1 *

*

Uses QR decomposition to reduce (XTX)-1 * to (RTR)-1, with only the top p rows of * R included, where p = the length of the beta vector.

* *

Data for the model must have been successfully loaded using one of * the {@code newSampleData} methods before invoking this method; otherwise * a {@code NullPointerException} will be thrown.

* * @return The beta variance-covariance matrix */ @Override protected RealMatrix calculateBetaVariance() { int p = getX().getColumnDimension(); RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1); RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse(); return Rinv.multiply(Rinv.transpose()); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy