All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.core.matrix.LinearRegression Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 * This software is a cooperative product of The MathWorks and the National
 * Institute of Standards and Technology (NIST) which has been released to the
 * public domain. Neither The MathWorks nor NIST assumes any responsibility
 * whatsoever for its use by other parties, and makes no guarantees, expressed
 * or implied, about its quality, reliability, or any other characteristic.
 */

/*
 * LinearRegression.java
 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.core.matrix;

import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;

/**
 * Class for performing (ridged) linear regression using Tikhonov
 * regularization.
 * 
 * @author Fracpete (fracpete at waikato dot ac dot nz)
 * @version $Revision: 9768 $
 */
 
public class LinearRegression
  implements RevisionHandler {

  /** the coefficients */
  protected double[] m_Coefficients = null;

  /**
   * Performs a (ridged) linear regression.
   *
   * @param a the matrix to perform the regression on
   * @param y the dependent variable vector
   * @param ridge the ridge parameter
   * @throws IllegalArgumentException if not successful
   */
  public LinearRegression(Matrix a, Matrix y, double ridge) {
    calculate(a, y, ridge);
  }

  /**
   * Performs a weighted (ridged) linear regression. 
   *
   * @param a the matrix to perform the regression on
   * @param y the dependent variable vector
   * @param w the array of data point weights
   * @param ridge the ridge parameter
   * @throws IllegalArgumentException if the wrong number of weights were
   * provided.
   */
  public LinearRegression(Matrix a, Matrix y, double[] w, double ridge) {

    if (w.length != a.getRowDimension())
      throw new IllegalArgumentException("Incorrect number of weights provided");
    Matrix weightedThis = new Matrix(
                              a.getRowDimension(), a.getColumnDimension());
    Matrix weightedDep = new Matrix(a.getRowDimension(), 1);
    for (int i = 0; i < w.length; i++) {
      double sqrt_weight = Math.sqrt(w[i]);
      for (int j = 0; j < a.getColumnDimension(); j++)
        weightedThis.set(i, j, a.get(i, j) * sqrt_weight);
      weightedDep.set(i, 0, y.get(i, 0) * sqrt_weight);
    }

    calculate(weightedThis, weightedDep, ridge);
  }

  /**
   * performs the actual regression.
   *
   * @param a the matrix to perform the regression on
   * @param y the dependent variable vector
   * @param ridge the ridge parameter
   * @throws IllegalArgumentException if not successful
   */
  protected void calculate(Matrix a, Matrix y, double ridge) {

    if (y.getColumnDimension() > 1)
      throw new IllegalArgumentException("Only one dependent variable allowed");

    int nc = a.getColumnDimension();
    m_Coefficients = new double[nc];
    Matrix solution;

    Matrix ss = aTa(a);
    Matrix bb = aTy(a, y);

    boolean success = true;

    do {
      // Set ridge regression adjustment
      Matrix ssWithRidge = ss.copy();
      for (int i = 0; i < nc; i++)
        ssWithRidge.set(i, i, ssWithRidge.get(i, i) + ridge);

      // Carry out the regression
      try {
        solution = ssWithRidge.solve(bb);
        for (int i = 0; i < nc; i++)
          m_Coefficients[i] = solution.get(i, 0);
        success = true;
      } catch (Exception ex) {
        ridge *= 10;
        success = false;
      }
    } while (!success);
  }
  
  /**
   * Return aTa (a' * a)
   */
  private static Matrix aTa(Matrix a) {
    int cols = a.getColumnDimension();
    double[][] A = a.getArray();
    Matrix x = new Matrix(cols, cols);
    double[][] X = x.getArray();
    double[] Acol = new double[a.getRowDimension()];
    for (int col1 = 0; col1 < cols; col1++) {
      // cache the column for faster access later
      for (int row = 0; row < Acol.length; row++) {
        Acol[row] = A[row][col1];
      }
      // reference the row for faster lookup
      double[] Xrow = X[col1];
      for (int row = 0; row < Acol.length; row++) {
        double[] Arow = A[row];
        for (int col2 = col1; col2 < Xrow.length; col2++) {
          Xrow[col2] += Acol[row] * Arow[col2];
        }
      }
      // result is symmetric
      for (int col2 = col1 + 1; col2 < Xrow.length; col2++) {
        X[col2][col1] = Xrow[col2];
      }
    }
    return x;
  }

  /**
   * Return aTy (a' * y)
   */
  private static Matrix aTy(Matrix a, Matrix y) {
    double[][] A = a.getArray();
    double[][] Y = y.getArray();
    Matrix x = new Matrix(a.getColumnDimension(), 1);
    double[][] X = x.getArray();
    for (int row = 0; row < A.length; row++) {
      // reference the rows for faster lookup
      double[] Arow = A[row];
      double[] Yrow = Y[row];
      for (int col = 0; col < Arow.length; col++) {
        X[col][0] += Arow[col] * Yrow[0];
      }
    }
    return x;
  }

  /**
   * returns the calculated coefficients
   *
   * @return the coefficients
   */
  public final double[] getCoefficients() {
    return m_Coefficients;
  }

  /**
   * returns the coefficients in a string representation
   */
  public String toString() {
    return Utils.arrayToString(getCoefficients());
  }
  
  /**
   * Returns the revision string.
   * 
   * @return		the revision
   */
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 9768 $");
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy