All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.factorization.impl.SVDAdapterEjml Maven / Gradle / Ivy

package com.expleague.ml.factorization.impl;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.MxTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.Pair;
import com.expleague.ml.factorization.Factorization;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.DecompositionFactory;
import org.ejml.interfaces.decomposition.SingularValueDecomposition;
import org.ejml.ops.SingularOps;

/**
 * User: qdeee
 * Date: 12.01.15
 */
public class SVDAdapterEjml implements Factorization {
  private final int factorDim;
  private final boolean needCompact;

  public SVDAdapterEjml(final int factorDim, final boolean needCompact) {
    this.factorDim = factorDim;
    this.needCompact = needCompact;
  }

  public SVDAdapterEjml(final int factorDim) {
    this(factorDim, true);
  }

  public SVDAdapterEjml() {
    this(1, true);
  }

  @Override
  public Pair factorize(final Mx X) {
    final int m = X.rows();
    final int n = X.columns();

    final DenseMatrix64F denseMatrix64F = new DenseMatrix64F(m, n);
    for (int i = 0; i < m; i++) {
      for (int j = 0; j < n; j++) {
        denseMatrix64F.set(i, j, X.get(i, j));
      }
    }

    final SingularValueDecomposition svd = DecompositionFactory.svd(m, n, true, true, needCompact);
    if (!DecompositionFactory.decomposeSafe(svd, denseMatrix64F)) {
      throw new IllegalStateException("Decomposition failed");
    }

    final DenseMatrix64F U = svd.getU(null, false);
    final DenseMatrix64F W = svd.getW(null);
    final DenseMatrix64F V = svd.getV(null, false);
    SingularOps.descendingOrder(U, false, W, V, false);

    if (W.getNumCols() < factorDim) {
      throw new IllegalStateException("Factor xdim is too big for this mx. Try a smaller value (" + Math.min(W.getNumRows(), W.getNumCols()) + ") or disable compact svd mode by setting 'needCompact' = false");
    }

    final Mx u = getSubFromEjmlMatrix(U, 0, 0, m, factorDim);
    final Mx w = getSubFromEjmlMatrix(W, 0, 0, factorDim, factorDim);
    final Mx v = getSubFromEjmlMatrix(V, 0, 0, n, factorDim);

    final Vec mult = MxTools.multiply(u, w);
    return Pair.create(mult, (Vec) v);
  }

  private static Mx getSubFromEjmlMatrix(DenseMatrix64F ejmlMatrix, int iPos, int jPos, int height, int width) {
    final Mx result = new VecBasedMx(height, width);
    for (int i = iPos; i < height; i++) {
      for (int j = jPos; j < width; j++) {
        result.set(i, j, ejmlMatrix.get(i, j));
      }
    }
    return result;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy