All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.commons.math.stat.clustering.KMeansPlusPlusClusterer Maven / Gradle / Ivy

There is a newer version: 6.5.21
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.commons.math.stat.clustering;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

import org.apache.commons.math.exception.ConvergenceException;
import org.apache.commons.math.exception.util.LocalizedFormats;
import org.apache.commons.math.stat.descriptive.moment.Variance;

/**
 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
 * @param  type of the points to cluster
 * @see K-means++ (wikipedia)
 * @version $Revision: 1054333 $ $Date: 2011-01-02 01:34:58 +0100 (dim. 02 janv. 2011) $
 * @since 2.0
 */
public class KMeansPlusPlusClusterer> {

    /** Strategies to use for replacing an empty cluster. */
    public static enum EmptyClusterStrategy {

        /** Split the cluster with largest distance variance. */
        LARGEST_VARIANCE,

        /** Split the cluster with largest number of points. */
        LARGEST_POINTS_NUMBER,

        /** Create a cluster around the point farthest from its centroid. */
        FARTHEST_POINT,

        /** Generate an error. */
        ERROR

    }

    /** Random generator for choosing initial centers. */
    private final Random random;

    /** Selected strategy for empty clusters. */
    private final EmptyClusterStrategy emptyStrategy;

    /** Build a clusterer.
     * 

* The default strategy for handling empty clusters that may appear during * algorithm iterations is to split the cluster with largest distance variance. *

* @param random random generator to use for choosing initial centers */ public KMeansPlusPlusClusterer(final Random random) { this(random, EmptyClusterStrategy.LARGEST_VARIANCE); } /** Build a clusterer. * @param random random generator to use for choosing initial centers * @param emptyStrategy strategy to use for handling empty clusters that * may appear during algorithm iterations * @since 2.2 */ public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) { this.random = random; this.emptyStrategy = emptyStrategy; } /** * Runs the K-means++ clustering algorithm. * * @param points the points to cluster * @param k the number of clusters to split the data into * @param maxIterations the maximum number of iterations to run the algorithm * for. If negative, no maximum will be used * @return a list of clusters containing the points */ public List> cluster(final Collection points, final int k, final int maxIterations) { // create the initial clusters List> clusters = chooseInitialCenters(points, k, random); assignPointsToClusters(clusters, points); // iterate through updating the centers until we're done final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; for (int count = 0; count < max; count++) { boolean clusteringChanged = false; List> newClusters = new ArrayList>(); for (final Cluster cluster : clusters) { final T newCenter; if (cluster.getPoints().isEmpty()) { switch (emptyStrategy) { case LARGEST_VARIANCE : newCenter = getPointFromLargestVarianceCluster(clusters); break; case LARGEST_POINTS_NUMBER : newCenter = getPointFromLargestNumberCluster(clusters); break; case FARTHEST_POINT : newCenter = getFarthestPoint(clusters); break; default : throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); } clusteringChanged = true; } else { newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); if (!newCenter.equals(cluster.getCenter())) { clusteringChanged = true; } } newClusters.add(new Cluster(newCenter)); } if (!clusteringChanged) { return clusters; } assignPointsToClusters(newClusters, points); clusters = newClusters; } return clusters; } /** * Adds the given points to the closest {@link Cluster}. * * @param type of the points to cluster * @param clusters the {@link Cluster}s to add the points to * @param points the points to add to the given {@link Cluster}s */ private static > void assignPointsToClusters(final Collection> clusters, final Collection points) { for (final T p : points) { Cluster cluster = getNearestCluster(clusters, p); cluster.addPoint(p); } } /** * Use K-means++ to choose the initial centers. * * @param type of the points to cluster * @param points the points to choose the initial centers from * @param k the number of centers to choose * @param random random generator to use * @return the initial centers */ private static > List> chooseInitialCenters(final Collection points, final int k, final Random random) { final List pointSet = new ArrayList(points); final List> resultSet = new ArrayList>(); // Choose one center uniformly at random from among the data points. final T firstPoint = pointSet.remove(random.nextInt(pointSet.size())); resultSet.add(new Cluster(firstPoint)); final double[] dx2 = new double[pointSet.size()]; while (resultSet.size() < k) { // For each data point x, compute D(x), the distance between x and // the nearest center that has already been chosen. int sum = 0; for (int i = 0; i < pointSet.size(); i++) { final T p = pointSet.get(i); final Cluster nearest = getNearestCluster(resultSet, p); final double d = p.distanceFrom(nearest.getCenter()); sum += d * d; dx2[i] = sum; } // Add one new data point as a center. Each point x is chosen with // probability proportional to D(x)2 final double r = random.nextDouble() * sum; for (int i = 0 ; i < dx2.length; i++) { if (dx2[i] >= r) { final T p = pointSet.remove(i); resultSet.add(new Cluster(p)); break; } } } return resultSet; } /** * Get a random point from the {@link Cluster} with the largest distance variance. * * @param clusters the {@link Cluster}s to search * @return a random point from the selected cluster */ private T getPointFromLargestVarianceCluster(final Collection> clusters) { double maxVariance = Double.NEGATIVE_INFINITY; Cluster selected = null; for (final Cluster cluster : clusters) { if (!cluster.getPoints().isEmpty()) { // compute the distance variance of the current cluster final T center = cluster.getCenter(); final Variance stat = new Variance(); for (final T point : cluster.getPoints()) { stat.increment(point.distanceFrom(center)); } final double variance = stat.getResult(); // select the cluster with the largest variance if (variance > maxVariance) { maxVariance = variance; selected = cluster; } } } // did we find at least one non-empty cluster ? if (selected == null) { throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); } // extract a random point from the cluster final List selectedPoints = selected.getPoints(); return selectedPoints.remove(random.nextInt(selectedPoints.size())); } /** * Get a random point from the {@link Cluster} with the largest number of points * * @param clusters the {@link Cluster}s to search * @return a random point from the selected cluster */ private T getPointFromLargestNumberCluster(final Collection> clusters) { int maxNumber = 0; Cluster selected = null; for (final Cluster cluster : clusters) { // get the number of points of the current cluster final int number = cluster.getPoints().size(); // select the cluster with the largest number of points if (number > maxNumber) { maxNumber = number; selected = cluster; } } // did we find at least one non-empty cluster ? if (selected == null) { throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); } // extract a random point from the cluster final List selectedPoints = selected.getPoints(); return selectedPoints.remove(random.nextInt(selectedPoints.size())); } /** * Get the point farthest to its cluster center * * @param clusters the {@link Cluster}s to search * @return point farthest to its cluster center */ private T getFarthestPoint(final Collection> clusters) { double maxDistance = Double.NEGATIVE_INFINITY; Cluster selectedCluster = null; int selectedPoint = -1; for (final Cluster cluster : clusters) { // get the farthest point final T center = cluster.getCenter(); final List points = cluster.getPoints(); for (int i = 0; i < points.size(); ++i) { final double distance = points.get(i).distanceFrom(center); if (distance > maxDistance) { maxDistance = distance; selectedCluster = cluster; selectedPoint = i; } } } // did we find at least one non-empty cluster ? if (selectedCluster == null) { throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); } return selectedCluster.getPoints().remove(selectedPoint); } /** * Returns the nearest {@link Cluster} to the given point * * @param type of the points to cluster * @param clusters the {@link Cluster}s to search * @param point the point to find the nearest {@link Cluster} for * @return the nearest {@link Cluster} to the given point */ private static > Cluster getNearestCluster(final Collection> clusters, final T point) { double minDistance = Double.MAX_VALUE; Cluster minCluster = null; for (final Cluster c : clusters) { final double distance = point.distanceFrom(c.getCenter()); if (distance < minDistance) { minDistance = distance; minCluster = c; } } return minCluster; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy