gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: KMeansClusterer.java
* Authors: Justin Basilico and Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright February 21, 2006, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.clustering;
import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.CodeReviews;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.Cluster;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.ClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.ClusterDivergenceFunction;
import gov.sandia.cognition.learning.algorithm.clustering.initializer.FixedClusterInitializer;
import gov.sandia.cognition.learning.function.distance.DivergenceFunctionContainer;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
/**
* The {@code KMeansClusterer} class implements the standard k-means
* (k-centroids) clustering algorithm.
*
* @param The type of the data to cluster. This is typically defined
* by the divergence function used.
* @param The type of {@code Cluster} created by the algorithm.
* This is typically defined by the cluster creator function used.
* @author Justin Basilico
* @author Kevin R. Dixon
* @since 1.0
*/
@CodeReviews(
reviews =
{
@CodeReview(
reviewer = "Kevin R. Dixon",
date = "2008-10-06",
changesNeeded = true,
comments =
{
"The constructors for this class are not user friendly.",
"I've been trying to write a test GUI for k-means for over an hour and STILL can't figure out the combination of classes to configure the constructor.",
"Please make a constructor that configures the class with meaningful, user-friendly default arguments."
}
),
@CodeReview(
reviewer = "Kevin R. Dixon",
date = "2008-07-22",
changesNeeded = false,
comments =
{
"Changed the condition to be 'members.size() > 0' instead of 1 in createClustersFromAssignments()",
"Cleaned up javadoc.",
"Code generally looks fine."
}
)
}
)
@PublicationReferences(
references =
{
@PublicationReference(
author = "Wikipedia",
title = "K-means algorithm",
type = PublicationType.WebPage,
year = 2008,
url = "http://en.wikipedia.org/wiki/K-means_algorithm"
),
@PublicationReference(
author = "Matteo Matteucci",
title = "A Tutorial on Clustering Algorithms: k-means Demo",
type = PublicationType.WebPage,
year = 2008,
url
= "http://home.dei.polimi.it/matteucc/Clustering/tutorial_html/AppletKM.html"
)
}
)
public class KMeansClusterer>
extends AbstractAnytimeBatchLearner, Collection>
implements BatchClusterer,
MeasurablePerformanceAlgorithm,
DivergenceFunctionContainer
{
/**
* The default number of requested clusters is {@value}.
*/
public static final int DEFAULT_NUM_REQUESTED_CLUSTERS = 10;
/**
* The default maximum number of iterations is {@value}.
*/
public static final int DEFAULT_MAX_ITERATIONS = 1000;
/**
* The number of clusters requested.
*/
protected int numRequestedClusters;
/**
* The initializer for the algorithm.
*/
protected FixedClusterInitializer initializer;
/**
* The divergence function between cluster being used.
*/
protected ClusterDivergenceFunction super ClusterType, ? super DataType> divergenceFunction;
/**
* The cluster creator for creating clusters.
*/
private ClusterCreator creator;
/**
* The current set of clusters.
*/
protected ArrayList clusters;
/**
* The current assignments of elements to clusters.
*/
protected int[] assignments;
/**
* The current number of elements assigned to each cluster.
*/
protected int[] clusterCounts;
/**
* Returns the number of samples that changed assignment between iterations
*/
private int numChanged;
/**
* Creates a new instance of {@code KMeansClusterer} with default
* parameters.
*/
public KMeansClusterer()
{
this(DEFAULT_NUM_REQUESTED_CLUSTERS, DEFAULT_MAX_ITERATIONS,
null, null, null);
}
/**
* Creates a new instance of KMeansClusterer using the given parameters.
*
* @param numRequestedClusters The number of clusters requested (k).
* @param maxIterations Maximum number of iterations before stopping
* @param initializer The initializer for the clusters.
* @param divergenceFunction The divergence function.
* @param creator The cluster creator.
*/
public KMeansClusterer(
int numRequestedClusters,
int maxIterations,
FixedClusterInitializer initializer,
ClusterDivergenceFunction super ClusterType, ? super DataType> divergenceFunction,
ClusterCreator creator)
{
super(maxIterations);
this.setNumRequestedClusters(numRequestedClusters);
this.setInitializer(initializer);
this.setDivergenceFunction(divergenceFunction);
this.setCreator(creator);
}
@Override
public KMeansClusterer clone()
{
@SuppressWarnings("unchecked")
final KMeansClusterer result
= (KMeansClusterer) super.clone();
result.initializer = ObjectUtil.cloneSmart(this.initializer);
result.divergenceFunction = ObjectUtil.cloneSmart(
this.divergenceFunction);
result.creator = ObjectUtil.cloneSmart(this.creator);
result.clusters = null;
result.assignments = null;
result.clusterCounts = null;
return result;
}
@Override
protected boolean initializeAlgorithm()
{
// Set the cluster state variables.
this.setClusters(this.initializer.initializeClusters(
this.numRequestedClusters, this.getData()));
this.setClusterCounts(new int[this.getNumClusters()]);
this.setAssignments(new int[this.getNumElements()]);
Arrays.fill(this.assignments, -1);
Arrays.fill(this.clusterCounts, 0);
this.setNumChanged(0);
// we can only run k-means if we have at least as many datapoints as
// clusters we are requested to find.
return this.getNumClusters() <= this.getNumElements();
}
/**
* Do a step of the clustering algorithm.
*
* @return true means keep going, false means stop clustering.
*/
@Override
protected boolean step()
{
// First, assign each data point to a cluster, given the current
// location of the clusters
int[] newAssignements = this.assignDataToClusters(this.getData());
int nc = 0;
for (int i = 0; i < newAssignements.length; i++)
{
final int newAssignment = newAssignements[i];
if (this.setAssignment(i, newAssignment))
{
nc++;
}
}
this.setNumChanged(nc);
// There was a change so create the clusters and keep going.
if (this.getNumChanged() > 0)
{
// Now, re-estimate the cluster locations, given the current
// assignments of the data points
this.createClustersFromAssignments();
return true;
}
// If the cluster assignments didn't change, then we're done
else
{
return false;
}
}
@Override
protected void cleanupAlgorithm()
{
}
/**
* Creates the cluster assignments given the current locations of clusters
*
* @param data Data to assign
* @return Assignments of the data to each of the k-clusters
*/
protected int[] assignDataToClusters(
Collection extends DataType> data)
{
// Loop through the elements and find the closest cluster for each.
int i = 0;
int[] localAssignments = new int[ data.size() ];
for (DataType element : data)
{
// Get the i-th element and find the index of the closest cluster
// to it.
localAssignments[i] = this.getClosestClusterIndex(element);
i++;
}
return localAssignments;
}
@Override
public void setData(
Collection extends DataType> data)
{
super.setData(data);
}
/**
* Puts the data into a list of lists for each cluster to then estimate
*
* @return The list of lists for each cluster to then estimate
*/
protected ArrayList> assignDataFromIndices()
{
// Loop through the clusters and initialize their membership lists
// based on who is in them.
int numClusters = this.getNumClusters();
ArrayList> clustersMembers = new ArrayList<>(
numClusters);
for (int i = 0; i < numClusters; i++)
{
int clusterSize = this.clusterCounts[i];
clustersMembers.add(new ArrayList<>(clusterSize));
}
// Go through and add each element to its proper cluster based on
// the current assignments.
int index = 0;
for (DataType element : this.getData())
{
int assignment = this.assignments[index];
clustersMembers.get(assignment).add(element);
index++;
}
return clustersMembers;
}
/**
* Creates the set of clusters using the current cluster assignments.
*/
protected void createClustersFromAssignments()
{
// Loop through the clusters and initialize their membership lists
// based on who is in them.
final ArrayList> clustersMembers
= this.assignDataFromIndices();
// Create the clusters from their memberships.
int clusterIndex = 0;
for (final ArrayList members : clustersMembers)
{
final ClusterType cluster;
if (members.size() > 0)
{
cluster = this.creator.createCluster(members);
}
else
{
cluster = null;
}
this.clusters.set(clusterIndex, cluster);
clusterIndex++;
}
}
/**
* Gets the index of the closest cluster for the given element.
*
* @param element The element to get the closet cluster for.
* @return The index of the closest cluster.
*/
protected int getClosestClusterIndex(
DataType element)
{
// Find the closest cluster.
double minDistance = Double.MAX_VALUE;
int closestClusterIndex = -1;
// Loop over all the clusters.
for (int i = 0; i < this.getNumClusters(); i++)
{
// Get the i-th cluster.
ClusterType cluster = this.clusters.get(i);
if (cluster != null)
{
// Compute the distance to the i-th cluster.
double distance = this.divergenceFunction.evaluate(cluster,
element);
if (closestClusterIndex < 0 || distance < minDistance)
{
// This is the closest so far.
minDistance = distance;
closestClusterIndex = i;
}
// else - There is already a closer cluster.
}
// else - Ignore empty clusters.
}
// Return the index of the closest cluster.
return closestClusterIndex;
}
/**
* Sets the assignment of the given element to the new cluster index,
* updating the cluster counts as well.
*
* @param elementIndex The index of the element.
* @param newClusterIndex The new cluster the element is assigned to.
* @return True if the assignment changed. Otherwise, false.
*/
protected boolean setAssignment(
int elementIndex,
int newClusterIndex)
{
// Save the old assignment.
int oldClusterIndex = this.assignments[elementIndex];
// Set the new assignment.
this.assignments[elementIndex] = newClusterIndex;
if (oldClusterIndex >= 0)
{
// Decrement the counter for the old assignment since the element
// is no longer in that cluster.
this.clusterCounts[oldClusterIndex]--;
}
if (newClusterIndex >= 0)
{
// Increment the counter for the new assignment since the element
// is now in that cluster.
this.clusterCounts[newClusterIndex]++;
}
return newClusterIndex != oldClusterIndex;
}
/**
* Gets the cluster for the given index.
*
* @param index The index of the cluster.
* @return The cluster for the given index.
*/
protected ClusterType getCluster(
int index)
{
return this.clusters.get(index);
}
/**
* Gets the actual number of clusters that were created.
*
* @return The actual number of clusters.
*/
protected int getNumClusters()
{
return (this.getClusters() == null) ? 0 : this.getClusters().size();
}
/**
* Gets the number of clusters that were requested.
*
* @return The number of clusters that were requested.
*/
public int getNumRequestedClusters()
{
return this.numRequestedClusters;
}
/**
* Gets the cluster initializer.
*
* @return The cluster initializer.
*/
public FixedClusterInitializer getInitializer()
{
return this.initializer;
}
/**
* Gets the divergence function used in clustering.
*
* @return The divergence function.
*/
@Override
public ClusterDivergenceFunction super ClusterType, ? super DataType>
getDivergenceFunction()
{
return this.divergenceFunction;
}
/**
* Gets the cluster creator.
*
* @return The cluster creator.
*/
public ClusterCreator getCreator()
{
return this.creator;
}
/**
* Sets the number of requested clusters.
*
* @param numRequestedClusters The number of requested clusters.
*/
public void setNumRequestedClusters(
int numRequestedClusters)
{
if (numRequestedClusters < 0)
{
// Error: Bad number of clusters requested.
throw new IllegalArgumentException(
"The number of clusters cannot be less than zero.");
}
this.numRequestedClusters = numRequestedClusters;
}
/**
* Sets the cluster initializer.
*
* @param initializer The cluster initializer.
*/
public void setInitializer(
FixedClusterInitializer initializer)
{
this.initializer = initializer;
}
/**
* Sets the divergence function.
*
* @param divergenceFunction The divergence function.
*/
public void setDivergenceFunction(
ClusterDivergenceFunction super ClusterType, ? super DataType> divergenceFunction)
{
this.divergenceFunction = divergenceFunction;
}
/**
* Sets the cluster creator.
*
* @param creator The creator for clusters.
*/
public void setCreator(
ClusterCreator creator)
{
this.creator = creator;
}
/**
* Returns the number of elements
*
* @return number of elements being clustered
*/
public int getNumElements()
{
if (this.getData() != null)
{
return this.getData().size();
}
else
{
return 0;
}
}
/**
* Sets the clusters.
*
* @param clusters The clusters.
*/
protected void setClusters(
ArrayList clusters)
{
this.clusters = clusters;
}
/**
* Getter for clusters
*
* @return list of clusters in the algorithm
*/
public ArrayList getClusters()
{
return this.clusters;
}
@Override
public ArrayList getResult()
{
return this.getClusters();
}
/**
* Sets the assignment of elements to clusters.
*
* @param assignments The new assignments.
*/
private void setAssignments(
int[] assignments)
{
this.assignments = assignments;
}
/**
* Getter for assignments
*
* @return The assignment of elements to clusters
*/
protected int[] getAssignments()
{
return this.assignments;
}
/**
* Sets the counts for how many elements are in each cluster.
*
* @param clusterCounts The new cluster counts.
*/
private void setClusterCounts(
int[] clusterCounts)
{
this.clusterCounts = clusterCounts;
}
/**
* Getter for clusterCounts
*
* @return counts for how many elements are assigned to each cluster
*/
protected int[] getClusterCounts()
{
return this.clusterCounts;
}
/**
* Getter for numChanged
*
* @return Returns the number of samples that changed assignment between
* iterations
*/
public int getNumChanged()
{
return this.numChanged;
}
/**
* Setter for numChanged
*
* @param numChanged Returns the number of samples that changed assignment
* between iterations
*/
protected void setNumChanged(
int numChanged)
{
this.numChanged = numChanged;
}
/**
* Gets the performance, which is the number changed on the last iteration.
*
* @return The performance of the algorithm.
*/
@Override
public NamedValue getPerformance()
{
return new DefaultNamedValue<>(
"Assignments changed", this.getNumChanged());
}
}