org.carrot2.matrix.factorization.LocalNonnegativeMatrixFactorization Maven / Gradle / Ivy
/*
* Carrot2 project.
*
* Copyright (C) 2002-2016, Dawid Weiss, Stanisław Osiński.
* All rights reserved.
*
* Refer to the full license file "carrot2.LICENSE"
* in the root folder of the repository checkout or at:
* http://www.carrot2.org/carrot2.LICENSE
*/
package org.carrot2.matrix.factorization;
import org.carrot2.mahout.math.function.DoubleDoubleFunction;
import org.carrot2.mahout.math.function.DoubleFunction;
import org.carrot2.mahout.math.function.Functions;
import org.carrot2.mahout.math.matrix.DoubleMatrix2D;
import org.carrot2.mahout.math.matrix.impl.DenseDoubleMatrix2D;
import org.carrot2.matrix.MatrixUtils;
/**
* Performs matrix factorization using the Local Non-negative Matrix Factorization
* algorithm with minimization of the Kullback-Leibler divergence between A and UV' and
* multiplicative updating.
*/
public class LocalNonnegativeMatrixFactorization extends IterativeMatrixFactorizationBase
{
/**
* Creates the LocalNonnegativeMatrixFactorization object for matrix A. Before
* accessing results, perform computations by calling the {@link #compute()}method.
*
* @param A matrix to be factorized
*/
public LocalNonnegativeMatrixFactorization(DoubleMatrix2D A)
{
super(A);
}
public void compute()
{
// Prototype Matlab code for the LNMF
//
// function [U, V, C] = lnmf(A)
// [m, n] = size(A);
// k = 2; % the desired number of base vectors
// maxiter = 50; % the number of iterations
// eps = 1e-9; % machine epsilon
//
// U = rand(m, k); % initialize U randomly
// V = rand(n, k); % initialize V randomly
// O = ones(m, m); % a matrix of ones
//
// for iter = 1:maxiter
// V = sqrt( V .* (((A+eps) ./ (U*V'+eps))' * U)); % update V
// U = U .* (((A+eps)./(U*V'+eps)) * V); % update U
// U = U ./ (O * U); % normalise U's columns
// C(1, iter) = norm((A-U*V'), 'fro'); % approximation quality
// end
//
double eps = 1e-9;
// Seed U and V with initial values
U = new DenseDoubleMatrix2D(A.rows(), k);
V = new DenseDoubleMatrix2D(A.columns(), k);
seedingStrategy.seed(A, U, V);
// Temporary matrices
DoubleMatrix2D Aeps = A.copy();
Aeps.assign(Functions.plus(eps));
DoubleMatrix2D UV = new DenseDoubleMatrix2D(A.rows(), A.columns());
DoubleMatrix2D VT = new DenseDoubleMatrix2D(A.columns(), k);
DoubleMatrix2D UT = new DenseDoubleMatrix2D(A.rows(), k);
double [] work = new double [U.columns()];
// Colt functions
DoubleDoubleFunction invDiv = Functions.swapArgs(Functions.DIV);
DoubleDoubleFunction sqrtMult = Functions.chain(Functions.SQRT, Functions.MULT);
DoubleFunction plusEps = Functions.plus(eps);
if (stopThreshold >= 0)
{
updateApproximationError();
}
for (int i = 0; i < maxIterations; i++)
{
// Update V
U.zMult(V, UV, 1, 0, false, true); // UV <- U*V'
UV.assign(plusEps); // UV <- UV + eps
UV.assign(Aeps, invDiv); // UV <- Aeps ./ UV
UV.zMult(U, VT, 1, 0, true, false); // VT <- UV' * U
V.assign(VT, sqrtMult); // V <- sqrt(V .* VT)
// Update U
U.zMult(V, UV, 1, 0, false, true); // UV <- U*V'
UV.assign(plusEps); // UV <- UV + eps
UV.assign(Aeps, invDiv); // UV <- Aeps ./ UV
UV.zMult(V, UT, 1, 0, false, false); // UT <- UV * V
U.assign(UT, Functions.MULT); // U <- U .* UT
MatrixUtils.normalizeColumnL1(U, work);
iterationsCompleted++;
if (stopThreshold >= 0)
{
if (updateApproximationError())
{
break;
}
}
}
if (ordered)
{
order();
}
}
public String toString()
{
return "LNMF-" + seedingStrategy.toString();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy