org.carrot2.matrix.factorization.seeding.KMeansSeedingStrategy Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of carrot2-mini Show documentation
Show all versions of carrot2-mini Show documentation
Carrot2 search results clustering framework. Minimal functional subset
(core algorithms and infrastructure, no document sources).
/*
* Carrot2 project.
*
* Copyright (C) 2002-2019, Dawid Weiss, Stanisław Osiński.
* All rights reserved.
*
* Refer to the full license file "carrot2.LICENSE"
* in the root folder of the repository checkout or at:
* http://www.carrot2.org/carrot2.LICENSE
*/
package org.carrot2.matrix.factorization.seeding;
import org.carrot2.mahout.math.matrix.*;
import org.carrot2.matrix.factorization.KMeansMatrixFactorization;
/**
* Matrix seeding based on the k-means algorithms.
*/
public class KMeansSeedingStrategy implements ISeedingStrategy
{
/** The maximum number of KMeans iterations */
private int maxIterations;
private static final int DEFAULT_MAX_ITERATIONS = 5;
/**
* Creates the KMeansSeedingStrategy.
*/
public KMeansSeedingStrategy()
{
this(DEFAULT_MAX_ITERATIONS);
}
/**
* Creates the KMeansSeedingStrategy.
*
* @param maxIterations maximum number of KMeans iterations.
*/
public KMeansSeedingStrategy(int maxIterations)
{
this.maxIterations = maxIterations;
}
public void seed(DoubleMatrix2D A, DoubleMatrix2D U, DoubleMatrix2D V)
{
KMeansMatrixFactorization kMeansMatrixFactorization = new KMeansMatrixFactorization(
A);
kMeansMatrixFactorization.setK(U.columns());
kMeansMatrixFactorization.setMaxIterations(maxIterations);
kMeansMatrixFactorization.compute();
U.assign(kMeansMatrixFactorization.getU());
for (int r = 0; r < U.rows(); r++)
{
for (int c = 0; c < U.columns(); c++)
{
if (U.getQuick(r, c) < 0.001)
{
U.setQuick(r, c, 0.05);
}
}
}
V.assign(kMeansMatrixFactorization.getV());
for (int r = 0; r < V.rows(); r++)
{
for (int c = 0; c < V.columns(); c++)
{
if (V.getQuick(r, c) == 0)
{
V.setQuick(r, c, 0.05);
}
}
}
}
public String toString()
{
return "KM";
}
}