All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.caseystella.analytics.outlier.batch.rpca.RPCA Maven / Gradle / Ivy

There is a newer version: 0.9.7
Show newest version
/**
 * Copyright (C) 2016 Hurence ([email protected])
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.caseystella.analytics.outlier.batch.rpca;


import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;

public class RPCA {

  private RealMatrix X;
  private RealMatrix L;
  private RealMatrix S;
  private RealMatrix E;

  private double lpenalty;
  private double spenalty;

  private static final int MAX_ITERS = 228;

  public RPCA(double[][] data, double lpenalty, double spenalty) {
    this.X = MatrixUtils.createRealMatrix(data);
    this.lpenalty = lpenalty;
    this.spenalty = spenalty;
    initMatrices();
    computeRSVD();
  }

  public RPCA(RealMatrix X, double lpenalty, double spenalty) {
    this.X = X;
    this.lpenalty = lpenalty;
    this.spenalty = spenalty;
    initMatrices();
    computeRSVD();
  }

  private void initMatrices() {
    this.L = MatrixUtils.createRealMatrix(this.X.getRowDimension(), this.X.getColumnDimension());
    this.S = MatrixUtils.createRealMatrix(this.X.getRowDimension(), this.X.getColumnDimension());
    this.E = MatrixUtils.createRealMatrix(this.X.getRowDimension(), this.X.getColumnDimension());
  }

  private void computeRSVD() {
    double mu = X.getColumnDimension() * X.getRowDimension() / (4 * l1norm(X.getData()));
    double objPrev = 0.5*Math.pow(X.getFrobeniusNorm(), 2);
    double obj = objPrev;
    double tol = 1e-8 * objPrev;
    double diff = 2 * tol;
    int iter = 0;

    while(diff > tol && iter < MAX_ITERS) {
      double nuclearNorm = computeS(mu);
      double l1Norm = computeL(mu);
      double l2Norm = computeE();

      obj = computeObjective(nuclearNorm, l1Norm, l2Norm);
      diff = Math.abs(objPrev - obj);
      objPrev = obj;

      mu = computeDynamicMu();

      iter = iter + 1;
    }
  }

  private double[] softThreshold(double[] x, double penalty) {
    for(int i = 0; i < x.length; i++) {
      x[i] = Math.signum(x[i]) * Math.max(Math.abs(x[i]) - penalty, 0);
    }
    return x;
  }

  private double[][] softThreshold(double[][] x, double penalty) {
    for(int i = 0; i < x.length; i++) {
      for(int j = 0; j < x[i].length; j++) {
        x[i][j] = Math.signum(x[i][j]) * Math.max(Math.abs(x[i][j]) - penalty, 0);
      }
    }
    return x;
  }

  private double sum(double[] x) {
    double sum = 0;
    for (int i = 0; i < x.length; i++)
      sum += x[i];
    return (sum);
  }

  private double l1norm(double[][] x) {
    double l1norm = 0;
    for (int i = 0; i < x.length; i++) {
      for (int j = 0; j < x[i].length; j++) {
        l1norm += Math.abs(x[i][j]);
      }
    }
    return l1norm;
  }

  private double computeL(double mu) {
    double LPenalty = lpenalty * mu;
    SingularValueDecomposition svd = new SingularValueDecomposition(X.subtract(S));
    double[] penalizedD = softThreshold(svd.getSingularValues(), LPenalty);
    RealMatrix D_matrix = MatrixUtils.createRealDiagonalMatrix(penalizedD);
    L = svd.getU().multiply(D_matrix).multiply(svd.getVT());
    return sum(penalizedD) * LPenalty;
  }

  private double computeS(double mu) {
    double SPenalty = spenalty * mu;
    double[][] penalizedS = softThreshold(X.subtract(L).getData(), SPenalty);
    S = MatrixUtils.createRealMatrix(penalizedS);
    return l1norm(penalizedS) * SPenalty;
  }

  private double computeE() {
    E = X.subtract(L).subtract(S);
    double norm = E.getFrobeniusNorm();
    return Math.pow(norm, 2);
  }

  private double computeObjective(double nuclearnorm, double l1norm, double l2norm) {
    return 0.5*l2norm + nuclearnorm + l1norm;
  }

  private double computeDynamicMu() {
    int m = E.getRowDimension();
    int n = E.getColumnDimension();

    double E_sd = standardDeviation(E.getData());
    double mu = E_sd * Math.sqrt(2*Math.max(m,n));

    return Math.max(.01, mu);
  }

	/*private double MedianAbsoluteDeviation(double[][] x) {
		DescriptiveStatistics stats = new DescriptiveStatistics();
		for (int i = 0; i < x.length; i ++)
			for (int j = 0; j < x[i].length; j++)
				stats.addValue(x[i][j]);
		double median = stats.getPercentile(50);

		DescriptiveStatistics absoluteDeviationStats = new DescriptiveStatistics();
		for (int i = 0; i < x.length; i ++)
			for (int j = 0; j < x[i].length; j++)
				absoluteDeviationStats.addValue(Math.abs(x[i][j] - median));

		return absoluteDeviationStats.getPercentile(50) * 1.4826;
	}*/

  private double standardDeviation(double[][] x) {
    DescriptiveStatistics stats = new DescriptiveStatistics();
    for (int i = 0; i < x.length; i ++)
      for (int j = 0; j < x[i].length; j++)
        stats.addValue(x[i][j]);
    return stats.getStandardDeviation();
  }

  public RealMatrix getL() {
    return L;
  }

  public RealMatrix getS() {
    return S;
  }

  public RealMatrix getE() {
    return E;
  }



}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy