weka.core.matrix.LinearRegression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-stable Show documentation
Show all versions of weka-stable Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This is the stable version. Apart from bugfixes, this version
does not receive any other updates.
/*
* This software is a cooperative product of The MathWorks and the National
* Institute of Standards and Technology (NIST) which has been released to the
* public domain. Neither The MathWorks nor NIST assumes any responsibility
* whatsoever for its use by other parties, and makes no guarantees, expressed
* or implied, about its quality, reliability, or any other characteristic.
*/
/*
* LinearRegression.java
* Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
*
*/
package weka.core.matrix;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
/**
* Class for performing (ridged) linear regression using Tikhonov
* regularization.
*
* @author Fracpete (fracpete at waikato dot ac dot nz)
* @version $Revision: 9768 $
*/
public class LinearRegression
implements RevisionHandler {
/** the coefficients */
protected double[] m_Coefficients = null;
/**
* Performs a (ridged) linear regression.
*
* @param a the matrix to perform the regression on
* @param y the dependent variable vector
* @param ridge the ridge parameter
* @throws IllegalArgumentException if not successful
*/
public LinearRegression(Matrix a, Matrix y, double ridge) {
calculate(a, y, ridge);
}
/**
* Performs a weighted (ridged) linear regression.
*
* @param a the matrix to perform the regression on
* @param y the dependent variable vector
* @param w the array of data point weights
* @param ridge the ridge parameter
* @throws IllegalArgumentException if the wrong number of weights were
* provided.
*/
public LinearRegression(Matrix a, Matrix y, double[] w, double ridge) {
if (w.length != a.getRowDimension())
throw new IllegalArgumentException("Incorrect number of weights provided");
Matrix weightedThis = new Matrix(
a.getRowDimension(), a.getColumnDimension());
Matrix weightedDep = new Matrix(a.getRowDimension(), 1);
for (int i = 0; i < w.length; i++) {
double sqrt_weight = Math.sqrt(w[i]);
for (int j = 0; j < a.getColumnDimension(); j++)
weightedThis.set(i, j, a.get(i, j) * sqrt_weight);
weightedDep.set(i, 0, y.get(i, 0) * sqrt_weight);
}
calculate(weightedThis, weightedDep, ridge);
}
/**
* performs the actual regression.
*
* @param a the matrix to perform the regression on
* @param y the dependent variable vector
* @param ridge the ridge parameter
* @throws IllegalArgumentException if not successful
*/
protected void calculate(Matrix a, Matrix y, double ridge) {
if (y.getColumnDimension() > 1)
throw new IllegalArgumentException("Only one dependent variable allowed");
int nc = a.getColumnDimension();
m_Coefficients = new double[nc];
Matrix solution;
Matrix ss = aTa(a);
Matrix bb = aTy(a, y);
boolean success = true;
do {
// Set ridge regression adjustment
Matrix ssWithRidge = ss.copy();
for (int i = 0; i < nc; i++)
ssWithRidge.set(i, i, ssWithRidge.get(i, i) + ridge);
// Carry out the regression
try {
solution = ssWithRidge.solve(bb);
for (int i = 0; i < nc; i++)
m_Coefficients[i] = solution.get(i, 0);
success = true;
} catch (Exception ex) {
ridge *= 10;
success = false;
}
} while (!success);
}
/**
* Return aTa (a' * a)
*/
private static Matrix aTa(Matrix a) {
int cols = a.getColumnDimension();
double[][] A = a.getArray();
Matrix x = new Matrix(cols, cols);
double[][] X = x.getArray();
double[] Acol = new double[a.getRowDimension()];
for (int col1 = 0; col1 < cols; col1++) {
// cache the column for faster access later
for (int row = 0; row < Acol.length; row++) {
Acol[row] = A[row][col1];
}
// reference the row for faster lookup
double[] Xrow = X[col1];
for (int row = 0; row < Acol.length; row++) {
double[] Arow = A[row];
for (int col2 = col1; col2 < Xrow.length; col2++) {
Xrow[col2] += Acol[row] * Arow[col2];
}
}
// result is symmetric
for (int col2 = col1 + 1; col2 < Xrow.length; col2++) {
X[col2][col1] = Xrow[col2];
}
}
return x;
}
/**
* Return aTy (a' * y)
*/
private static Matrix aTy(Matrix a, Matrix y) {
double[][] A = a.getArray();
double[][] Y = y.getArray();
Matrix x = new Matrix(a.getColumnDimension(), 1);
double[][] X = x.getArray();
for (int row = 0; row < A.length; row++) {
// reference the rows for faster lookup
double[] Arow = A[row];
double[] Yrow = Y[row];
for (int col = 0; col < Arow.length; col++) {
X[col][0] += Arow[col] * Yrow[0];
}
}
return x;
}
/**
* returns the calculated coefficients
*
* @return the coefficients
*/
public final double[] getCoefficients() {
return m_Coefficients;
}
/**
* returns the coefficients in a string representation
*/
public String toString() {
return Utils.arrayToString(getCoefficients());
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 9768 $");
}
}