com.github.waikatodatamining.matrix.algorithms.pls.SparsePLS Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of matrix-algorithms Show documentation
Show all versions of matrix-algorithms Show documentation
Java library of 2-dimensional matrix algorithms.
The newest version!
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* PLS1.java
* Copyright (C) 2018 University of Waikato, Hamilton, NZ
*/
package com.github.waikatodatamining.matrix.algorithms.pls;
import com.github.waikatodatamining.matrix.core.StoppedException;
import com.github.waikatodatamining.matrix.core.matrix.Matrix;
import com.github.waikatodatamining.matrix.core.matrix.MatrixFactory;
import com.github.waikatodatamining.matrix.algorithms.Standardize;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
/**
* Sparse PLS algorithm.
*
* See here:
* Sparse partial least squares regression for simultaneous dimension reduction and variable selection
*
* Implementation was oriented at the R SPLS package, which implemnets the above
* mentioned paper:
* Sparse Partial Least Squares (SPLS) Regression and Classification
*
* The lambda parameter controls the features sparseness. For sufficiently small
* lambda, all features will be selected and the algorithm results are equal
* to NIPALS'.
*
* @author Steven Lang
*/
public class SparsePLS
extends AbstractSingleResponsePLS {
private static final long serialVersionUID = -6097279189841762321L;
protected Matrix m_Bpls;
/** NIPALS tolerance threshold */
protected double m_Tol = 1e-7;
/** NIPALS max iterations */
protected int m_MaxIter = 500;
/** Sparsity parameter. Determines sparseness. */
protected double m_lambda = 0.5;
protected Set m_A;
/** Loadings. */
protected Matrix m_W;
/** Standardize X */
protected Standardize m_StandardizeX = new Standardize();
/** Standardize Y */
protected Standardize m_StandardizeY = new Standardize();
public int getMaxIter() {
return m_MaxIter;
}
public void setMaxIter(int maxIter) {
if (maxIter < 0) {
getLogger().warning("Maximum iterations parameter must be positive " +
"but was " + maxIter + ".");
} else {
this.m_MaxIter = maxIter;
reset();
}
}
public double getTol() {
return m_Tol;
}
public void setTol(double tol) {
if (tol < 0) {
getLogger().warning("Tolerance parameter must be positive but " +
"was " + tol + ".");
} else {
this.m_Tol = tol;
reset();
}
}
public double getLambda() {
return m_lambda;
}
public void setLambda(double lambda) {
if (lambda < 0){
getLogger().warning("Sparseness parameter lambda must be positive " +
"but was " + lambda + ".");
} else {
m_lambda = lambda;
reset();
}
}
/**
* Resets the member variables.
*/
@Override
protected void doReset() {
super.doReset();
m_Bpls = null;
m_A = null;
m_W = null;
m_StandardizeX = new Standardize();
m_StandardizeY = new Standardize();
}
/**
* Returns the all the available matrices.
*
* @return the names of the matrices
*/
@Override
public String[] getMatrixNames() {
return new String[]{
"W",
"B"
};
}
/**
* Returns the matrix with the specified name.
*
* @param name the name of the matrix
* @return the matrix, null if not available
*/
@Override
public Matrix getMatrix(String name) {
switch (name) {
case "W":
return m_W;
case "B":
return m_Bpls;
default:
return null;
}
}
/**
* Whether the algorithm supports return of loadings.
*
* @return true if supported
* @see #getLoadings()
*/
public boolean hasLoadings() {
return true;
}
/**
* Returns the loadings, if available.
*
* @return the loadings, null if not available
*/
public Matrix getLoadings() {
return getMatrix("W");
}
/**
* Initializes using the provided data.
*
* @param predictors the input data
* @param response the dependent variable(s)
*/
protected void doPLSConfigure(Matrix predictors, Matrix response) {
Matrix X, y, wk;
getLogger();
X = m_StandardizeX.configureAndTransform(predictors);
y = m_StandardizeY.configureAndTransform(response);
Matrix Xj = X.copy();
Matrix yj = y.copy();
m_A = new TreeSet<>();
m_Bpls = MatrixFactory.zeros(X.numColumns(), y.numColumns());
m_W = MatrixFactory.zeros(X.numColumns(), getNumComponents());
for (int k = 0; k < getNumComponents(); k++) {
if (m_Stopped)
throw new StoppedException();
wk = getDirectionVector(Xj, yj, k);
m_W.setColumn(k, wk);
if (m_Debug) {
checkDirectionVector(wk);
}
collectIndices(wk);
Matrix X_A = getColumnSubmatrixOf(X);
m_Bpls = MatrixFactory.zeros(X.numColumns(), y.numColumns());
Matrix Bpls_A = getRegressionCoefficient(X_A, y, k);
// Fill m_Bpls values at non zero indices with estimated
// regression coefficients
int idxCounter = 0;
for (Integer idx : m_A) {
m_Bpls.setRow(idx, Bpls_A.getRow(idxCounter++));
}
// Deflate
yj = y.sub(X.mul(m_Bpls));
}
if (m_Debug) {
getLogger().info("Selected following features " +
"(" + m_A.size() + "/" + X.numColumns() + "): ");
List l = m_A.stream().map(String::valueOf).collect(Collectors.toList());
getLogger().info(String.join(",", l));
}
}
/**
* Calculate NIPALS regression coefficients.
*
* @param X_A Predictors subset
* @param y Current response vector
* @param k PLS iteration
* @return Bpls (NIPALS regression coefficients)
*/
private Matrix getRegressionCoefficient(Matrix X_A, Matrix y, int k) {
int numComponents = Math.min(X_A.numColumns(), k + 1);
NIPALS nipals = new NIPALS();
nipals.setMaxIter(m_MaxIter);
nipals.setTol(m_Tol);
nipals.setNumComponents(numComponents);
nipals.configure(X_A, y);
return nipals.getCoef();
}
/**
* Get the column submatrix of X given by the indices in m_A
*
* @param X Input Matrix
* @return Submatrix of x
*/
private Matrix getColumnSubmatrixOf(Matrix X) {
Matrix X_A = MatrixFactory.zeros(X.numRows(), m_A.size());
int colCount = 0;
for (Integer i : m_A) {
Matrix col = X.getColumn(i);
X_A.setColumn(colCount, col);
colCount++;
}
return X_A;
}
/**
* Get the row submatrix of X given by the indices in m_A
*
* @param X Input Matrix
* @return Submatrix of x
*/
private Matrix getRowSubmatrixOf(Matrix X) {
Matrix X_A = MatrixFactory.zeros(m_A.size(), X.numColumns());
int rowCount = 0;
for (Integer i : m_A) {
Matrix row = X.getRow(i);
X_A.setRow(rowCount, row);
rowCount++;
}
return X_A;
}
/**
* Collect indices based on the current non zero indices in w and m_Bpls
*
* @param w Direction Vector
*/
private void collectIndices(Matrix w) {
m_A.clear();
m_A.addAll(w.whereVector(d -> Math.abs(d) > 1e-6));
m_A.addAll(m_Bpls.whereVector(d -> Math.abs(d) > 1e-6));
}
/**
* Check if the direction vector is fulfills w^Tw=1
*
* @param w Direction vector
*/
private void checkDirectionVector(Matrix w) {
// Test if w^Tw = 1
if (w.norm2squared() - 1 > 1e-6) {
getLogger().warning("Direction vector condition w'w=1 was violated.");
}
}
/**
* Compute the direction vector.
*
* @param X Predictors
* @param yj Current deflated response
* @param k Iteration
* @return Direction vector
*/
private Matrix getDirectionVector(Matrix X, Matrix yj, int k) {
Matrix Zp = X.t().mul(yj);
// Zp.divi(Zp.norm2()); // Reference paper uses l2 norm
double znorm = Zp.abs().median(); // R package spls uses median norm
Zp = Zp.div(znorm);
Matrix ZpSign = Zp.sign();
Matrix valb = Zp.abs().sub(m_lambda * Zp.abs().max());
// Collect indices where valb is >= 0
List idxs = valb.whereVector(d -> d >= 0);
Matrix preMul = valb.mulElementwise(ZpSign);
Matrix c = MatrixFactory.zeros(Zp.numRows(), 1);
for (Integer idx : idxs) {
double val = preMul.get(idx, 0);
c.set(idx, 0, val);
}
return c.div(c.norm2squared()); // Rescale c and use as estimated direction vector
/* Extension for multivariate Y (needs further testing):
Matrix w;
Matrix c;
Matrix wOld;
Matrix M;
Matrix U;
Matrix V;
Matrix cOld;
double iterationChangeW = m_Tol * 10;
double iterationChangeC = m_Tol * 10;
int iterations = 0;
// Repeat w step and c step until convergence
while ((iterationChangeW > m_Tol || iterationChangeC > m_Tol) && iterations < m_MaxIter) {
// w step
wOld = w;
M = Xt.mul(yj).mul(yj.t()).mul(X);
Matrix mtc = M.mul(c);
U = mtc.svdU();
V = mtc.svdV();
w = U.mul(V.t());
// c step
cOld = c;
Zp = Xt.mul(yj);
Zp.divi(Zp.norm2()); // Reference paper uses l2 norm
// double znorm = Zp.abs().median(); // R package spls uses median norm
// Zp.divi(znorm);
Matrix ZpSign = Zp.sign();
Matrix valb = Zp.abs().sub(m_lambda * Zp.abs().max());
// Collect indices where valb is >= 0
List idxs = valb.whereVector(d -> d >= 0);
Matrix preMul = valb.mulElementwise(ZpSign);
c = new Matrix(Zp.numRows(), 1);
for (Integer idx : idxs) {
double val = preMul.get(idx, 0);
c.set(idx, 0, val);
}
// Update stopping conditions
iterations++;
iterationChangeW = w.sub(wOld).norm2();
iterationChangeC = c.sub(cOld).norm2();
}
return w;*/
}
/**
* Transforms the data.
*
* @param predictors the input data
* @return the transformed data and the predictions
*/
@Override
protected Matrix doPLSTransform(Matrix predictors) {
int numComponents = getNumComponents();
Matrix T = MatrixFactory.zeros(predictors.numRows(), numComponents);
Matrix X = predictors.copy();
for (int k = 0; k < numComponents; k++) {
if (m_Stopped)
throw new StoppedException();
Matrix wk = m_W.getColumn(k);
Matrix tk = X.mul(wk);
T.setColumn(k, tk);
Matrix pk = X.t().mul(tk).div(tk.norm2squared());
X = X.sub(tk.mul(pk.t()));
}
return T;
}
/**
* Returns whether the algorithm can make predictions.
*
* @return true if can make predictions
*/
public boolean canPredict() {
return true;
}
/**
* Performs predictions on the data.
*
* @param predictors the input data
* @return the transformed data and the predictions
*/
@Override
protected Matrix doPLSPredict(Matrix predictors) {
Matrix X = m_StandardizeX.transform(predictors);
Matrix X_A = getColumnSubmatrixOf(X);
Matrix B_A = getRowSubmatrixOf(m_Bpls);
Matrix yMeans = MatrixFactory.fromColumn(m_StandardizeY.getMeans());
Matrix yStd = MatrixFactory.fromColumn(m_StandardizeY.getStdDevs());
Matrix yhat = X_A.mul(B_A).scaleByRowVector(yStd).addByVector(yMeans);
return yhat;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy