All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.carrot2.matrix.factorization.seeding.KMeansSeedingStrategy Maven / Gradle / Ivy


/*
 * Carrot2 project.
 *
 * Copyright (C) 2002-2016, Dawid Weiss, Stanisław Osiński.
 * All rights reserved.
 *
 * Refer to the full license file "carrot2.LICENSE"
 * in the root folder of the repository checkout or at:
 * http://www.carrot2.org/carrot2.LICENSE
 */

package org.carrot2.matrix.factorization.seeding;

import org.carrot2.mahout.math.matrix.*;
import org.carrot2.matrix.factorization.KMeansMatrixFactorization;

/**
 * Matrix seeding based on the k-means algorithms.
 */
public class KMeansSeedingStrategy implements ISeedingStrategy
{
    /** The maximum number of KMeans iterations */
    private int maxIterations;
    private static final int DEFAULT_MAX_ITERATIONS = 5;

    /**
     * Creates the KMeansSeedingStrategy.
     */
    public KMeansSeedingStrategy()
    {
        this(DEFAULT_MAX_ITERATIONS);
    }

    /**
     * Creates the KMeansSeedingStrategy.
     * 
     * @param maxIterations maximum number of KMeans iterations.
     */
    public KMeansSeedingStrategy(int maxIterations)
    {
        this.maxIterations = maxIterations;
    }

    public void seed(DoubleMatrix2D A, DoubleMatrix2D U, DoubleMatrix2D V)
    {
        KMeansMatrixFactorization kMeansMatrixFactorization = new KMeansMatrixFactorization(
                A);
        kMeansMatrixFactorization.setK(U.columns());
        kMeansMatrixFactorization.setMaxIterations(maxIterations);
        kMeansMatrixFactorization.compute();

        U.assign(kMeansMatrixFactorization.getU());
        for (int r = 0; r < U.rows(); r++)
        {
            for (int c = 0; c < U.columns(); c++)
            {
                if (U.getQuick(r, c) < 0.001)
                {
                    U.setQuick(r, c, 0.05);
                }
            }
        }

        V.assign(kMeansMatrixFactorization.getV());
        for (int r = 0; r < V.rows(); r++)
        {
            for (int c = 0; c < V.columns(); c++)
            {
                if (V.getQuick(r, c) == 0)
                {
                    V.setQuick(r, c, 0.05);
                }
            }
        }
    }
    
    public String toString()
    {
        return "KM";
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy