All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.cluster.KMeans Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

There is a newer version: 2.0.12
Show newest version
/*
 * Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
 * This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
 * http://www.cs.umass.edu/~mccallum/mallet This software is provided under the
 * terms of the Common Public License, version 1.0, as published by
 * http://www.opensource.org. For further information, see the file `LICENSE'
 * included with this distribution.
 */

/**
 * Clusters a set of point via k-Means. The instances that are clustered are
 * expected to be of the type FeatureVector.
 * 
 * EMPTY_SINGLE and other changes implemented March 2005 Heuristic cluster
 * selection implemented May 2005
 * 
 * @author Jerod Weinman [email protected]
 * @author Mike Winter [email protected]
 * 
 */

package cc.mallet.cluster;

import java.util.ArrayList;
import java.util.Random;
import java.util.logging.Logger;

import cc.mallet.pipe.Pipe;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Metric;
import cc.mallet.types.SparseVector;
import cc.mallet.util.VectorStats;

/**
 * KMeans Clusterer
 * 
 * Clusters the points into k clusters by minimizing the total intra-cluster
 * variance. It uses a given {@link Metric} to find the distance between
 * {@link Instance}s, which should have {@link SparseVector}s in the data
 * field.
 * 
 */
public class KMeans extends Clusterer {

  private static final long serialVersionUID = 1L;

  // Stop after movement of means is less than this
  static double MEANS_TOLERANCE = 1e-2;

  // Maximum number of iterations
  static int MAX_ITER = 100;

  // Minimum fraction of points that move
  static double POINTS_TOLERANCE = .005;

  /**
   * Treat an empty cluster as an error condition.
   */
  public static final int EMPTY_ERROR = 0;
  /**
   * Drop an empty cluster
   */
  public static final int EMPTY_DROP = 1;
  /**
   * Place the single instance furthest from the previous cluster mean
   */
  public static final int EMPTY_SINGLE = 2;

  Random randinator;
  Metric metric;
  int numClusters;
  int emptyAction;
  ArrayList clusterMeans;

  private static Logger logger = Logger
      .getLogger("edu.umass.cs.mallet.base.cluster.KMeans");

  /**
   * Construct a KMeans object
   * 
   * @param instancePipe Pipe for the instances being clustered
   * @param numClusters Number of clusters to use
   * @param metric Metric object to measure instance distances
   * @param emptyAction Specify what should happen when an empty cluster occurs
   */
  public KMeans(Pipe instancePipe, int numClusters, Metric metric,
      int emptyAction) {

    super(instancePipe);

    this.emptyAction = emptyAction;
    this.metric = metric;
    this.numClusters = numClusters;

    this.clusterMeans = new ArrayList(numClusters);
    this.randinator = new Random();

  }

  /**
   * Construct a KMeans object
   * 
   * @param instancePipe Pipe for the instances being clustered
   * @param numClusters Number of clusters to use
   * @param metric Metric object to measure instance distances 

If an empty * cluster occurs, it is considered an error. */ public KMeans(Pipe instancePipe, int numClusters, Metric metric) { this(instancePipe, numClusters, metric, EMPTY_ERROR); } /** * Cluster instances * * @param instances List of instances to cluster */ @Override public Clustering cluster(InstanceList instances) { assert (instances.getPipe() == this.instancePipe); // Initialize clusterMeans initializeMeansSample(instances, this.metric); int clusterLabels[] = new int[instances.size()]; ArrayList instanceClusters = new ArrayList( numClusters); int instClust; double instClustDist, instDist; double deltaMeans = Double.MAX_VALUE; double deltaPoints = (double) instances.size(); int iterations = 0; SparseVector clusterMean; for (int c = 0; c < numClusters; c++) { instanceClusters.add(c, new InstanceList(instancePipe)); } logger.info("Entering KMeans iteration"); while (deltaMeans > MEANS_TOLERANCE && iterations < MAX_ITER && deltaPoints > instances.size() * POINTS_TOLERANCE) { iterations++; deltaPoints = 0; // For each instance, measure its distance to the current cluster // means, and subsequently assign it to the closest cluster // by adding it to an corresponding instance list // The mean of each cluster InstanceList is then updated. for (int n = 0; n < instances.size(); n++) { instClust = 0; instClustDist = Double.MAX_VALUE; for (int c = 0; c < numClusters; c++) { instDist = metric.distance(clusterMeans.get(c), (SparseVector) instances.get(n).getData()); if (instDist < instClustDist) { instClust = c; instClustDist = instDist; } } // Add to closest cluster & label it such instanceClusters.get(instClust).add(instances.get(n)); if (clusterLabels[n] != instClust) { clusterLabels[n] = instClust; deltaPoints++; } } deltaMeans = 0; for (int c = 0; c < numClusters; c++) { if (instanceClusters.get(c).size() > 0) { clusterMean = VectorStats.mean(instanceClusters.get(c)); deltaMeans += metric.distance(clusterMeans.get(c), clusterMean); clusterMeans.set(c, clusterMean); instanceClusters.set(c, new InstanceList(instancePipe)); } else { logger.info("Empty cluster found."); switch (emptyAction) { case EMPTY_ERROR: return null; case EMPTY_DROP: logger.fine("Removing cluster " + c); clusterMeans.remove(c); instanceClusters.remove(c); for (int n = 0; n < instances.size(); n++) { assert (clusterLabels[n] != c) : "Cluster size is " + instanceClusters.get(c).size() + "+ yet clusterLabels[n] is " + clusterLabels[n]; if (clusterLabels[n] > c) clusterLabels[n]--; } numClusters--; c--; // <-- note this trickiness. bad style? maybe. // it just means now that we've deleted the entry, // we have to repeat the index to get the next entry. break; case EMPTY_SINGLE: // Get the instance the furthest from any centroid // and make it a new centroid. double newCentroidDist = 0; int newCentroid = 0; InstanceList cacheList = null; for (int clusters = 0; clusters < clusterMeans.size(); clusters++) { SparseVector centroid = clusterMeans.get(clusters); InstanceList centInstances = instanceClusters.get(clusters); // Dont't create new empty clusters. if (centInstances.size() <= 1) continue; for (int n = 0; n < centInstances.size(); n++) { double currentDist = metric.distance(centroid, (SparseVector) centInstances.get(n).getData()); if (currentDist > newCentroidDist) { newCentroid = n; newCentroidDist = currentDist; cacheList = centInstances; } } } if (cacheList == null) { logger.info("Can't find an instance to move. Exiting."); // Can't find an instance to move. return null; } else clusterMeans.set(c, (SparseVector) cacheList.get( newCentroid).getData()); default: return null; } } } logger.info("Iter " + iterations + " deltaMeans = " + deltaMeans); } if (deltaMeans <= MEANS_TOLERANCE) logger.info("KMeans converged with deltaMeans = " + deltaMeans); else if (iterations >= MAX_ITER) logger.info("Maximum number of iterations (" + MAX_ITER + ") reached."); else if (deltaPoints <= instances.size() * POINTS_TOLERANCE) logger.info("Minimum number of points (np*" + POINTS_TOLERANCE + "=" + (int) (instances.size() * POINTS_TOLERANCE) + ") moved in last iteration. Saying converged."); return new Clustering(instances, numClusters, clusterLabels); } /** * Uses a MAX-MIN heuristic to seed the initial cluster means.. * * @param instList List of instances. * @param metric Distance metric. */ private void initializeMeansSample(InstanceList instList, Metric metric) { // InstanceList has no remove() and null instances aren't // parsed out by most Pipes, so we have to pre-process // here and possibly leave some instances without // cluster assignments. ArrayList instances = new ArrayList(instList.size()); for (int i = 0; i < instList.size(); i++) { Instance ins = instList.get(i); SparseVector sparse = (SparseVector) ins.getData(); if (sparse.numLocations() == 0) continue; instances.add(ins); } // Add next center that has the MAX of the MIN of the distances from // each of the previous j-1 centers (idea from Andrew Moore tutorial, // not sure who came up with it originally) for (int i = 0; i < numClusters; i++) { double max = 0; int selected = 0; for (int k = 0; k < instances.size(); k++) { double min = Double.MAX_VALUE; Instance ins = instances.get(k); SparseVector inst = (SparseVector) ins.getData(); for (int j = 0; j < clusterMeans.size(); j++) { SparseVector centerInst = clusterMeans.get(j); double dist = metric.distance(centerInst, inst); if (dist < min) min = dist; } if (min > max) { selected = k; max = min; } } Instance newCenter = instances.remove(selected); clusterMeans.add((SparseVector) newCenter.getData()); } } /** * Return the ArrayList of cluster means after a run of the algorithm. * * @return An ArrayList of Instances. */ public ArrayList getClusterMeans() { return this.clusterMeans; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy