All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                KMeansClusterer.java
 * Authors:             Justin Basilico and Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright February 21, 2006, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.learning.algorithm.clustering;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.CodeReviews;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.Cluster;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.ClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.ClusterDivergenceFunction;
import gov.sandia.cognition.learning.algorithm.clustering.initializer.FixedClusterInitializer;
import gov.sandia.cognition.learning.function.distance.DivergenceFunctionContainer;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;

/**
 * The {@code KMeansClusterer} class implements the standard k-means
 * (k-centroids) clustering algorithm.
 *
 * @param  The type of the data to cluster. This is typically defined
 * by the divergence function used.
 * @param  The type of {@code Cluster} created by the algorithm.
 * This is typically defined by the cluster creator function used.
 * @author Justin Basilico
 * @author Kevin R. Dixon
 * @since 1.0
 */
@CodeReviews(
    reviews =
    {
        @CodeReview(
            reviewer = "Kevin R. Dixon",
            date = "2008-10-06",
            changesNeeded = true,
            comments =
            {
                "The constructors for this class are not user friendly.",
                "I've been trying to write a test GUI for k-means for over an hour and STILL can't figure out the combination of classes to configure the constructor.",
                "Please make a constructor that configures the class with meaningful, user-friendly default arguments."
            }
        ),
        @CodeReview(
            reviewer = "Kevin R. Dixon",
            date = "2008-07-22",
            changesNeeded = false,
            comments =
            {
                "Changed the condition to be 'members.size() > 0' instead of 1 in createClustersFromAssignments()",
                "Cleaned up javadoc.",
                "Code generally looks fine."
            }
        )
    }
)
@PublicationReferences(
    references =
    {
        @PublicationReference(
            author = "Wikipedia",
            title = "K-means algorithm",
            type = PublicationType.WebPage,
            year = 2008,
            url = "http://en.wikipedia.org/wiki/K-means_algorithm"
        ),
        @PublicationReference(
            author = "Matteo Matteucci",
            title = "A Tutorial on Clustering Algorithms: k-means Demo",
            type = PublicationType.WebPage,
            year = 2008,
            url
            = "http://home.dei.polimi.it/matteucc/Clustering/tutorial_html/AppletKM.html"
        )
    }
)
public class KMeansClusterer>
    extends AbstractAnytimeBatchLearner, Collection>
    implements BatchClusterer,
    MeasurablePerformanceAlgorithm,
    DivergenceFunctionContainer
{

    /**
     * The default number of requested clusters is {@value}.
     */
    public static final int DEFAULT_NUM_REQUESTED_CLUSTERS = 10;

    /**
     * The default maximum number of iterations is {@value}.
     */
    public static final int DEFAULT_MAX_ITERATIONS = 1000;

    /**
     * The number of clusters requested.
     */
    protected int numRequestedClusters;

    /**
     * The initializer for the algorithm.
     */
    protected FixedClusterInitializer initializer;

    /**
     * The divergence function between cluster being used.
     */
    protected ClusterDivergenceFunction divergenceFunction;

    /**
     * The cluster creator for creating clusters.
     */
    private ClusterCreator creator;

    /**
     * The current set of clusters.
     */
    protected ArrayList clusters;

    /**
     * The current assignments of elements to clusters.
     */
    protected int[] assignments;

    /**
     * The current number of elements assigned to each cluster.
     */
    protected int[] clusterCounts;

    /**
     * Returns the number of samples that changed assignment between iterations
     */
    private int numChanged;

    /**
     * Creates a new instance of {@code KMeansClusterer} with default
     * parameters.
     */
    public KMeansClusterer()
    {
        this(DEFAULT_NUM_REQUESTED_CLUSTERS, DEFAULT_MAX_ITERATIONS,
            null, null, null);
    }

    /**
     * Creates a new instance of KMeansClusterer using the given parameters.
     *
     * @param numRequestedClusters The number of clusters requested (k).
     * @param maxIterations Maximum number of iterations before stopping
     * @param initializer The initializer for the clusters.
     * @param divergenceFunction The divergence function.
     * @param creator The cluster creator.
     */
    public KMeansClusterer(
        int numRequestedClusters,
        int maxIterations,
        FixedClusterInitializer initializer,
        ClusterDivergenceFunction divergenceFunction,
        ClusterCreator creator)
    {
        super(maxIterations);

        this.setNumRequestedClusters(numRequestedClusters);
        this.setInitializer(initializer);
        this.setDivergenceFunction(divergenceFunction);
        this.setCreator(creator);
    }

    @Override
    public KMeansClusterer clone()
    {
        @SuppressWarnings("unchecked")
        final KMeansClusterer result
            = (KMeansClusterer) super.clone();
        result.initializer = ObjectUtil.cloneSmart(this.initializer);
        result.divergenceFunction = ObjectUtil.cloneSmart(
            this.divergenceFunction);
        result.creator = ObjectUtil.cloneSmart(this.creator);

        result.clusters = null;
        result.assignments = null;
        result.clusterCounts = null;

        return result;

    }

    @Override
    protected boolean initializeAlgorithm()
    {
        // Set the cluster state variables.
        this.setClusters(this.initializer.initializeClusters(
            this.numRequestedClusters, this.getData()));
        this.setClusterCounts(new int[this.getNumClusters()]);

        this.setAssignments(new int[this.getNumElements()]);
        Arrays.fill(this.assignments, -1);
        Arrays.fill(this.clusterCounts, 0);

        this.setNumChanged(0);

        // we can only run k-means if we have at least as many datapoints as
        // clusters we are requested to find.
        return this.getNumClusters() <= this.getNumElements();
    }

    /**
     * Do a step of the clustering algorithm.
     *
     * @return true means keep going, false means stop clustering.
     */
    @Override
    protected boolean step()
    {
        // First, assign each data point to a cluster, given the current
        // location of the clusters
        int[] newAssignements = this.assignDataToClusters(this.getData());
        int nc = 0;
        for (int i = 0; i < newAssignements.length; i++)
        {
            final int newAssignment = newAssignements[i];
            if (this.setAssignment(i, newAssignment))
            {
                nc++;
            }
        }

        this.setNumChanged(nc);

        // There was a change so create the clusters and keep going.
        if (this.getNumChanged() > 0)
        {
            // Now, re-estimate the cluster locations, given the current
            // assignments of the data points
            this.createClustersFromAssignments();
            return true;
        }
        // If the cluster assignments didn't change, then we're done
        else
        {
            return false;
        }

    }

    @Override
    protected void cleanupAlgorithm()
    {
    }

    /**
     * Creates the cluster assignments given the current locations of clusters
     *
     * @param data Data to assign
     * @return Assignments of the data to each of the k-clusters
     */
    protected int[] assignDataToClusters(
        Collection data)
    {
        // Loop through the elements and find the closest cluster for each.
        int i = 0;
        int[] localAssignments = new int[ data.size() ];
        for (DataType element : data)
        {
            // Get the i-th element and find the index of the closest cluster
            // to it.
            localAssignments[i] = this.getClosestClusterIndex(element);
            i++;
        }
        
        return localAssignments;
        
    }

    @Override
    public void setData(
        Collection data)
    {
        super.setData(data);
    }

    /**
     * Puts the data into a list of lists for each cluster to then estimate
     *
     * @return The list of lists for each cluster to then estimate
     */
    protected ArrayList> assignDataFromIndices()
    {
        // Loop through the clusters and initialize their membership lists
        // based on who is in them.
        int numClusters = this.getNumClusters();
        ArrayList> clustersMembers = new ArrayList<>(
            numClusters);
        for (int i = 0; i < numClusters; i++)
        {
            int clusterSize = this.clusterCounts[i];
            clustersMembers.add(new ArrayList<>(clusterSize));
        }

        // Go through and add each element to its proper cluster based on
        // the current assignments.
        int index = 0;
        for (DataType element : this.getData())
        {
            int assignment = this.assignments[index];
            clustersMembers.get(assignment).add(element);
            index++;
        }

        return clustersMembers;

    }

    /**
     * Creates the set of clusters using the current cluster assignments.
     */
    protected void createClustersFromAssignments()
    {
        // Loop through the clusters and initialize their membership lists
        // based on who is in them.
        final ArrayList> clustersMembers
            = this.assignDataFromIndices();

        // Create the clusters from their memberships.
        int clusterIndex = 0;
        for (final ArrayList members : clustersMembers)
        {
            final ClusterType cluster;
            if (members.size() > 0)
            {
                cluster = this.creator.createCluster(members);
            }
            else
            {
                cluster = null;
            }
            this.clusters.set(clusterIndex, cluster);
            clusterIndex++;
        }

    }

    /**
     * Gets the index of the closest cluster for the given element.
     *
     * @param element The element to get the closet cluster for.
     * @return The index of the closest cluster.
     */
    protected int getClosestClusterIndex(
        DataType element)
    {
        // Find the closest cluster.
        double minDistance = Double.MAX_VALUE;
        int closestClusterIndex = -1;

        // Loop over all the clusters.
        for (int i = 0; i < this.getNumClusters(); i++)
        {
            // Get the i-th cluster.
            ClusterType cluster = this.clusters.get(i);

            if (cluster != null)
            {
                // Compute the distance to the i-th cluster.
                double distance = this.divergenceFunction.evaluate(cluster,
                    element);

                if (closestClusterIndex < 0 || distance < minDistance)
                {
                    // This is the closest so far.
                    minDistance = distance;
                    closestClusterIndex = i;
                }
                // else - There is already a closer cluster.
            }
            // else - Ignore empty clusters.
        }

        // Return the index of the closest cluster.
        return closestClusterIndex;
    }

    /**
     * Sets the assignment of the given element to the new cluster index,
     * updating the cluster counts as well.
     *
     * @param elementIndex The index of the element.
     * @param newClusterIndex The new cluster the element is assigned to.
     * @return True if the assignment changed. Otherwise, false.
     */
    protected boolean setAssignment(
        int elementIndex,
        int newClusterIndex)
    {
        // Save the old assignment.
        int oldClusterIndex = this.assignments[elementIndex];

        // Set the new assignment.
        this.assignments[elementIndex] = newClusterIndex;

        if (oldClusterIndex >= 0)
        {
            // Decrement the counter for the old assignment since the element
            // is no longer in that cluster.
            this.clusterCounts[oldClusterIndex]--;
        }

        if (newClusterIndex >= 0)
        {
            // Increment the counter for the new assignment since the element
            // is now in that cluster.
            this.clusterCounts[newClusterIndex]++;
        }

        return newClusterIndex != oldClusterIndex;
    }

    /**
     * Gets the cluster for the given index.
     *
     * @param index The index of the cluster.
     * @return The cluster for the given index.
     */
    protected ClusterType getCluster(
        int index)
    {
        return this.clusters.get(index);
    }

    /**
     * Gets the actual number of clusters that were created.
     *
     * @return The actual number of clusters.
     */
    protected int getNumClusters()
    {
        return (this.getClusters() == null) ? 0 : this.getClusters().size();
    }

    /**
     * Gets the number of clusters that were requested.
     *
     * @return The number of clusters that were requested.
     */
    public int getNumRequestedClusters()
    {
        return this.numRequestedClusters;
    }

    /**
     * Gets the cluster initializer.
     *
     * @return The cluster initializer.
     */
    public FixedClusterInitializer getInitializer()
    {
        return this.initializer;
    }

    /**
     * Gets the divergence function used in clustering.
     *
     * @return The divergence function.
     */
    @Override
    public ClusterDivergenceFunction
        getDivergenceFunction()
    {
        return this.divergenceFunction;
    }

    /**
     * Gets the cluster creator.
     *
     * @return The cluster creator.
     */
    public ClusterCreator getCreator()
    {
        return this.creator;
    }

    /**
     * Sets the number of requested clusters.
     *
     * @param numRequestedClusters The number of requested clusters.
     */
    public void setNumRequestedClusters(
        int numRequestedClusters)
    {
        if (numRequestedClusters < 0)
        {
            // Error: Bad number of clusters requested.
            throw new IllegalArgumentException(
                "The number of clusters cannot be less than zero.");
        }

        this.numRequestedClusters = numRequestedClusters;
    }

    /**
     * Sets the cluster initializer.
     *
     * @param initializer The cluster initializer.
     */
    public void setInitializer(
        FixedClusterInitializer initializer)
    {
        this.initializer = initializer;
    }

    /**
     * Sets the divergence function.
     *
     * @param divergenceFunction The divergence function.
     */
    public void setDivergenceFunction(
        ClusterDivergenceFunction divergenceFunction)
    {
        this.divergenceFunction = divergenceFunction;
    }

    /**
     * Sets the cluster creator.
     *
     * @param creator The creator for clusters.
     */
    public void setCreator(
        ClusterCreator creator)
    {
        this.creator = creator;
    }

    /**
     * Returns the number of elements
     *
     * @return number of elements being clustered
     */
    public int getNumElements()
    {
        if (this.getData() != null)
        {
            return this.getData().size();
        }
        else
        {
            return 0;
        }
    }

    /**
     * Sets the clusters.
     *
     * @param clusters The clusters.
     */
    protected void setClusters(
        ArrayList clusters)
    {
        this.clusters = clusters;
    }

    /**
     * Getter for clusters
     *
     * @return list of clusters in the algorithm
     */
    public ArrayList getClusters()
    {
        return this.clusters;
    }

    @Override
    public ArrayList getResult()
    {
        return this.getClusters();
    }

    /**
     * Sets the assignment of elements to clusters.
     *
     * @param assignments The new assignments.
     */
    private void setAssignments(
        int[] assignments)
    {
        this.assignments = assignments;
    }

    /**
     * Getter for assignments
     *
     * @return The assignment of elements to clusters
     */
    protected int[] getAssignments()
    {
        return this.assignments;
    }

    /**
     * Sets the counts for how many elements are in each cluster.
     *
     * @param clusterCounts The new cluster counts.
     */
    private void setClusterCounts(
        int[] clusterCounts)
    {
        this.clusterCounts = clusterCounts;
    }

    /**
     * Getter for clusterCounts
     *
     * @return counts for how many elements are assigned to each cluster
     */
    protected int[] getClusterCounts()
    {
        return this.clusterCounts;
    }

    /**
     * Getter for numChanged
     *
     * @return Returns the number of samples that changed assignment between
     * iterations
     */
    public int getNumChanged()
    {
        return this.numChanged;
    }

    /**
     * Setter for numChanged
     *
     * @param numChanged Returns the number of samples that changed assignment
     * between iterations
     */
    protected void setNumChanged(
        int numChanged)
    {
        this.numChanged = numChanged;
    }

    /**
     * Gets the performance, which is the number changed on the last iteration.
     *
     * @return The performance of the algorithm.
     */
    @Override
    public NamedValue getPerformance()
    {
        return new DefaultNamedValue<>(
            "Assignments changed", this.getNumChanged());
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy