com.github.waikatodatamining.matrix.algorithm.pls.SIMPLS Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of matrix-algorithms Show documentation
Show all versions of matrix-algorithms Show documentation
Java library of 2-dimensional matrix algorithms.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* SIMPLS.java
* Copyright (C) 2018 University of Waikato, Hamilton, NZ
*/
package com.github.waikatodatamining.matrix.algorithm.pls;
import com.github.waikatodatamining.matrix.core.Matrix;
import com.github.waikatodatamining.matrix.core.MatrixFactory;
import com.github.waikatodatamining.matrix.core.MatrixHelper;
import com.github.waikatodatamining.matrix.core.Utils;
/**
* SIMPLS algorithm.
*
* See here:
* SIMPLS (StatSoft)
*
* @author FracPete (fracpete at waikato dot ac dot nz)
*/
public class SIMPLS
extends AbstractSingleResponsePLS {
private static final long serialVersionUID = 4899661745515419256L;
/** the number of coefficients in W to keep (0 keep all). */
protected int m_NumCoefficients;
/** the W matrix */
protected Matrix m_W;
/** the B matrix (used for prediction) */
protected Matrix m_B;
/** Q matrix to regress T (XW) on y */
protected Matrix m_Q;
/**
* Initializes the members.
*/
@Override
protected void initialize() {
super.initialize();
setNumCoefficients(0);
}
/**
* Resets the member variables.
*/
@Override
protected void reset() {
super.reset();
m_B = null;
m_W = null;
}
/**
* Sets the number of coefficients of W matrix to keep (rest gets zeroed).
*
* @param value the number of coefficients, 0 to keep all
*/
public void setNumCoefficients(int value) {
m_NumCoefficients = value;
reset();
}
/**
* returns the number of coefficients of W matrix to keep (rest gets zeroed).
*
* @return the maximum number of attributes, 0 to keep all
*/
public int getNumCoefficients() {
return m_NumCoefficients;
}
/**
* Returns the all the available matrices.
*
* @return the names of the matrices
*/
@Override
public String[] getMatrixNames() {
return new String[]{
"W",
"B",
"Q"
};
}
/**
* Returns the matrix with the specified name.
*
* @param name the name of the matrix
* @return the matrix, null if not available
*/
@Override
public Matrix getMatrix(String name) {
switch (name) {
case "W":
return m_W;
case "B":
return m_B;
case "Q":
return m_Q;
default:
return null;
}
}
/**
* Whether the algorithm supports return of loadings.
*
* @return true if supported
* @see #getLoadings()
*/
public boolean hasLoadings() {
return true;
}
/**
* Returns the loadings, if available.
*
* @return the loadings, null if not available
*/
public Matrix getLoadings() {
return getMatrix("W");
}
/**
* Zeroes the coefficients of the W matrix beyond the specified number of
* coefficients.
*
* @param in the matrix to process in-place
*/
protected void slim(Matrix in) {
double[][] B = in.toRawCopy2D();
for (int i = 0; i < in.numColumns(); i++) {
Matrix l = in.getSubMatrix(0, in.numRows(), i, i + 1);
double[] ld = l.toRawCopy1D();
for (int t = 0; t < ld.length; t++) {
ld[t] = Math.abs(ld[t]);
}
int[] srt = Utils.sort(ld);
//int index = srt.length - 1 - srt[Math.min(getNumCoefficients(),srt.length-1)]; //nonono
int index = srt[Math.max(srt.length - 1 - getNumCoefficients(), 0)];
double val = ld[index];
for (int c = 0; c < in.numRows(); c++) {
if (Math.abs(B[c][i]) < val) {
in.set(c, i, 0);
}
}
}
}
/**
* Initializes using the provided data.
*
* @param predictors the input data
* @param response the dependent variable(s)
* @throws Exception if analysis fails
* @return null if successful, otherwise error message
*/
protected String doPerformInitialization(Matrix predictors, Matrix response) throws Exception {
Matrix A, A_trans;
Matrix M;
Matrix X_trans;
Matrix C, c;
Matrix Q, q;
Matrix W, w;
Matrix P, p, p_trans;
Matrix v, v_trans;
int h;
X_trans = predictors.transpose();
A = X_trans.mul(response);
M = X_trans.mul(predictors);
C = MatrixFactory.eye(predictors.numColumns(), predictors.numColumns());
W = MatrixFactory.zeros(predictors.numColumns(), getNumComponents());
P = MatrixFactory.zeros(predictors.numColumns(), getNumComponents());
Q = MatrixFactory.zeros(1, getNumComponents());
for (h = 0; h < getNumComponents(); h++) {
// 1. qh as dominant EigenVector of Ah'*Ah
A_trans = A.transpose();
q = MatrixHelper.getDominantEigenVector(A_trans.mul(A));
// 2. wh=Ah*qh, ch=wh'*Mh*wh, wh=wh/sqrt(ch), store wh in W as column
w = A.mul(q);
c = w.transpose().mul(M).mul(w);
w = w.mul(1.0 / StrictMath.sqrt(c.asDouble()));
W.setColumn(h, w);
// 3. ph=Mh*wh, store ph in P as column
p = M.mul(w);
p_trans = p.transpose();
P.setColumn(h, p);
// 4. qh=Ah'*wh, store qh in Q as column
q = A_trans.mul(w);
Q.setColumn(h, q);
// 5. vh=Ch*ph, vh=vh/||vh||
v = C.mul(p);
v = v.normalized();
v_trans = v.transpose();
// 6. Ch+1=Ch-vh*vh', Mh+1=Mh-ph*ph'
C = C.sub(v.mul(v_trans));
M = M.sub(p.mul(p_trans));
// 7. Ah+1=ChAh (actually Ch+1)
A = C.mul(A);
}
// finish
if (m_NumCoefficients > 0)
slim(W);
m_W = W;
m_B = W.mul(Q.transpose());
m_Q = Q;
return null;
}
/**
* Transforms the data.
*
* @param predictors the input data
* @throws Exception if analysis fails
* @return the transformed data
*/
@Override
protected Matrix doTransform(Matrix predictors) throws Exception {
return predictors.mul(m_W);
}
/**
* Returns whether the algorithm can make predictions.
*
* @return true if can make predictions
*/
public boolean canPredict() {
return true;
}
/**
* Performs predictions on the data.
*
* @param predictors the input data
* @throws Exception if analysis fails
* @return the predictions
*/
@Override
protected Matrix doPerformPredictions(Matrix predictors) throws Exception {
return predictors.mul(m_B);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy