All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.clustering.XMeans Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *   
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *  
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package smile.clustering;

import java.util.ArrayList;
import smile.math.Math;
import smile.sort.QuickSort;

/**
 * X-Means clustering algorithm, an extended K-Means which tries to
 * automatically determine the number of clusters based on BIC scores.
 * Starting with only one cluster, the X-Means algorithm goes into action
 * after each run of K-Means, making local decisions about which subset of the
 * current centroids should split themselves in order to better fit the data.
 * The splitting decision is done by computing the Bayesian Information
 * Criterion (BIC).
 * 
 * 

References

*
    *
  1. Dan Pelleg and Andrew Moore. X-means: Extending K-means with Efficient Estimation of the Number of Clusters. ICML, 2000.
  2. *
* * @see KMeans * @see GMeans * * @author Haifeng Li */ public class XMeans extends KMeans { private static final double LOG2PI = Math.log(Math.PI * 2.0); /** * Constructor. Clustering data with the number of clusters being * automatically determined by X-Means algorithm. * @param data the input data of which each row is a sample. * @param kmax the maximum number of clusters. */ public XMeans(double[][] data, int kmax) { if (kmax < 2) { throw new IllegalArgumentException("Invalid parameter kmax = " + kmax); } int n = data.length; int d = data[0].length; k = 1; size = new int[k]; size[0] = n; y = new int[n]; centroids = new double[k][d]; for (int i = 0; i < n; i++) { for (int j = 0; j < d; j++) { centroids[0][j] += data[i][j]; } } for (int j = 0; j < d; j++) { centroids[0][j] /= n; } // within-cluster sum of squares double[] wcss = new double[k]; for (int i = 0; i < n; i++) { wcss[0] += Math.squaredDistance(data[i], centroids[0]); } distortion = wcss[0]; System.out.format("X-Means distortion with %d clusters: %.5f\n", k, distortion); BBDTree bbd = new BBDTree(data); while (k < kmax) { ArrayList centers = new ArrayList(); double[] score = new double[k]; KMeans[] kmeans = new KMeans[k]; for (int i = 0; i < k; i++) { // don't split too small cluster. anyway likelihood estimation // not accurate in this case. if (size[i] < 25) { System.out.format("Cluster %3d\ttoo small to split: %d samples\n", i, size[i]); continue; } double[][] subset = new double[size[i]][]; for (int j = 0, l = 0; j < n; j++) { if (y[j] == i) { subset[l++] = data[j]; } } kmeans[i] = new KMeans(subset, 2, 100, 4); double newBIC = bic(2, size[i], d, kmeans[i].distortion, kmeans[i].size); double oldBIC = bic(size[i], d, wcss[i]); score[i] = newBIC - oldBIC; System.out.format("Cluster %3d\tBIC: %.5f\tBIC after split: %.5f\timprovement: %.5f\n", i, oldBIC, newBIC, score[i]); } int[] index = QuickSort.sort(score); for (int i = 0; i < k; i++) { if (score[index[i]] <= 0.0) { centers.add(centroids[index[i]]); } } int m = centers.size(); for (int i = k; --i >= 0;) { if (score[i] > 0) { if (centers.size() + i - m + 1 < kmax) { System.out.format("Split cluster %d...\n", index[i]); centers.add(kmeans[index[i]].centroids[0]); centers.add(kmeans[index[i]].centroids[1]); } else { centers.add(centroids[index[i]]); } } } // no more split. if (centers.size() == k) { break; } k = centers.size(); double[][] sums = new double[k][d]; size = new int[k]; centroids = new double[k][]; for (int i = 0; i < k; i++) { centroids[i] = centers.get(i); } distortion = Double.MAX_VALUE; for (int iter = 0; iter < 100; iter++) { double newDistortion = bbd.clustering(centroids, sums, size, y); for (int i = 0; i < k; i++) { if (size[i] > 0) { for (int j = 0; j < d; j++) { centroids[i][j] = sums[i][j] / size[i]; } } } if (distortion <= newDistortion) { break; } else { distortion = newDistortion; } } wcss = new double[k]; for (int i = 0; i < n; i++) { wcss[y[i]] += Math.squaredDistance(data[i], centroids[y[i]]); } System.out.format("X-Means distortion with %d clusters: %.5f\n", k, distortion); } } /** * Calculates the BIC for single cluster. * @param k the number of clusters. * @param n the total number of samples. * @param d the dimensionality of data. * @param distortion the distortion of clusters. * @return the BIC score. */ private double bic(int n, int d, double distortion) { double variance = distortion / (n - 1); double p1 = -n * LOG2PI; double p2 = -n * d * Math.log(variance); double p3 = -(n - 1); double L = (p1 + p2 + p3) / 2; int numParameters = d + 1; return L - 0.5 * numParameters * Math.log(n); } /** * Calculates the BIC for the given set of centers. * @param k the number of clusters. * @param n the total number of samples. * @param d the dimensionality of data. * @param distortion the distortion of clusters. * @param clusterSize the number of samples in each cluster. * @return the BIC score. */ private double bic(int k, int n, int d, double distortion, int[] clusterSize) { double variance = distortion / (n - k); double L = 0.0; for (int i = 0; i < k; i++) { L += logLikelihood(k, n, clusterSize[i], d, variance); } int numParameters = k + k * d; return L - 0.5 * numParameters * Math.log(n); } /** * Estimate the log-likelihood of the data for the given model. * * @param k the number of clusters. * @param n the total number of samples. * @param ni the number of samples belong to this cluster. * @param d the dimensionality of data. * @param variance the estimated variance of clusters. * @return the likelihood estimate */ private static double logLikelihood(int k, int n, int ni, int d, double variance) { double p1 = -ni * LOG2PI; double p2 = -ni * d * Math.log(variance); double p3 = -(ni - k); double p4 = ni * Math.log(ni); double p5 = -ni * Math.log(n); double loglike = (p1 + p2 + p3) / 2 + p4 + p5; return loglike; } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append(String.format("X-Means distortion: %.5f\n", distortion)); sb.append(String.format("Clusters of %d data points of dimension %d:\n", y.length, centroids[0].length)); for (int i = 0; i < k; i++) { int r = (int) Math.round(1000.0 * size[i] / y.length); sb.append(String.format("%3d\t%5d (%2d.%1d%%)\n", i, size[i], r / 10, r % 10)); } return sb.toString(); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy