All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.uci.ics.jung.algorithms.util.KMeansClusterer Maven / Gradle / Ivy

There is a newer version: 2.1.1
Show newest version
/*
 * Copyright (c) 2003, The JUNG Authors
 *
 * All rights reserved.
 *
 * This software is open-source under the BSD license; see either
 * "license.txt" or
 * https://github.com/jrtom/jung/blob/master/LICENSE for a description.
 */
/*
 * Created on Aug 9, 2004
 *
 */
package edu.uci.ics.jung.algorithms.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Set;



/**
 * Groups items into a specified number of clusters, based on their proximity in
 * d-dimensional space, using the k-means algorithm. Calls to
 * cluster will terminate when either of the two following
 * conditions is true:
 * 
    *
  • the number of iterations is > max_iterations *
  • none of the centroids has moved as much as convergence_threshold * since the previous iteration *
* * @author Joshua O'Madadhain */ public class KMeansClusterer { protected int max_iterations; protected double convergence_threshold; protected Random rand; /** * Creates an instance which will terminate when either the maximum number of * iterations has been reached, or all changes are smaller than the convergence threshold. * @param max_iterations the maximum number of iterations to employ * @param convergence_threshold the smallest change we want to track */ public KMeansClusterer(int max_iterations, double convergence_threshold) { this.max_iterations = max_iterations; this.convergence_threshold = convergence_threshold; this.rand = new Random(); } /** * Creates an instance with max iterations of 100 and convergence threshold * of 0.001. */ public KMeansClusterer() { this(100, 0.001); } /** * @return the maximum number of iterations */ public int getMaxIterations() { return max_iterations; } /** * @param max_iterations the maximum number of iterations */ public void setMaxIterations(int max_iterations) { if (max_iterations < 0) throw new IllegalArgumentException("max iterations must be >= 0"); this.max_iterations = max_iterations; } /** * @return the convergence threshold */ public double getConvergenceThreshold() { return convergence_threshold; } /** * @param convergence_threshold the convergence threshold */ public void setConvergenceThreshold(double convergence_threshold) { if (convergence_threshold <= 0) throw new IllegalArgumentException("convergence threshold " + "must be > 0"); this.convergence_threshold = convergence_threshold; } /** * Returns a Collection of clusters, where each cluster is * represented as a Map of Objects to locations * in d-dimensional space. * @param object_locations a map of the items to cluster, to * double arrays that specify their locations in d-dimensional space. * @param num_clusters the number of clusters to create * @return a clustering of the input objects in d-dimensional space * @throws NotEnoughClustersException if {@code num_clusters} is larger than the number of * distinct points in object_locations */ @SuppressWarnings("unchecked") public Collection> cluster(Map object_locations, int num_clusters) { if (object_locations == null || object_locations.isEmpty()) throw new IllegalArgumentException("'objects' must be non-empty"); if (num_clusters < 2 || num_clusters > object_locations.size()) throw new IllegalArgumentException("number of clusters " + "must be >= 2 and <= number of objects (" + object_locations.size() + ")"); Set centroids = new HashSet(); Object[] obj_array = object_locations.keySet().toArray(); Set tried = new HashSet(); // create the specified number of clusters while (centroids.size() < num_clusters && tried.size() < object_locations.size()) { T o = (T)obj_array[(int)(rand.nextDouble() * obj_array.length)]; tried.add(o); double[] mean_value = object_locations.get(o); boolean duplicate = false; for (double[] cur : centroids) { if (Arrays.equals(mean_value, cur)) duplicate = true; } if (!duplicate) centroids.add(mean_value); } if (tried.size() >= object_locations.size()) throw new NotEnoughClustersException(); // put items in their initial clusters Map> clusterMap = assignToClusters(object_locations, centroids); // keep reconstituting clusters until either // (a) membership is stable, or // (b) number of iterations passes max_iterations, or // (c) max movement of any centroid is <= convergence_threshold int iterations = 0; double max_movement = Double.POSITIVE_INFINITY; while (iterations++ < max_iterations && max_movement > convergence_threshold) { max_movement = 0; Set new_centroids = new HashSet(); // calculate new mean for each cluster for (Map.Entry> entry : clusterMap.entrySet()) { double[] centroid = entry.getKey(); Map elements = entry.getValue(); ArrayList locations = new ArrayList(elements.values()); double[] mean = DiscreteDistribution.mean(locations); max_movement = Math.max(max_movement, Math.sqrt(DiscreteDistribution.squaredError(centroid, mean))); new_centroids.add(mean); } // TODO: check membership of clusters: have they changed? // regenerate cluster membership based on means clusterMap = assignToClusters(object_locations, new_centroids); } return clusterMap.values(); } /** * Assigns each object to the cluster whose centroid is closest to the * object. * @param object_locations a map of objects to locations * @param centroids the centroids of the clusters to be formed * @return a map of objects to assigned clusters */ protected Map> assignToClusters(Map object_locations, Set centroids) { Map> clusterMap = new HashMap>(); for (double[] centroid : centroids) clusterMap.put(centroid, new HashMap()); for (Map.Entry object_location : object_locations.entrySet()) { T object = object_location.getKey(); double[] location = object_location.getValue(); // find the cluster with the closest centroid Iterator c_iter = centroids.iterator(); double[] closest = c_iter.next(); double distance = DiscreteDistribution.squaredError(location, closest); while (c_iter.hasNext()) { double[] centroid = c_iter.next(); double dist_cur = DiscreteDistribution.squaredError(location, centroid); if (dist_cur < distance) { distance = dist_cur; closest = centroid; } } clusterMap.get(closest).put(object, location); } return clusterMap; } /** * Sets the seed used by the internal random number generator. * Enables consistent outputs. * @param random_seed the random seed to use */ public void setSeed(int random_seed) { this.rand = new Random(random_seed); } /** * An exception that indicates that the specified data points cannot be * clustered into the number of clusters requested by the user. * This will happen if and only if there are fewer distinct points than * requested clusters. (If there are fewer total data points than * requested clusters, IllegalArgumentException will be thrown.) * * @author Joshua O'Madadhain */ @SuppressWarnings("serial") public static class NotEnoughClustersException extends RuntimeException { @Override public String getMessage() { return "Not enough distinct points in the input data set to form " + "the requested number of clusters"; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy