All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.waikatodatamining.matrix.algorithm.pls.KernelPLS Maven / Gradle / Ivy

There is a newer version: 0.1.0
Show newest version
package com.github.waikatodatamining.matrix.algorithm.pls;

import com.github.waikatodatamining.matrix.core.Matrix;
import com.github.waikatodatamining.matrix.core.MatrixFactory;
import com.github.waikatodatamining.matrix.transformation.Center;
import com.github.waikatodatamining.matrix.transformation.kernel.AbstractKernel;
import com.github.waikatodatamining.matrix.transformation.kernel.RBFKernel;

/**
 * Kernel Partial Least Squares algorithm.
 * 
* See here: * Kernel Partial Least Squares Regression in Reproducing * Kernel Hilbert Space * * @author Steven Lang */ public class KernelPLS extends AbstractMultiResponsePLS { private static final long serialVersionUID = -2760078672082710402L; public static final int SEED = 0; /** Calibration data in feature space */ protected Matrix m_K_orig; protected Matrix m_K_deflated; /** Scores on K */ protected Matrix m_T; /** Scores on Y */ protected Matrix m_U; /** Loadings on K */ protected Matrix m_P; /** Loadings on Y */ protected Matrix m_Q; /** Partial regression matrix */ protected Matrix m_B_RHS; /** Training points */ protected Matrix m_X; /** Kernel for feature transformation */ protected AbstractKernel m_Kernel; /** Inner NIPALS loop improvement tolerance */ protected double m_Tol; /** Inner NIPALS loop maximum number of iterations */ protected int m_MaxIter; /** Center X transformation */ protected Center m_CenterX; /** Center Y transformation */ protected Center m_CenterY; @Override protected void initialize() { super.initialize(); setKernel(new RBFKernel()); setTol(1e-6); setMaxIter(500); m_CenterX = new Center(); m_CenterY = new Center(); } public AbstractKernel getKernel() { return m_Kernel; } public void setKernel(AbstractKernel kernel) { this.m_Kernel = kernel; reset(); } public int getMaxIter() { return m_MaxIter; } public void setMaxIter(int maxIter) { if (maxIter < 0) { m_Logger.warning("Maximum iterations parameter must be positive " + "but was " + maxIter + "."); } else { this.m_MaxIter = maxIter; reset(); } } public double getTol() { return m_Tol; } public void setTol(double tol) { if (tol < 0) { m_Logger.warning("Tolerance parameter must be positive but " + "was " + tol + "."); } else { this.m_Tol = tol; reset(); } } @Override protected int getMinColumnsResponse() { return 1; } @Override protected int getMaxColumnsResponse() { return -1; } @Override protected String doPerformInitialization(Matrix predictors, Matrix response) throws Exception { Matrix Y, I, t, u, q, w; getLogger(); // Init int numComponents = getNumComponents(); m_X = predictors; m_X = m_CenterX.transform(m_X); Y = response; Y = m_CenterY.transform(Y); int numRows = m_X.numRows(); int numClasses = Y.numColumns(); q = MatrixFactory.zeros(numClasses, 1); t = MatrixFactory.zeros(numRows, 1); w = MatrixFactory.zeros(numRows, 1); I = MatrixFactory.eye(numRows, numRows); m_T = MatrixFactory.zeros(numRows, numComponents); m_U = MatrixFactory.zeros(numRows, numComponents); m_P = MatrixFactory.zeros(numRows, numComponents); m_Q = MatrixFactory.zeros(numClasses, numComponents); m_K_orig = m_Kernel.applyMatrix(m_X); m_K_orig = centralizeTrainInKernelSpace(m_K_orig); m_K_deflated = m_K_orig.copy(); for (int currentComponent = 0; currentComponent < numComponents; currentComponent++) { int iterations = 0; Matrix uOld; u = MatrixFactory.randn(numRows, 1, SEED + currentComponent); double iterationChange = m_Tol * 10; // Repeat 1) - 3) until convergence: either change of u is lower than m_Tol or maximum // number of iterations has been reached (m_MaxIter) while (iterationChange > m_Tol && iterations < m_MaxIter) { // 1) t = m_K_deflated.mul(u).normalized(); w = t.copy(); // 2) q = Y.transpose().mul(t); // 3) uOld = u; u = Y.mul(q).normalized(); // Update stopping conditions iterations++; iterationChange = u.sub(uOld).norm2(); } // Deflate Matrix ttTrans = t.mul(t.transpose()); Matrix part = I.sub(ttTrans); m_K_deflated = part.mul(m_K_deflated).mul(part); Y = Y.sub(t.mul(q.transpose())); Matrix p = m_K_deflated.transpose().mul(w).div(w.transpose().mul(w).asDouble()); // Store u,t,c,p m_T.setColumn(currentComponent, t); m_U.setColumn(currentComponent, u); m_Q.setColumn(currentComponent, q); m_P.setColumn(currentComponent, p); } // Calculate right hand side of the regression matrix B Matrix tTtimesKtimesU = m_T.transpose().mul(m_K_orig).mul(m_U); Matrix inv = tTtimesKtimesU.inverse(); m_B_RHS = inv.mul(m_Q.transpose()); return null; } /** * Centralize a kernel matrix in the kernel space via: * K <- (I - 1/n * 1_n * 1_n^T) * K * (I - 1/n * 1_n * 1_n^T) * * @param K Kernel matrix * @return Centralised kernel matrix */ protected Matrix centralizeTrainInKernelSpace(Matrix K) { int n = m_X.numRows(); Matrix I = MatrixFactory.eye(n, n); Matrix one = MatrixFactory.filled(n, 1, 1.0); // Centralize in kernel space Matrix part = I.sub(one.mul(one.transpose()).div(n)); return part.mul(K).mul(part); } /** * @param K Kernel matrix * @return Centralised kernel matrix */ protected Matrix centralizeTestInKernelSpace(Matrix K) { int nTrain = m_X.numRows(); int nTest = K.numRows(); Matrix I = MatrixFactory.eye(nTrain, nTrain); Matrix onesTrainTestScaled = MatrixFactory.filled(nTest, nTrain, 1.0 / nTrain); Matrix onesTrainScaled = MatrixFactory.filled(nTrain, nTrain, 1.0 / nTrain); return (K.sub(onesTrainTestScaled.mul(m_K_orig))).mul(I.sub(onesTrainScaled)); } @Override protected Matrix doPerformPredictions(Matrix predictors) { Matrix K_t = doTransform(predictors); Matrix Y_hat = K_t.mul(m_B_RHS); Y_hat = m_CenterY.inverseTransform(Y_hat); return Y_hat; } @Override protected Matrix doTransform(Matrix predictors) { Matrix predictorsCentered = m_CenterX.transform(predictors); Matrix K_t = m_Kernel.applyMatrix(predictorsCentered, m_X); K_t = centralizeTestInKernelSpace(K_t); return K_t.mul(m_U); } @Override public String[] getMatrixNames() { return new String[]{"K", "T", "U", "P", "Q"}; } @Override public Matrix getMatrix(String name) { switch (name) { case "K": return m_K_deflated; case "T": return m_T; case "U": return m_U; case "P": return m_P; case "Q": return m_Q; } return null; } @Override public boolean hasLoadings() { return true; } @Override protected void reset() { super.reset(); m_K_orig = null; m_K_deflated = null; m_T = null; m_U = null; m_P = null; m_Q = null; m_B_RHS = null; m_X = null; m_CenterX = new Center(); m_CenterY = new Center(); } @Override public Matrix getLoadings() { return m_T; } @Override public boolean canPredict() { return true; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy