All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.clustering.streaming.cluster.StreamingKMeans Maven / Gradle / Ivy

There is a newer version: 0.13.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.clustering.streaming.cluster;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.jet.math.Constants;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.WeightedThing;

/**
 * Implements a streaming k-means algorithm for weighted vectors.
 * The goal clustering points one at a time, especially useful for MapReduce mappers that get inputs one at a time.
 *
 * A rough description of the algorithm:
 * Suppose there are l clusters at one point and a new point p is added.
 * The new point can either be added to one of the existing l clusters or become a new cluster. To decide:
 * - let c be the closest cluster to point p;
 * - let d be the distance between c and p;
 * - if d > distanceCutoff, create a new cluster from p (p is too far away from the clusters to be part of them;
 * distanceCutoff represents the largest distance from a point its assigned cluster's centroid);
 * - else (d <= distanceCutoff), create a new cluster with probability d / distanceCutoff (the probability of creating
 * a new cluster increases as d increases).
 * There will be either l points or l + 1 points after processing a new point.
 *
 * As the number of clusters increases, it will go over the numClusters limit (numClusters represents a recommendation
 * for the number of clusters that there should be at the end). To decrease the number of clusters the existing clusters
 * are treated as data points and are re-clustered (collapsed). This tends to make the number of clusters go down.
 * If the number of clusters is still too high, distanceCutoff is increased.
 *
 * For more details, see:
 * - "Streaming  k-means approximation" by N. Ailon, R. Jaiswal, C. Monteleoni
 * http://books.nips.cc/papers/files/nips22/NIPS2009_1085.pdf
 * - "Fast and Accurate k-means for Large Datasets" by M. Shindler, A. Wong, A. Meyerson,
 * http://books.nips.cc/papers/files/nips24/NIPS2011_1271.pdf
 */
public class StreamingKMeans implements Iterable {
  /**
   * The searcher containing the centroids that resulted from the clustering of points until now. When adding a new
   * point we either assign it to one of the existing clusters in this searcher or create a new centroid for it.
   */
  private final UpdatableSearcher centroids;

  /**
   * The estimated number of clusters to cluster the data in. If the actual number of clusters increases beyond this
   * limit, the clusters will be "collapsed" (re-clustered, by treating them as data points). This doesn't happen
   * recursively and a collapse might not necessarily make the number of actual clusters drop to less than this limit.
   *
   * If the goal is clustering a large data set into k clusters, numClusters SHOULD NOT BE SET to k. StreamingKMeans is
   * useful to reduce the size of the data set by the mappers so that it can fit into memory in one reducer that runs
   * BallKMeans.
   *
   * It is NOT MEANT to cluster the data into k clusters in one pass because it can't guarantee that there will in fact
   * be k clusters in total. This is because of the dynamic nature of numClusters over the course of the runtime.
   * To get an exact number of clusters, another clustering algorithm needs to be applied to the results.
   */
  private int numClusters;

  /**
   * The number of data points seen so far. This is important for re-estimating numClusters when deciding to collapse
   * the existing clusters.
   */
  private int numProcessedDatapoints = 0;

  /**
   * This is the current value of the distance cutoff.  Points which are much closer than this to a centroid will stick
   * to it almost certainly. Points further than this to any centroid will form a new cluster.
   *
   * This increases (is multiplied by beta) when a cluster collapse did not make the number of clusters drop to below
   * numClusters (it effectively increases the tolerance for cluster compactness discouraging the creation of new
   * clusters). Since a collapse only happens when centroids.size() > clusterOvershoot * numClusters, the cutoff
   * increases when the collapse didn't at least remove the slack in the number of clusters.
   */
  private double distanceCutoff;

  /**
   * Parameter that controls the growth of the distanceCutoff. After n increases of the
   * distanceCutoff starting at d_0, the final value is d_0 * beta^n (distance cutoffs increase following a geometric
   * progression with ratio beta).
   */
  private final double beta;

  /**
   * Multiplying clusterLogFactor with numProcessedDatapoints gets an estimate of the suggested
   * number of clusters. This mirrors the recommended number of clusters for n points where there should be k actual
   * clusters, k * log n. In the case of our estimate we use clusterLogFactor * log(numProcessedDataPoints).
   *
   * It is important to note that numClusters is NOT k. It is an estimate of k * log n.
   */
  private final double clusterLogFactor;

  /**
   * Centroids are collapsed when the number of clusters becomes greater than clusterOvershoot * numClusters. This
   * effectively means having a slack in numClusters so that the actual number of centroids, centroids.size() tracks
   * numClusters approximately. The idea is that the actual number of clusters should be at least numClusters but not
   * much more (so that we don't end up having 1 cluster / point).
   */
  private final double clusterOvershoot;

  /**
   * Random object to sample values from.
   */
  private final Random random = RandomUtils.getRandom();

  /**
   * Calls StreamingKMeans(searcher, numClusters, 1.3, 10, 2).
   * @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int,
   * double, double, double, double)
   */
  public StreamingKMeans(UpdatableSearcher searcher, int numClusters) {
    this(searcher, numClusters, 1.0 / numClusters, 1.3, 20, 2);
  }

  /**
   * Calls StreamingKMeans(searcher, numClusters, distanceCutoff, 1.3, 10, 2).
   * @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int,
   * double, double, double, double)
   */
  public StreamingKMeans(UpdatableSearcher searcher, int numClusters, double distanceCutoff) {
    this(searcher, numClusters, distanceCutoff, 1.3, 20, 2);
  }

  /**
   * Creates a new StreamingKMeans class given a searcher and the number of clusters to generate.
   *
   * @param searcher A Searcher that is used for performing nearest neighbor search. It MUST BE
   *                 EMPTY initially because it will be used to keep track of the cluster
   *                 centroids.
   * @param numClusters An estimated number of clusters to generate for the data points.
   *                    This can adjusted, but the actual number will depend on the data. The
   * @param distanceCutoff  The initial distance cutoff representing the value of the
   *                      distance between a point and its closest centroid after which
   *                      the new point will definitely be assigned to a new cluster.
   * @param beta Ratio of geometric progression to use when increasing distanceCutoff. After n increases, distanceCutoff
   *             becomes distanceCutoff * beta^n. A smaller value increases the distanceCutoff less aggressively.
   * @param clusterLogFactor Value multiplied with the number of points counted so far estimating the number of clusters
   *                         to aim for. If the final number of clusters is known and this clustering is only for a
   *                         sketch of the data, this can be the final number of clusters, k.
   * @param clusterOvershoot Multiplicative slack factor for slowing down the collapse of the clusters.
   */
  public StreamingKMeans(UpdatableSearcher searcher, int numClusters,
                         double distanceCutoff, double beta, double clusterLogFactor, double clusterOvershoot) {
    this.centroids = searcher;
    this.numClusters = numClusters;
    this.distanceCutoff = distanceCutoff;
    this.beta = beta;
    this.clusterLogFactor = clusterLogFactor;
    this.clusterOvershoot = clusterOvershoot;
  }

  /**
   * @return an Iterator to the Centroids contained in this clusterer.
   */
  @Override
  public Iterator iterator() {
    return Iterators.transform(centroids.iterator(), new Function() {
      @Override
      public Centroid apply(Vector input) {
        return (Centroid)input;
      }
    });
  }

  /**
   * Cluster the rows of a matrix, treating them as Centroids with weight 1.
   * @param data matrix whose rows are to be clustered.
   * @return the UpdatableSearcher containing the resulting centroids.
   */
  public UpdatableSearcher cluster(Matrix data) {
    return cluster(Iterables.transform(data, new Function() {
      @Override
      public Centroid apply(MatrixSlice input) {
        // The key in a Centroid is actually the MatrixSlice's index.
        return Centroid.create(input.index(), input.vector());
      }
    }));
  }

  /**
   * Cluster the data points in an Iterable.
   * @param datapoints Iterable whose elements are to be clustered.
   * @return the UpdatableSearcher containing the resulting centroids.
   */
  public UpdatableSearcher cluster(Iterable datapoints) {
    return clusterInternal(datapoints, false);
  }

  /**
   * Cluster one data point.
   * @param datapoint to be clustered.
   * @return the UpdatableSearcher containing the resulting centroids.
   */
  public UpdatableSearcher cluster(final Centroid datapoint) {
    return cluster(new Iterable() {
      @Override
      public Iterator iterator() {
        return new Iterator() {
          private boolean accessed = false;

          @Override
          public boolean hasNext() {
            return !accessed;
          }

          @Override
          public Centroid next() {
            accessed = true;
            return datapoint;
          }

          @Override
          public void remove() {
            throw new UnsupportedOperationException();
          }
        };
      }
    });
  }

  /**
   * @return the number of clusters computed from the points until now.
   */
  public int getNumClusters() {
    return centroids.size();
  }

  /**
   * Internal clustering method that gets called from the other wrappers.
   * @param datapoints Iterable of data points to be clustered.
   * @param collapseClusters whether this is an "inner" clustering and the datapoints are the previously computed
   *                         centroids. Some logic is different to ensure counters are consistent but it behaves
   *                         nearly the same.
   * @return the UpdatableSearcher containing the resulting centroids.
   */
  private UpdatableSearcher clusterInternal(Iterable datapoints, boolean collapseClusters) {
    Iterator datapointsIterator = datapoints.iterator();
    if (!datapointsIterator.hasNext()) {
      return centroids;
    }

    int oldNumProcessedDataPoints = numProcessedDatapoints;
    // We clear the centroids we have in case of cluster collapse, the old clusters are the
    // datapoints but we need to re-cluster them.
    if (collapseClusters) {
      centroids.clear();
      numProcessedDatapoints = 0;
    }

    if (centroids.size() == 0) {
      // Assign the first datapoint to the first cluster.
      // Adding a vector to a searcher would normally just reference the copy,
      // but we could potentially mutate it and so we need to make a clone.
      centroids.add(datapointsIterator.next().clone());
      ++numProcessedDatapoints;
    }

    // To cluster, we scan the data and either add each point to the nearest group or create a new group.
    // when we get too many groups, we need to increase the threshold and rescan our current groups
    while (datapointsIterator.hasNext()) {
      Centroid row = datapointsIterator.next();
      // Get the closest vector and its weight as a WeightedThing.
      // The weight of the WeightedThing is the distance to the query and the value is a
      // reference to one of the vectors we added to the searcher previously.
      WeightedThing closestPair = centroids.searchFirst(row, false);

      // We get a uniformly distributed random number between 0 and 1 and compare it with the
      // distance to the closest cluster divided by the distanceCutoff.
      // This is so that if the closest cluster is further than distanceCutoff,
      // closestPair.getWeight() / distanceCutoff > 1 which will trigger the creation of a new
      // cluster anyway.
      // However, if the ratio is less than 1, we want to create a new cluster with probability
      // proportional to the distance to the closest cluster.
      double sample = random.nextDouble();
      if (sample < row.getWeight() * closestPair.getWeight() / distanceCutoff) {
        // Add new centroid, note that the vector is copied because we may mutate it later.
        centroids.add(row.clone());
      } else {
        // Merge the new point with the existing centroid. This will update the centroid's actual
        // position.
        // We know that all the points we inserted in the centroids searcher are (or extend)
        // WeightedVector, so the cast will always succeed.
        Centroid centroid = (Centroid) closestPair.getValue();

        // We will update the centroid by removing it from the searcher and reinserting it to
        // ensure consistency.
        if (!centroids.remove(centroid, Constants.EPSILON)) {
          throw new RuntimeException("Unable to remove centroid");
        }
        centroid.update(row);
        centroids.add(centroid);

      }
      ++numProcessedDatapoints;

      if (!collapseClusters && centroids.size() > clusterOvershoot * numClusters) {
        numClusters = (int) Math.max(numClusters, clusterLogFactor * Math.log(numProcessedDatapoints));

        List shuffled = new ArrayList<>();
        for (Vector vector : centroids) {
          shuffled.add((Centroid) vector);
        }
        Collections.shuffle(shuffled);
        // Re-cluster using the shuffled centroids as data points. The centroids member variable
        // is modified directly.
        clusterInternal(shuffled, true);

        if (centroids.size() > numClusters) {
          distanceCutoff *= beta;
        }
      }
    }

    if (collapseClusters) {
      numProcessedDatapoints = oldNumProcessedDataPoints;
    }
    return centroids;
  }

  public void reindexCentroids() {
    int numCentroids = 0;
    for (Centroid centroid : this) {
      centroid.setIndex(numCentroids++);
    }
  }

  /**
   * @return the distanceCutoff (an upper bound for the maximum distance within a cluster).
   */
  public double getDistanceCutoff() {
    return distanceCutoff;
  }

  public void setDistanceCutoff(double distanceCutoff) {
    this.distanceCutoff = distanceCutoff;
  }

  public DistanceMeasure getDistanceMeasure() {
    return centroids.getDistanceMeasure();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy