All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.carrot2.matrix.factorization.KMeansMatrixFactorization Maven / Gradle / Ivy

Go to download

Carrot2 search results clustering framework. Minimal functional subset (core algorithms and infrastructure, no document sources).

There is a newer version: 3.16.3
Show newest version

/*
 * Carrot2 project.
 *
 * Copyright (C) 2002-2016, Dawid Weiss, Stanisław Osiński.
 * All rights reserved.
 *
 * Refer to the full license file "carrot2.LICENSE"
 * in the root folder of the repository checkout or at:
 * http://www.carrot2.org/carrot2.LICENSE
 */

package org.carrot2.matrix.factorization;

import org.carrot2.mahout.math.function.Functions;
import org.carrot2.mahout.math.function.Mult;
import org.carrot2.mahout.math.matrix.DoubleMatrix2D;
import org.carrot2.mahout.math.matrix.impl.DenseDoubleMatrix2D;
import org.carrot2.matrix.MatrixUtils;

/**
 * Performs matrix factorization using the K-means clustering algorithm. This kind of
 * factorization is sometimes referred to as Concept Decomposition Factorization.
 */
public class KMeansMatrixFactorization extends IterativeMatrixFactorizationBase
{
    /**
     * Creates the KMeansMatrixFactorization object for matrix A. Before accessing
     * results, perform computations by calling the {@link #compute()} method.
     * 
     * @param A matrix to be factorized. The matrix must have Euclidean length-normalized
     *            columns.
     */
    public KMeansMatrixFactorization(DoubleMatrix2D A)
    {
        super(A);
    }

    public void compute()
    {
        int n = A.columns();

        // Distances to centroids
        DoubleMatrix2D D = new DenseDoubleMatrix2D(k, n);

        // Object-cluster assignments
        V = new DenseDoubleMatrix2D(n, k);

        // Initialize the centroids with some document vectors
        U = new DenseDoubleMatrix2D(A.rows(), k);
        U.assign(A.viewPart(0, 0, A.rows(), k));

        int [] minIndices = new int [D.columns()];
        double [] minValues = new double [D.columns()];

        for (iterationsCompleted = 0; iterationsCompleted < maxIterations; iterationsCompleted++)
        {
            // Calculate cosine distances
            U.zMult(A, D, 1, 0, true, false);

            V.assign(0);
            U.assign(0);

            // For each object
            MatrixUtils.maxInColumns(D, minIndices, minValues);
            for (int i = 0; i < minIndices.length; i++)
            {
                V.setQuick(i, minIndices[i], 1);
            }

            // Update centroids
            for (int c = 0; c < V.columns(); c++)
            {
                // Sum
                int count = 0;
                for (int d = 0; d < V.rows(); d++)
                {
                    if (V.getQuick(d, c) != 0)
                    {
                        count++;
                        U.viewColumn(c).assign(A.viewColumn(d), Functions.PLUS);
                    }
                }

                // Divide
                U.viewColumn(c).assign(Mult.div(count));
                MatrixUtils.normalizeColumnL2(U, null);
            }

        }
    }

    public String toString()
    {
        return "KMMF";
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy