smile.clustering.KMeans Maven / Gradle / Ivy
The newest version!
/*******************************************************************************
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package smile.clustering;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import smile.math.Math;
import smile.util.MulticoreExecutor;
/**
* K-Means learn aims to partition n observations into k clusters in which
* each observation belongs to the cluster with the nearest mean.
* Although finding an exact solution to the k-means problem for arbitrary
* input is NP-hard, the standard approach to finding an approximate solution
* (often called Lloyd's algorithm or the k-means algorithm) is used widely
* and frequently finds reasonable solutions quickly.
*
* However, the k-means algorithm has at least two major theoretic shortcomings:
*
* - First, it has been shown that the worst case running time of the
* algorithm is super-polynomial in the input size.
*
- Second, the approximation found can be arbitrarily bad with respect
* to the objective function compared to the optimal learn.
*
* In this implementation, we use k-means++ which addresses the second of these
* obstacles by specifying a procedure to initialize the cluster centers before
* proceeding with the standard k-means optimization iterations. With the
* k-means++ initialization, the algorithm is guaranteed to find a solution
* that is O(log k) competitive to the optimal k-means solution.
*
* We also use k-d trees to speed up each k-means step as described in the filter
* algorithm by Kanungo, et al.
*
* K-means is a hard clustering method, i.e. each sample is assigned to
* a specific cluster. In contrast, soft clustering, e.g. the
* Expectation-Maximization algorithm for Gaussian mixtures, assign samples
* to different clusters with different probabilities.
*
*
References
*
* - Tapas Kanungo, David M. Mount, Nathan S. Netanyahu, Christine D. Piatko, Ruth Silverman, and Angela Y. Wu. An Efficient k-Means Clustering Algorithm: Analysis and Implementation. IEEE TRANS. PAMI, 2002.
* - D. Arthur and S. Vassilvitskii. "K-means++: the advantages of careful seeding". ACM-SIAM symposium on Discrete algorithms, 1027-1035, 2007.
* - Anna D. Peterson, Arka P. Ghosh and Ranjan Maitra. A systematic evaluation of different methods for initializing the K-means clustering algorithm. 2010.
*
*
* @see XMeans
* @see GMeans
* @see CLARANS
* @see SIB
* @see SOM
* @see NeuralGas
* @see BIRCH
* @see BBDTree
*
* @author Haifeng Li
*/
public class KMeans extends PartitionClustering {
/**
* The total distortion.
*/
double distortion;
/**
* The centroids of each cluster.
*/
double[][] centroids;
/**
* Constructor.
*/
KMeans() {
}
/**
* Returns the distortion.
*/
public double distortion() {
return distortion;
}
/**
* Returns the centroids.
*/
public double[][] centroids() {
return centroids;
}
/**
* Cluster a new instance.
* @param x a new instance.
* @return the cluster label, which is the index of nearest centroid.
*/
@Override
public int predict(double[] x) {
double minDist = Double.MAX_VALUE;
int bestCluster = 0;
for (int i = 0; i < k; i++) {
double dist = Math.squaredDistance(x, centroids[i]);
if (dist < minDist) {
minDist = dist;
bestCluster = i;
}
}
return bestCluster;
}
/**
* Constructor. Clustering data into k clusters up to 100 iterations.
* @param data the input data of which each row is a sample.
* @param k the number of clusters.
*/
public KMeans(double[][] data, int k) {
this(data, k, 100);
}
/**
* Constructor. Clustering data into k clusters.
* @param data the input data of which each row is a sample.
* @param k the number of clusters.
* @param maxIter the maximum number of iterations for each running.
*/
public KMeans(double[][] data, int k, int maxIter) {
this(new BBDTree(data), data, k, maxIter);
}
/**
* Constructor. Clustering data into k clusters.
* @param bbd the BBD-tree of data for fast clustering.
* @param data the input data of which each row is a sample.
* @param k the number of clusters.
* @param maxIter the maximum number of iterations for each running.
*/
KMeans(BBDTree bbd, double[][] data, int k, int maxIter) {
if (k < 2) {
throw new IllegalArgumentException("Invalid number of clusters: " + k);
}
if (maxIter <= 0) {
throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
}
int n = data.length;
int d = data[0].length;
this.k = k;
distortion = Double.MAX_VALUE;
y = seed(data, k, DistanceMethod.EUCLIDEAN);
size = new int[k];
centroids = new double[k][d];
for (int i = 0; i < n; i++) {
size[y[i]]++;
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < d; j++) {
centroids[y[i]][j] += data[i][j];
}
}
for (int i = 0; i < k; i++) {
for (int j = 0; j < d; j++) {
centroids[i][j] /= size[i];
}
}
double[][] sums = new double[k][d];
for (int iter = 1; iter <= maxIter; iter++) {
double dist = bbd.clustering(centroids, sums, size, y);
for (int i = 0; i < k; i++) {
if (size[i] > 0) {
for (int j = 0; j < d; j++) {
centroids[i][j] = sums[i][j] / size[i];
}
}
}
if (distortion <= dist) {
break;
} else {
distortion = dist;
}
}
}
/**
* Clustering data into k clusters. Run the algorithm for given times
* and return the best one with smallest distortion.
* @param data the input data of which each row is a sample.
* @param k the number of clusters.
* @param maxIter the maximum number of iterations for each running.
* @param runs the number of runs of K-Means algorithm.
*/
public KMeans(double[][] data, int k, int maxIter, int runs) {
if (k < 2) {
throw new IllegalArgumentException("Invalid number of clusters: " + k);
}
if (maxIter <= 0) {
throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
}
if (runs <= 0) {
throw new IllegalArgumentException("Invalid number of runs: " + runs);
}
BBDTree bbd = new BBDTree(data);
List tasks = new ArrayList();
for (int i = 0; i < runs; i++) {
tasks.add(new KMeansThread(bbd, data, k, maxIter));
}
KMeans best = new KMeans();
best.distortion = Double.MAX_VALUE;
try {
List clusters = MulticoreExecutor.run(tasks);
for (KMeans kmeans : clusters) {
if (kmeans.distortion < best.distortion) {
best = kmeans;
}
}
} catch (Exception ex) {
System.err.println(ex);
for (int i = 0; i < runs; i++) {
KMeans kmeans = lloyd(data, k, maxIter);
if (kmeans.distortion < best.distortion) {
best = kmeans;
}
}
}
this.k = best.k;
this.distortion = best.distortion;
this.centroids = best.centroids;
this.y = best.y;
this.size = best.size;
}
/**
* Adapter for running BBD-Tree based K-Means algorithm in thread pool.
*/
static class KMeansThread implements Callable {
final BBDTree bbd;
final double[][] data;
final int k;
final int maxIter;
KMeansThread(BBDTree bbd, double[][] data, int k, int maxIter) {
this.bbd = bbd;
this.data = data;
this.k = k;
this.maxIter = maxIter;
}
@Override
public KMeans call() {
return new KMeans(bbd, data, k, maxIter);
}
}
/**
* The implementation of Lloyd algorithm as a benchmark. The data may
* contain missing values (i.e. Double.NaN). The algorithm runs up to
* 100 iterations.
* @param data the input data of which each row is a sample.
* @param k the number of clusters.
*/
public static KMeans lloyd(double[][] data, int k) {
return lloyd(data, k, 100);
}
/**
* The implementation of Lloyd algorithm as a benchmark. The data may
* contain missing values (i.e. Double.NaN).
* @param data the input data of which each row is a sample.
* @param k the number of clusters.
* @param maxIter the maximum number of iterations for each running.
*/
public static KMeans lloyd(double[][] data, int k, int maxIter) {
if (k < 2) {
throw new IllegalArgumentException("Invalid number of clusters: " + k);
}
if (maxIter <= 0) {
throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
}
int n = data.length;
int d = data[0].length;
int[][] nd = new int[k][d]; // The number of non-missing values per cluster per variable.
double distortion = Double.MAX_VALUE;
int[] size = new int[k];
double[][] centroids = new double[k][d];
int[] y = seed(data, k, DistanceMethod.EUCLIDEAN_MISSING_VALUES);
int np = MulticoreExecutor.getThreadPoolSize();
List tasks = null;
if (n >= 1000 && np >= 2) {
tasks = new ArrayList(np + 1);
int step = n / np;
if (step < 100) {
step = 100;
}
int start = 0;
int end = step;
for (int i = 0; i < np-1; i++) {
tasks.add(new LloydThread(data, centroids, y, start, end));
start += step;
end += step;
}
tasks.add(new LloydThread(data, centroids, y, start, n));
}
for (int iter = 0; iter < maxIter; iter++) {
Arrays.fill(size, 0);
for (int i = 0; i < k; i++) {
Arrays.fill(centroids[i], 0);
Arrays.fill(nd[i], 0);
}
for (int i = 0; i < n; i++) {
int m = y[i];
size[m]++;
for (int j = 0; j < d; j++) {
if (!Double.isNaN(data[i][j])) {
centroids[m][j] += data[i][j];
nd[m][j]++;
}
}
}
for (int i = 0; i < k; i++) {
for (int j = 0; j < d; j++) {
centroids[i][j] /= nd[i][j];
}
}
double wcss = Double.NaN;
if (tasks != null) {
try {
wcss = 0.0;
for (double ss : MulticoreExecutor.run(tasks)) {
wcss += ss;
}
} catch (Exception ex) {
System.err.println(ex);
wcss = Double.NaN;
}
}
if (Double.isNaN(wcss)) {
wcss = 0.0;
for (int i = 0; i < n; i++) {
double nearest = Double.MAX_VALUE;
for (int j = 0; j < k; j++) {
double dist = squaredDistance(data[i], centroids[j]);
if (nearest > dist) {
y[i] = j;
nearest = dist;
}
}
wcss += nearest;
}
}
if (distortion <= wcss) {
break;
} else {
distortion = wcss;
}
}
// In case of early stop, we should recalculate centroids and clusterSize.
Arrays.fill(size, 0);
for (int i = 0; i < k; i++) {
Arrays.fill(centroids[i], 0);
Arrays.fill(nd[i], 0);
}
for (int i = 0; i < n; i++) {
int m = y[i];
size[m]++;
for (int j = 0; j < d; j++) {
if (!Double.isNaN(data[i][j])) {
centroids[m][j] += data[i][j];
nd[m][j]++;
}
}
}
for (int i = 0; i < k; i++) {
for (int j = 0; j < d; j++) {
centroids[i][j] /= nd[i][j];
}
}
KMeans kmeans = new KMeans();
kmeans.k = k;
kmeans.distortion = distortion;
kmeans.size = size;
kmeans.centroids = centroids;
kmeans.y = y;
return kmeans;
}
/**
* The implementation of Lloyd algorithm as a benchmark. Run the algorithm
* multiple times and return the best one in terms of smallest distortion.
* The data may contain missing values (i.e. Double.NaN).
* @param data the input data of which each row is a sample.
* @param k the number of clusters.
* @param maxIter the maximum number of iterations for each running.
* @param runs the number of runs of K-Means algorithm.
*/
public static KMeans lloyd(double[][] data, int k, int maxIter, int runs) {
if (k < 2) {
throw new IllegalArgumentException("Invalid number of clusters: " + k);
}
if (maxIter <= 0) {
throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
}
if (runs <= 0) {
throw new IllegalArgumentException("Invalid number of runs: " + runs);
}
KMeans best = lloyd(data, k, maxIter);
for (int i = 1; i < runs; i++) {
KMeans kmeans = lloyd(data, k, maxIter);
if (kmeans.distortion < best.distortion) {
best = kmeans;
}
}
return best;
}
/**
* Adapter for running Lloyd algorithm in thread pool.
*/
static class LloydThread implements Callable {
/**
* The start index of data portion for this task.
*/
final int start;
/**
* The end index of data portion for this task.
*/
final int end;
final double[][] data;
final int k;
final double[][] centroids;
int[] y;
LloydThread(double[][] data, double[][] centroids, int[] y, int start, int end) {
this.data = data;
this.k = centroids.length;
this.y = y;
this.centroids = centroids;
this.start = start;
this.end = end;
}
@Override
public Double call() {
double wcss = 0.0;
for (int i = start; i < end; i++) {
double nearest = Double.MAX_VALUE;
for (int j = 0; j < k; j++) {
double dist = squaredDistance(data[i], centroids[j]);
if (nearest > dist) {
y[i] = j;
nearest = dist;
}
}
wcss += nearest;
}
return wcss;
}
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(String.format("K-Means distortion: %.5f\n", distortion));
sb.append(String.format("Clusters of %d data points of dimension %d:\n", y.length, centroids[0].length));
for (int i = 0; i < k; i++) {
int r = (int) Math.round(1000.0 * size[i] / y.length);
sb.append(String.format("%3d\t%5d (%2d.%1d%%)\n", i, size[i], r / 10, r % 10));
}
return sb.toString();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy