All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.waikatodatamining.matrix.algorithm.pls.OPLS Maven / Gradle / Ivy

There is a newer version: 0.1.0
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 * PLS1.java
 * Copyright (C) 2018 University of Waikato, Hamilton, NZ
 */

package com.github.waikatodatamining.matrix.algorithm.pls;

import com.github.waikatodatamining.matrix.core.Matrix;
import com.github.waikatodatamining.matrix.core.MatrixFactory;


/**
 * OPLS algorithm.
 * 
* See here: * Orthogonal Projections to latent structures (O-PLS) * * @author Steven Lang */ public class OPLS extends AbstractSingleResponsePLS { private static final long serialVersionUID = -6097279189841762321L; /** the P matrix */ protected Matrix m_Porth; /** the T matrix */ protected Matrix m_Torth; /** the W matrix */ protected Matrix m_Worth; /** Data with orthogonal signal components removed */ protected Matrix m_Xosc; /** Base PLS that is trained on the cleaned data */ protected AbstractPLS m_BasePLS; /** Get the base PLS model that is fitted on the OSC cleaned data */ public AbstractPLS getBasePLS() { return m_BasePLS; } /** Set the base PLS model that is fitted on the OSC cleaned data */ public void setBasePLS(AbstractPLS basePLS) { m_BasePLS = basePLS; reset(); } /** * Resets the member variables. */ @Override protected void reset() { super.reset(); m_Porth = null; m_Worth = null; m_Torth = null; } @Override protected void initialize() { super.initialize(); setBasePLS(new PLS1()); } /** * Returns the all the available matrices. * * @return the names of the matrices */ @Override public String[] getMatrixNames() { return new String[]{ "P_orth", "W_orth", "T_orth" }; } /** * Returns the matrix with the specified name. * * @param name the name of the matrix * @return the matrix, null if not available */ @Override public Matrix getMatrix(String name) { switch (name) { case "P_orth": return m_Porth; case "W_orth": return m_Worth; case "T_orth": return m_Torth; default: return null; } } /** * Whether the algorithm supports return of loadings. * * @return true if supported * @see #getLoadings() */ public boolean hasLoadings() { return true; } /** * Returns the loadings, if available. * * @return the loadings, null if not available */ public Matrix getLoadings() { return getMatrix("P_orth"); } /** * Initializes using the provided data. * * @param predictors the input data * @param response the dependent variable(s) * @return null if successful, otherwise error message */ protected String doPerformInitialization(Matrix predictors, Matrix response) throws Exception { Matrix X, Xtrans, y; Matrix w, wOrth; Matrix t, tOrth; Matrix p, pOrth; X = predictors.copy(); Xtrans = X.transpose(); y = response; // init m_Worth = MatrixFactory.zeros(predictors.numColumns(), getNumComponents()); m_Porth = MatrixFactory.zeros(predictors.numColumns(), getNumComponents()); m_Torth = MatrixFactory.zeros(predictors.numRows(), getNumComponents()); w = Xtrans.mul(y).mul(invL2Squared(y)).normalized(); for (int currentComponent = 0; currentComponent < getNumComponents(); currentComponent++) { // Calculate scores vector t = X.mul(w).mul(invL2Squared(w)); // Calculate loadings of X p = Xtrans.mul(t).mul(invL2Squared(t)); // Orthogonalize weight wOrth = p.sub(w.mul(w.transpose().mul(p).mul(invL2Squared(w)).asDouble())); wOrth = wOrth.normalized(); tOrth = X.mul(wOrth).mul(invL2Squared(wOrth)); pOrth = Xtrans.mul(tOrth).mul(invL2Squared(tOrth)); // Remove orthogonal components from X X = X.sub(tOrth.mul(pOrth.transpose())); Xtrans = X.transpose(); // Store results m_Worth.setColumn(currentComponent, wOrth); m_Torth.setColumn(currentComponent, tOrth); m_Porth.setColumn(currentComponent, pOrth); } m_Xosc = X.copy(); m_BasePLS.initialize(this.doTransform(predictors), response); return null; } /** * Get the inverse of the squared l2 norm. * @param v Input vector * @return 1.0 / norm2(v)^2 */ protected double invL2Squared(Matrix v) { double l2 = v.norm2(); return 1.0 / (l2 * l2); } /** * Transforms the data. * * @param predictors the input data * @return the transformed data and the predictions * @throws Exception if analysis fails */ @Override protected Matrix doTransform(Matrix predictors) { // Remove signal from X_test that is orthogonal to y_train // X_clean = X_test - X_test*W_orth*P_orth^T Matrix T = predictors.mul(m_Worth); Matrix Xorth = T.mul(m_Porth.transpose()); return predictors.sub(Xorth); } /** * Returns whether the algorithm can make predictions. * * @return true if can make predictions */ public boolean canPredict() { return true; } /** * Performs predictions on the data. * * @param predictors the input data * @return the transformed data and the predictions * @throws Exception if analysis fails */ @Override protected Matrix doPerformPredictions(Matrix predictors) throws Exception { Matrix Xtransformed = transform(predictors); return m_BasePLS.predict(Xtransformed); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy