gov.sandia.cognition.learning.algorithm.clustering.AgglomerativeClusterer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: AgglomerativeClusterer.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright June 28, 2006, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.clustering;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.CodeReviewResponse;
import gov.sandia.cognition.learning.algorithm.clustering.hierarchy.BatchHierarchicalClusterer;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.algorithm.clustering.hierarchy.BinaryClusterHierarchyNode;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.Cluster;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.ClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.hierarchy.ClusterHierarchyNode;
import gov.sandia.cognition.learning.algorithm.clustering.hierarchy.DefaultClusterHierarchyNode;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.ClusterToClusterDivergenceFunction;
import gov.sandia.cognition.learning.function.distance.DivergenceFunctionContainer;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
/**
* The {@code AgglomerativeClusterer} implements an agglomerative clustering
* algorithm, which is a type of hierarchical clustering algorithm.
* Such a clustering algorithm works by initially creating one
* cluster for each element in the collection to cluster and then
* repeatedly merging the two closest clusters until the stopping
* condition is met or there is only one cluster remaining. This
* implementation supports multiple methods for determining the
* distance between two clusters by supplying an
* {@code ClusterToClusterDivergenceFunction} object. There are two stopping
* conditions for the algorithm, which are parameters that can be set. The first
* is that the clustering will stop when some minimum number of
* clusters is reached, which defaults to 1. The second criteria is
* that the clustering will stop when the distance between the two
* closest clusters is larger than a given value. This threshold can
* be used to create clusters when the number of clusters is not
* known ahead of time.
*
* @param The type of the data to cluster. This is typically
* defined by the divergence function used.
* @param The type of {@code Cluster} created by the algorithm.
* This is typically defined by the cluster creator function used.
* @author Justin Basilico
* @since 1.0
*/
@CodeReview(
reviewer="Kevin R. Dixon",
date="2008-07-22",
changesNeeded=true,
comments={
"I *really* don't like the use of 'continue', but I will defer.",
"Please implement the sections previously marked as 'to do'"
},
response=@CodeReviewResponse(
respondent="Justin Basilico",
date="2008-10-07",
moreChangesNeeded=false,
comments="The clusterer now supports hierarchical clustering."
)
)
public class AgglomerativeClusterer
>
extends AbstractAnytimeBatchLearner
, Collection>
implements BatchClusterer,
BatchHierarchicalClusterer,
DivergenceFunctionContainer
{
/** The default minimum number of clusters is {@value}. */
public static final int DEFAULT_MIN_NUM_CLUSTERS = 1;
/** The default maximum minimum distance is {@value}. */
public static final double DEFAULT_MAX_MIN_DISTANCE = Double.MAX_VALUE;
/** The default maximum number of iterations {@value} */
public static final int DEFAULT_MAX_ITERATIONS = Integer.MAX_VALUE;
/**
* The divergence function used to find the distance between two clusters.
*/
protected ClusterToClusterDivergenceFunction super ClusterType, ? super DataType>
divergenceFunction;
/** The merger used to merge two clusters into one element. */
protected ClusterCreator creator;
/** The minimum number of clusters allowed. */
protected int minNumClusters;
/** The maximum minimum distance between clusters allowed. */
protected double maxMinDistance;
/** The current set of clusters. */
protected ArrayList clusters;
/** The current set of hierarchical clusters. */
protected ArrayList>
clustersHierarchy;
/**
* An array list mapping the cached minimum distance from the cluster with
* the given index to any other clusters.
*/
protected transient ArrayList minDistances;
/**
* The array of indexes that maps the cluster index to the closest cluster.
*/
protected transient ArrayList minClusters;
/**
* Creates a new instance of AgglomerativeClusterer.
*/
public AgglomerativeClusterer()
{
this(null, null);
}
/**
* Initializes the clustering to use the given metric between
* clusters, and the given cluster creator. The minimum number of
* clusters will be set to 1.
*
* @param divergenceFunction The distance metric between clusters.
* @param creator The method for creating clusters.
*/
public AgglomerativeClusterer(
final ClusterToClusterDivergenceFunction super ClusterType, ? super DataType>
divergenceFunction,
final ClusterCreator creator)
{
this(divergenceFunction, creator, DEFAULT_MIN_NUM_CLUSTERS);
}
/**
* Initializes the clustering to use the given metric between
* clusters, the given cluster creator, and the minimum number of
* clusters to allow.
*
* @param divergenceFunction The distance metric between clusters.
* @param creator The method for creating clusters.
* @param minNumClusters The minimum number of clusters to allow. Must
* be greater than zero.
*/
public AgglomerativeClusterer(
final ClusterToClusterDivergenceFunction super ClusterType, ? super DataType>
divergenceFunction,
ClusterCreator creator,
int minNumClusters)
{
this(divergenceFunction, creator, minNumClusters,
DEFAULT_MAX_MIN_DISTANCE);
}
/**
* Initializes the clustering to use the given metric between
* clusters, the given cluster merger, and the maximum minimum
* distance between clusters to allow.
*
* @param divergenceFunction The distance metric between clusters.
* @param creator The method for creating clusters.
* @param maxMinDistance The maximum minimum distance between clusters
* to allow.
*/
public AgglomerativeClusterer(
ClusterToClusterDivergenceFunction super ClusterType, ? super DataType>
divergenceFunction,
ClusterCreator creator,
double maxMinDistance)
{
this(divergenceFunction, creator, 1, maxMinDistance);
}
/**
* Initializes the clustering to use the given metric between
* clusters, the given cluster merger, the minimum number of
* clusters to allow, and the maximum minimum distance between
* clusters to allow.
*
* @param divergenceFunction The distance metric between clusters.
* @param creator The method for creating clusters.
* @param minNumClusters The minimum number of clusters to allow. Must
* be greater than zero.
* @param maxMinDistance The maximum minimum distance between clusters
* to allow.
*/
public AgglomerativeClusterer(
ClusterToClusterDivergenceFunction super ClusterType, ? super DataType>
divergenceFunction,
ClusterCreator creator,
int minNumClusters,
double maxMinDistance)
{
super(DEFAULT_MAX_ITERATIONS);
this.setDivergenceFunction(divergenceFunction);
this.setCreator(creator);
this.setMinNumClusters(minNumClusters);
this.setMaxMinDistance(maxMinDistance);
this.setClusters(null);
this.setClustersHierarchy(null);
this.setMinDistances(null);
this.setMinClusters(null);
}
@Override
public AgglomerativeClusterer clone()
{
@SuppressWarnings("unchecked")
final AgglomerativeClusterer result =
(AgglomerativeClusterer) super.clone();
result.divergenceFunction = ObjectUtil.cloneSmart(this.divergenceFunction);
result.creator = ObjectUtil.cloneSmart(this.creator);
result.clusters = null;
result.clustersHierarchy = null;
result.minDistances = null;
result.minClusters = null;
return result;
}
public ClusterHierarchyNode clusterHierarchically(
Collection extends DataType> data)
{
// Turn off the stopping criteria to do with the minimum number of
// clusters or the maximum minimum distance.
final int tempMinNumClusters = this.getMinNumClusters();
final double tempMaxMinDistance = this.getMaxMinDistance();
this.setMinNumClusters(1);
this.setMaxMinDistance(Double.MAX_VALUE);
this.learn(data);
this.setMinNumClusters(tempMinNumClusters);
this.setMaxMinDistance(tempMaxMinDistance);
if (CollectionUtil.isEmpty(this.clustersHierarchy))
{
// No clusters.
return null;
}
else if (this.clustersHierarchy.size() == 1)
{
// Get the root of the hierarchy.
return this.clustersHierarchy.get(0);
}
else
{
// This should really never happen, but it is possible that
// clustering got stopped early. If that is the case, we bind
// together all the clusters into one root node.
final DefaultClusterHierarchyNode root =
new DefaultClusterHierarchyNode();
// Set the children.
root.setChildren(
new ArrayList>(
this.clustersHierarchy));
return root;
}
}
protected boolean initializeAlgorithm()
{
// Create the arrays to store the cluster information.
int numElements = this.data.size();
// Initialize our data structures.
this.setClusters(new ArrayList(numElements));
this.setClustersHierarchy(
new ArrayList>(
numElements));
this.setMinDistances(new ArrayList(numElements));
this.setMinClusters(new ArrayList(numElements));
// Initialize one cluster for each element.
for (DataType element : this.data)
{
// Create the cluster object for the element.
LinkedList singleton = new LinkedList();
singleton.add(element);
ClusterType cluster = this.creator.createCluster(singleton);
// Add the cluster.
this.clusters.add(cluster);
this.clustersHierarchy.add(
new HierarchyNode(cluster));
this.minDistances.add(Double.MAX_VALUE);
this.minClusters.add(-1);
}
// Initialize the minimum distance calculation for each cluster.
for (int i = 0; i < this.getNumClusters(); i++)
{
this.updateMinDistance(i);
}
return true;
}
protected boolean step()
{
if (this.getNumClusters() <= this.minNumClusters)
{
// Make sure we haven't violated the minimum number of clusters.
return false;
}
// Find the two clusters that are the closest together.
double minDistance = Double.MAX_VALUE;
int minFrom = -1;
int minTo = -1;
for (int i = 0; i < this.getNumClusters(); i++)
{
// Get the distance to the closest cluster for this
// cluster.
// ClusterType cluster = this.clusters.get(i);
double distance = (double) minDistances.get(i);
if (minFrom < 0 || distance < minDistance)
{
// This is the smallest distance seen so far so store
// the information about the cluster.
minDistance = distance;
minFrom = i;
minTo = this.minClusters.get(i);
}
}
if (minDistance > this.maxMinDistance)
{
// The minimum distance clusters are too far apart
// to merge.
return false;
}
// Merge the two clusters into one.
int mergedIndex = this.mergeClusters(minFrom, minTo, minDistance);
ClusterType merged = this.clusters.get(mergedIndex);
// Update the minimum distance for the merged cluster.
this.updateMinDistance(mergedIndex);
// Update the cached minimum distances for the other
// clusters.
for (int i = 0; i < this.getNumClusters(); i++)
{
ClusterType other = this.clusters.get(i);
if (other == merged)
{
// Don't update the new cluster.
continue;
}
// Get the current minimum cluster.
int minClusterIndex = this.minClusters.get(i);
// ClusterType minCluster = this.clusters.get(minClusterIndex);
if (minClusterIndex == minTo || minClusterIndex == minFrom)
{
// The minimum distance was to the cluster we just
// merged, so we need to do a complete update on the
// distances for it.
this.updateMinDistance(i);
}
else
{
// Get the current minimum distance.
double distance =
this.divergenceFunction.evaluate(other, merged);
if (distance < (double) minDistances.get(i))
{
// The new cluster is the closest.
this.minDistances.set(i, distance);
this.minClusters.set(i, mergedIndex);
}
}
}
return this.getNumClusters() > this.minNumClusters;
}
protected void cleanupAlgorithm()
{
this.setMinDistances(null);
this.setMinClusters(null);
}
/**
* Updates the cached minimum distance for this cluster by
* comparing it to all the other clusters.
*
* @param index The cluster to update.
*/
protected void updateMinDistance(
int index)
{
// Search for the closest cluster to this cluster.
ClusterType cluster = this.clusters.get(index);
double minDistance = Double.MAX_VALUE;
int minCluster = -1;
for (int i = 0; i < this.getNumClusters(); i++)
{
ClusterType other = this.clusters.get(i);
if (cluster == other)
{
// Don't compute the distance to self, since it will be
// zero for a valid distance metric.
continue;
}
// Compute the distance.
double distance = this.divergenceFunction.evaluate(cluster, other);
if (minCluster < 0 || distance < minDistance)
{
// This is the closest one found so far so save it.
minDistance = distance;
minCluster = i;
}
}
// Save the closest cluster found to this one.
this.minDistances.set(index, minDistance);
this.minClusters.set(index, minCluster);
}
/**
* Merges two clusters together, creating a new BinaryTreeCluster
* and updating the internal state.
*
* @param firstIndex The first cluster.
* @param secondIndex The second cluster.
* @param distance The distance between the clusters.
* @return The new, merged cluster.
*/
protected int mergeClusters(
int firstIndex,
int secondIndex,
double distance)
{
// Get the two clusters.
ClusterType first = this.clusters.get(firstIndex);
ClusterType second = this.clusters.get(secondIndex);
// Figure out the larger and smaller indices of the two given
// clusters.
int minIndex = Math.min(firstIndex, secondIndex);
int maxIndex = Math.max(firstIndex, secondIndex);
// Create a list of all the members of the clusters.
ArrayList members = new ArrayList();
members.addAll(first.getMembers());
members.addAll(second.getMembers());
// If we have the ability to merge the clusters, merge them.
ClusterType merged = this.creator.createCluster(members);
// Create the new parent cluster for the two that are merged.
HierarchyNode firstChild =
this.clustersHierarchy.get(firstIndex);
HierarchyNode secondChild =
this.clustersHierarchy.get(secondIndex);
HierarchyNode parent =
new HierarchyNode(
merged, firstChild, secondChild, distance);
// Move the cluster at the end of the list to the larger index of
// the two clusters to merge in order to remove one element from
// the list.
int endClusterNum = this.clusters.size() - 1;
if (endClusterNum != maxIndex)
{
ClusterType endCluster = this.clusters.get(endClusterNum);
this.clusters.set(maxIndex, endCluster);
this.minDistances.set(maxIndex, this.minDistances.get(endClusterNum));
this.minClusters.set(maxIndex, this.minClusters.get(endClusterNum));
// TODO: Make the minClusters array not store an index but instead a pointer so
// that we don't have to do this update step.
// Move all the pointers to the end cluster.
for (int i = 0; i < this.getNumClusters(); i++)
{
if (endClusterNum == this.minClusters.get(i))
{
this.minClusters.set(i, maxIndex);
}
}
}
// else - The end cluster is the one we are removing.
// Store the information about the parent.
int newIndex = minIndex;
this.clusters.set(newIndex, merged);
this.clustersHierarchy.set(newIndex, parent);
this.minDistances.set(newIndex, Double.MAX_VALUE);
this.minClusters.set(newIndex, null);
// Remove the last element from the list.
this.clusters.remove(endClusterNum);
this.clustersHierarchy.remove(endClusterNum);
this.minDistances.remove(endClusterNum);
this.minClusters.remove(endClusterNum);
// Return the new cluster that we just created.
return newIndex;
}
public ArrayList getResult()
{
return this.clusters;
}
/**
* Gets the number of clusters.
*
* @return The number of clusters.
*/
public int getNumClusters()
{
if (this.clusters == null)
{
return 0;
}
else
{
return this.clusters.size();
}
}
/**
* Gets the divergence function used in clustering.
*
* @return The divergence function.
*/
public ClusterToClusterDivergenceFunction super ClusterType, ? super DataType>
getDivergenceFunction()
{
return this.divergenceFunction;
}
/**
* Sets the divergence function.
*
* @param divergenceFunction The divergence function.
*/
public void setDivergenceFunction(
ClusterToClusterDivergenceFunction super ClusterType, ? super DataType>
divergenceFunction)
{
this.divergenceFunction = divergenceFunction;
}
/**
* Gets the cluster creator.
*
* @return The cluster creator.
*/
public ClusterCreator getCreator()
{
return this.creator;
}
/**
* Sets the cluster creator.
*
* @param creator The creator for clusters.
*/
public void setCreator(
ClusterCreator creator)
{
this.creator = creator;
}
/**
* The minimum number of clusters to allow. To create a cluster tree,
* set this value to 1. If the number of clusters drops to this number
* (or below) then the clustering will stop. Must be greater than
* zero.
*
* @return The minimum number of clusters allowed.
*/
public int getMinNumClusters()
{
return this.minNumClusters;
}
/**
* The minimum number of clusters to allow. To create a cluster tree,
* set this value to 1. If the number of clusters drops to this number
* (or below) then the clustering will stop. Must be greater than
* zero.
*
* @param minNumClusters The new minimum number of clusters.
*/
public void setMinNumClusters(
int minNumClusters)
{
this.minNumClusters = Math.max(1, minNumClusters);
}
/**
* The maximum minimum distance between clusters that is allowed
* for the two clusters to be merged. If there are no clusters
* that remain that have a distance between them less than or
* equal to this value, then the clustering will halt. To not
* have this value factored into the clustering, set it to
* something such as Double.MAX_VALUE.
*
* @return The maximum minimum distance between clusters.
*/
public double getMaxMinDistance()
{
return this.maxMinDistance;
}
/**
* The maximum minimum distance between clusters that is allowed
* for the two clusters to be merged. If there are no clusters
* that remain that have a distance between them less than or
* equal to this value, then the clustering will halt. To not
* have this value factored into the clustering, set it to
* something such as Double.MAX_VALUE.
*
* @param maxMinDistance The new maximum minimum distance.
*/
public void setMaxMinDistance(
double maxMinDistance)
{
this.maxMinDistance = maxMinDistance;
}
/**
* Sets the clusters.
*
* @param clusters The clusters.
*/
protected void setClusters(
ArrayList clusters)
{
this.clusters = clusters;
}
/**
* Gets the hierarchy of clusters.
*
* @return The hierarchy of clusters.
*/
public ArrayList>
getClustersHierarchy()
{
return clustersHierarchy;
}
/**
* Sets the hierarchy of clusters.
*
* @param clustersHierarchy The hierarchy of clusters.
*/
protected void setClustersHierarchy(
final ArrayList>
clustersHierarchy)
{
this.clustersHierarchy = clustersHierarchy;
}
/**
* Sets the minimum distances for each clusters.
*
* @param minDistances The array of minimum distances.
*/
protected void setMinDistances(
ArrayList minDistances)
{
this.minDistances = minDistances;
}
/**
* Sets the index of the closest cluster.
*
* @param minClusters The array of closest cluster indices.
*/
protected void setMinClusters(
ArrayList minClusters)
{
this.minClusters = minClusters;
}
/**
* Holds the hierarchy information for the agglomerative clusterer. It
* is a binary node that also keeps track of the divergence between
* children.
*
* @param
* The type of the data being clustered.
* @param
* The type of the clusters being created.
*/
public static class HierarchyNode>
extends BinaryClusterHierarchyNode
{
/** The divergence between the two children, if they exist. */
protected double childrenDivergence;
/**
* Creates a new {@code HierarchyNode}.
*/
public HierarchyNode()
{
this(null);
}
/**
* Creates a new {@code HierarchyNode}.
*
* @param cluster
* The cluster associated with the node.
*/
public HierarchyNode(
final ClusterType cluster)
{
this(cluster, null, null, 0.0);
}
/**
* Creates a new {@code HierarchyNode}.
*
* @param cluster
* The cluster associated with the node.
* @param firstChild
* The first child.
* @param secondChild
* The second child.
* @param childrenDivergence
* The divergence between the children.
*/
public HierarchyNode(
final ClusterType cluster,
final HierarchyNode firstChild,
final HierarchyNode secondChild,
final double childrenDivergence)
{
super(cluster, firstChild, secondChild);
this.setChildrenDivergence(childrenDivergence);
}
/**
* Gets the divergence between the two children, if they exist;
* otherwise, 0.0.
*
* @return The divergence between the two children, if they exist.
*/
public double getChildrenDivergence()
{
return this.childrenDivergence;
}
/**
* Sets the divergence between the two children.
*
* @param childrenDivergence
* The divergence between the two children.
*/
public void setChildrenDivergence(
final double childrenDivergence)
{
this.childrenDivergence = childrenDivergence;
}
}
}