org.apache.mahout.clustering.streaming.cluster.StreamingKMeans Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.streaming.cluster;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.jet.math.Constants;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.WeightedThing;
/**
* Implements a streaming k-means algorithm for weighted vectors.
* The goal clustering points one at a time, especially useful for MapReduce mappers that get inputs one at a time.
*
* A rough description of the algorithm:
* Suppose there are l clusters at one point and a new point p is added.
* The new point can either be added to one of the existing l clusters or become a new cluster. To decide:
* - let c be the closest cluster to point p;
* - let d be the distance between c and p;
* - if d > distanceCutoff, create a new cluster from p (p is too far away from the clusters to be part of them;
* distanceCutoff represents the largest distance from a point its assigned cluster's centroid);
* - else (d <= distanceCutoff), create a new cluster with probability d / distanceCutoff (the probability of creating
* a new cluster increases as d increases).
* There will be either l points or l + 1 points after processing a new point.
*
* As the number of clusters increases, it will go over the numClusters limit (numClusters represents a recommendation
* for the number of clusters that there should be at the end). To decrease the number of clusters the existing clusters
* are treated as data points and are re-clustered (collapsed). This tends to make the number of clusters go down.
* If the number of clusters is still too high, distanceCutoff is increased.
*
* For more details, see:
* - "Streaming k-means approximation" by N. Ailon, R. Jaiswal, C. Monteleoni
* http://books.nips.cc/papers/files/nips22/NIPS2009_1085.pdf
* - "Fast and Accurate k-means for Large Datasets" by M. Shindler, A. Wong, A. Meyerson,
* http://books.nips.cc/papers/files/nips24/NIPS2011_1271.pdf
*/
public class StreamingKMeans implements Iterable {
/**
* The searcher containing the centroids that resulted from the clustering of points until now. When adding a new
* point we either assign it to one of the existing clusters in this searcher or create a new centroid for it.
*/
private final UpdatableSearcher centroids;
/**
* The estimated number of clusters to cluster the data in. If the actual number of clusters increases beyond this
* limit, the clusters will be "collapsed" (re-clustered, by treating them as data points). This doesn't happen
* recursively and a collapse might not necessarily make the number of actual clusters drop to less than this limit.
*
* If the goal is clustering a large data set into k clusters, numClusters SHOULD NOT BE SET to k. StreamingKMeans is
* useful to reduce the size of the data set by the mappers so that it can fit into memory in one reducer that runs
* BallKMeans.
*
* It is NOT MEANT to cluster the data into k clusters in one pass because it can't guarantee that there will in fact
* be k clusters in total. This is because of the dynamic nature of numClusters over the course of the runtime.
* To get an exact number of clusters, another clustering algorithm needs to be applied to the results.
*/
private int numClusters;
/**
* The number of data points seen so far. This is important for re-estimating numClusters when deciding to collapse
* the existing clusters.
*/
private int numProcessedDatapoints = 0;
/**
* This is the current value of the distance cutoff. Points which are much closer than this to a centroid will stick
* to it almost certainly. Points further than this to any centroid will form a new cluster.
*
* This increases (is multiplied by beta) when a cluster collapse did not make the number of clusters drop to below
* numClusters (it effectively increases the tolerance for cluster compactness discouraging the creation of new
* clusters). Since a collapse only happens when centroids.size() > clusterOvershoot * numClusters, the cutoff
* increases when the collapse didn't at least remove the slack in the number of clusters.
*/
private double distanceCutoff;
/**
* Parameter that controls the growth of the distanceCutoff. After n increases of the
* distanceCutoff starting at d_0, the final value is d_0 * beta^n (distance cutoffs increase following a geometric
* progression with ratio beta).
*/
private final double beta;
/**
* Multiplying clusterLogFactor with numProcessedDatapoints gets an estimate of the suggested
* number of clusters. This mirrors the recommended number of clusters for n points where there should be k actual
* clusters, k * log n. In the case of our estimate we use clusterLogFactor * log(numProcessedDataPoints).
*
* It is important to note that numClusters is NOT k. It is an estimate of k * log n.
*/
private final double clusterLogFactor;
/**
* Centroids are collapsed when the number of clusters becomes greater than clusterOvershoot * numClusters. This
* effectively means having a slack in numClusters so that the actual number of centroids, centroids.size() tracks
* numClusters approximately. The idea is that the actual number of clusters should be at least numClusters but not
* much more (so that we don't end up having 1 cluster / point).
*/
private final double clusterOvershoot;
/**
* Random object to sample values from.
*/
private final Random random = RandomUtils.getRandom();
/**
* Calls StreamingKMeans(searcher, numClusters, 1.3, 10, 2).
* @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int,
* double, double, double, double)
*/
public StreamingKMeans(UpdatableSearcher searcher, int numClusters) {
this(searcher, numClusters, 1.0 / numClusters, 1.3, 20, 2);
}
/**
* Calls StreamingKMeans(searcher, numClusters, distanceCutoff, 1.3, 10, 2).
* @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int,
* double, double, double, double)
*/
public StreamingKMeans(UpdatableSearcher searcher, int numClusters, double distanceCutoff) {
this(searcher, numClusters, distanceCutoff, 1.3, 20, 2);
}
/**
* Creates a new StreamingKMeans class given a searcher and the number of clusters to generate.
*
* @param searcher A Searcher that is used for performing nearest neighbor search. It MUST BE
* EMPTY initially because it will be used to keep track of the cluster
* centroids.
* @param numClusters An estimated number of clusters to generate for the data points.
* This can adjusted, but the actual number will depend on the data. The
* @param distanceCutoff The initial distance cutoff representing the value of the
* distance between a point and its closest centroid after which
* the new point will definitely be assigned to a new cluster.
* @param beta Ratio of geometric progression to use when increasing distanceCutoff. After n increases, distanceCutoff
* becomes distanceCutoff * beta^n. A smaller value increases the distanceCutoff less aggressively.
* @param clusterLogFactor Value multiplied with the number of points counted so far estimating the number of clusters
* to aim for. If the final number of clusters is known and this clustering is only for a
* sketch of the data, this can be the final number of clusters, k.
* @param clusterOvershoot Multiplicative slack factor for slowing down the collapse of the clusters.
*/
public StreamingKMeans(UpdatableSearcher searcher, int numClusters,
double distanceCutoff, double beta, double clusterLogFactor, double clusterOvershoot) {
this.centroids = searcher;
this.numClusters = numClusters;
this.distanceCutoff = distanceCutoff;
this.beta = beta;
this.clusterLogFactor = clusterLogFactor;
this.clusterOvershoot = clusterOvershoot;
}
/**
* @return an Iterator to the Centroids contained in this clusterer.
*/
@Override
public Iterator iterator() {
return Iterators.transform(centroids.iterator(), new Function() {
@Override
public Centroid apply(Vector input) {
return (Centroid)input;
}
});
}
/**
* Cluster the rows of a matrix, treating them as Centroids with weight 1.
* @param data matrix whose rows are to be clustered.
* @return the UpdatableSearcher containing the resulting centroids.
*/
public UpdatableSearcher cluster(Matrix data) {
return cluster(Iterables.transform(data, new Function() {
@Override
public Centroid apply(MatrixSlice input) {
// The key in a Centroid is actually the MatrixSlice's index.
return Centroid.create(input.index(), input.vector());
}
}));
}
/**
* Cluster the data points in an Iterable.
* @param datapoints Iterable whose elements are to be clustered.
* @return the UpdatableSearcher containing the resulting centroids.
*/
public UpdatableSearcher cluster(Iterable datapoints) {
return clusterInternal(datapoints, false);
}
/**
* Cluster one data point.
* @param datapoint to be clustered.
* @return the UpdatableSearcher containing the resulting centroids.
*/
public UpdatableSearcher cluster(final Centroid datapoint) {
return cluster(new Iterable() {
@Override
public Iterator iterator() {
return new Iterator() {
private boolean accessed = false;
@Override
public boolean hasNext() {
return !accessed;
}
@Override
public Centroid next() {
accessed = true;
return datapoint;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}
});
}
/**
* @return the number of clusters computed from the points until now.
*/
public int getNumClusters() {
return centroids.size();
}
/**
* Internal clustering method that gets called from the other wrappers.
* @param datapoints Iterable of data points to be clustered.
* @param collapseClusters whether this is an "inner" clustering and the datapoints are the previously computed
* centroids. Some logic is different to ensure counters are consistent but it behaves
* nearly the same.
* @return the UpdatableSearcher containing the resulting centroids.
*/
private UpdatableSearcher clusterInternal(Iterable datapoints, boolean collapseClusters) {
Iterator datapointsIterator = datapoints.iterator();
if (!datapointsIterator.hasNext()) {
return centroids;
}
int oldNumProcessedDataPoints = numProcessedDatapoints;
// We clear the centroids we have in case of cluster collapse, the old clusters are the
// datapoints but we need to re-cluster them.
if (collapseClusters) {
centroids.clear();
numProcessedDatapoints = 0;
}
if (centroids.size() == 0) {
// Assign the first datapoint to the first cluster.
// Adding a vector to a searcher would normally just reference the copy,
// but we could potentially mutate it and so we need to make a clone.
centroids.add(datapointsIterator.next().clone());
++numProcessedDatapoints;
}
// To cluster, we scan the data and either add each point to the nearest group or create a new group.
// when we get too many groups, we need to increase the threshold and rescan our current groups
while (datapointsIterator.hasNext()) {
Centroid row = datapointsIterator.next();
// Get the closest vector and its weight as a WeightedThing.
// The weight of the WeightedThing is the distance to the query and the value is a
// reference to one of the vectors we added to the searcher previously.
WeightedThing closestPair = centroids.searchFirst(row, false);
// We get a uniformly distributed random number between 0 and 1 and compare it with the
// distance to the closest cluster divided by the distanceCutoff.
// This is so that if the closest cluster is further than distanceCutoff,
// closestPair.getWeight() / distanceCutoff > 1 which will trigger the creation of a new
// cluster anyway.
// However, if the ratio is less than 1, we want to create a new cluster with probability
// proportional to the distance to the closest cluster.
double sample = random.nextDouble();
if (sample < row.getWeight() * closestPair.getWeight() / distanceCutoff) {
// Add new centroid, note that the vector is copied because we may mutate it later.
centroids.add(row.clone());
} else {
// Merge the new point with the existing centroid. This will update the centroid's actual
// position.
// We know that all the points we inserted in the centroids searcher are (or extend)
// WeightedVector, so the cast will always succeed.
Centroid centroid = (Centroid) closestPair.getValue();
// We will update the centroid by removing it from the searcher and reinserting it to
// ensure consistency.
if (!centroids.remove(centroid, Constants.EPSILON)) {
throw new RuntimeException("Unable to remove centroid");
}
centroid.update(row);
centroids.add(centroid);
}
++numProcessedDatapoints;
if (!collapseClusters && centroids.size() > clusterOvershoot * numClusters) {
numClusters = (int) Math.max(numClusters, clusterLogFactor * Math.log(numProcessedDatapoints));
List shuffled = new ArrayList<>();
for (Vector vector : centroids) {
shuffled.add((Centroid) vector);
}
Collections.shuffle(shuffled);
// Re-cluster using the shuffled centroids as data points. The centroids member variable
// is modified directly.
clusterInternal(shuffled, true);
if (centroids.size() > numClusters) {
distanceCutoff *= beta;
}
}
}
if (collapseClusters) {
numProcessedDatapoints = oldNumProcessedDataPoints;
}
return centroids;
}
public void reindexCentroids() {
int numCentroids = 0;
for (Centroid centroid : this) {
centroid.setIndex(numCentroids++);
}
}
/**
* @return the distanceCutoff (an upper bound for the maximum distance within a cluster).
*/
public double getDistanceCutoff() {
return distanceCutoff;
}
public void setDistanceCutoff(double distanceCutoff) {
this.distanceCutoff = distanceCutoff;
}
public DistanceMeasure getDistanceMeasure() {
return centroids.getDistanceMeasure();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy