org.apache.mahout.clustering.streaming.cluster.BallKMeans Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-mr Show documentation
Show all versions of mahout-mr Show documentation
Scalable machine learning libraries
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.streaming.cluster;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.WeightedVector;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.Multinomial;
import org.apache.mahout.math.random.WeightedThing;
/**
* Implements a ball k-means algorithm for weighted vectors with probabilistic seeding similar to k-means++.
* The idea is that k-means++ gives good starting clusters and ball k-means can tune up the final result very nicely
* in only a few passes (or even in a single iteration for well-clusterable data).
*
* A good reference for this class of algorithms is "The Effectiveness of Lloyd-Type Methods for the k-Means Problem"
* by Rafail Ostrovsky, Yuval Rabani, Leonard J. Schulman and Chaitanya Swamy. The code here uses the seeding strategy
* as described in section 4.1.1 of that paper and the ball k-means step as described in section 4.2. We support
* multiple iterations in contrast to the algorithm described in the paper.
*/
public class BallKMeans implements Iterable {
/**
* The searcher containing the centroids.
*/
private final UpdatableSearcher centroids;
/**
* The number of clusters to cluster the data into.
*/
private final int numClusters;
/**
* The maximum number of iterations of the algorithm to run waiting for the cluster assignments
* to stabilize. If there are no changes in cluster assignment earlier, we can finish early.
*/
private final int maxNumIterations;
/**
* When deciding which points to include in the new centroid calculation,
* it's preferable to exclude outliers since it increases the rate of convergence.
* So, we calculate the distance from each cluster to its closest neighboring cluster. When
* evaluating the points assigned to a cluster, we compare the distance between the centroid to
* the point with the distance between the centroid and its closest centroid neighbor
* multiplied by this trimFraction. If the distance between the centroid and the point is
* greater, we consider it an outlier and we don't use it.
*/
private final double trimFraction;
/**
* Selecting the initial centroids is the most important part of the ball k-means clustering. Poor choices, like two
* centroids in the same actual cluster result in a low-quality final result.
* k-means++ initialization yields good quality clusters, especially when using BallKMeans after StreamingKMeans as
* the points have weights.
* Simple, random selection of the points based on their weights is faster but sometimes fails to produce the
* desired number of clusters.
* This field is true if the initialization should be done with k-means++.
*/
private final boolean kMeansPlusPlusInit;
/**
* When using trimFraction, the weight of each centroid will not be the sum of the weights of
* the vectors assigned to that cluster because outliers are not used to compute the updated
* centroid.
* So, the total weight is probably wrong. This can be fixed by doing another pass over the
* data points and adjusting the weights of each centroid. This doesn't update the coordinates
* of the centroids, but is useful if the weights matter.
*/
private final boolean correctWeights;
/**
* When running multiple ball k-means passes to get the one with the smallest total cost, can compute the
* overall cost, using all the points for clustering, or reserve a fraction of them, testProbability in a test set.
* The cost is the sum of the distances between each point and its corresponding centroid.
* We then use this set of points to compute the total cost on. We're therefore trying to select the clustering
* that best describes the underlying distribution of the clusters.
* This field is the probability of assigning a given point to the test set. If this is 0, the cost will be computed
* on the entire set of points.
*/
private final double testProbability;
/**
* Whether or not testProbability > 0, i.e., there exists a non-empty 'test' set.
*/
private final boolean splitTrainTest;
/**
* How many k-means runs to have. If there's more than one run, we compute the cost of each clustering as described
* above and select the clustering that minimizes the cost.
* Multiple runs are a lot more useful when using the random initialization. With kmeans++, 1-2 runs are enough and
* more runs are not likely to help quality much.
*/
private final int numRuns;
/**
* Random object to sample values from.
*/
private final Random random;
public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations) {
// By default, the trimFraction is 0.9, k-means++ is used, the weights will be corrected at the end,
// there will be 0 points in the test set and 1 run.
this(searcher, numClusters, maxNumIterations, 0.9, true, true, 0.0, 1);
}
public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations,
boolean kMeansPlusPlusInit, int numRuns) {
// By default, the trimFraction is 0.9, k-means++ is used, the weights will be corrected at the end,
// there will be 10% points of in the test set.
this(searcher, numClusters, maxNumIterations, 0.9, kMeansPlusPlusInit, true, 0.1, numRuns);
}
public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations,
double trimFraction, boolean kMeansPlusPlusInit, boolean correctWeights,
double testProbability, int numRuns) {
Preconditions.checkArgument(searcher.size() == 0, "Searcher must be empty initially to populate with centroids");
Preconditions.checkArgument(numClusters > 0, "The requested number of clusters must be positive");
Preconditions.checkArgument(maxNumIterations > 0, "The maximum number of iterations must be positive");
Preconditions.checkArgument(trimFraction > 0, "The trim fraction must be positive");
Preconditions.checkArgument(testProbability >= 0 && testProbability < 1, "The testProbability must be in [0, 1)");
Preconditions.checkArgument(numRuns > 0, "There has to be at least one run");
this.centroids = searcher;
this.numClusters = numClusters;
this.maxNumIterations = maxNumIterations;
this.trimFraction = trimFraction;
this.kMeansPlusPlusInit = kMeansPlusPlusInit;
this.correctWeights = correctWeights;
this.testProbability = testProbability;
this.splitTrainTest = testProbability > 0;
this.numRuns = numRuns;
this.random = RandomUtils.getRandom();
}
public Pair, List extends WeightedVector>> splitTrainTest(
List extends WeightedVector> datapoints) {
// If there will be no points assigned to the test set, return now.
if (testProbability == 0) {
return new Pair, List extends WeightedVector>>(datapoints,
new ArrayList());
}
int numTest = (int) (testProbability * datapoints.size());
Preconditions.checkArgument(numTest > 0 && numTest < datapoints.size(),
"Must have nonzero number of training and test vectors. Asked for %.1f %% of %d vectors for test",
testProbability * 100, datapoints.size());
Collections.shuffle(datapoints);
return new Pair, List extends WeightedVector>>(
datapoints.subList(numTest, datapoints.size()), datapoints.subList(0, numTest));
}
/**
* Clusters the datapoints in the list doing either random seeding of the centroids or k-means++.
*
* @param datapoints the points to be clustered.
* @return an UpdatableSearcher with the resulting clusters.
*/
public UpdatableSearcher cluster(List extends WeightedVector> datapoints) {
Pair, List extends WeightedVector>> trainTestSplit = splitTrainTest(datapoints);
List bestCentroids = new ArrayList<>();
double cost = Double.POSITIVE_INFINITY;
double bestCost = Double.POSITIVE_INFINITY;
for (int i = 0; i < numRuns; ++i) {
centroids.clear();
if (kMeansPlusPlusInit) {
// Use k-means++ to set initial centroids.
initializeSeedsKMeansPlusPlus(trainTestSplit.getFirst());
} else {
// Randomly select the initial centroids.
initializeSeedsRandomly(trainTestSplit.getFirst());
}
// Do k-means iterations with trimmed mean computation (aka ball k-means).
if (numRuns > 1) {
// If the clustering is successful (there are no zero-weight centroids).
iterativeAssignment(trainTestSplit.getFirst());
// Compute the cost of the clustering and possibly save the centroids.
cost = ClusteringUtils.totalClusterCost(
splitTrainTest ? datapoints : trainTestSplit.getSecond(), centroids);
if (cost < bestCost) {
bestCost = cost;
bestCentroids.clear();
Iterables.addAll(bestCentroids, centroids);
}
} else {
// If there is only going to be one run, the cost doesn't need to be computed, so we just return the clustering.
iterativeAssignment(datapoints);
return centroids;
}
}
if (bestCost == Double.POSITIVE_INFINITY) {
throw new RuntimeException("No valid clustering was found");
}
if (cost != bestCost) {
centroids.clear();
centroids.addAll(bestCentroids);
}
if (correctWeights) {
for (WeightedVector testDatapoint : trainTestSplit.getSecond()) {
WeightedVector closest = (WeightedVector) centroids.searchFirst(testDatapoint, false).getValue();
closest.setWeight(closest.getWeight() + testDatapoint.getWeight());
}
}
return centroids;
}
/**
* Selects some of the original points randomly with probability proportional to their weights. This is much
* less sophisticated than the kmeans++ approach, however it is faster and coupled with
*
* The side effect of this method is to fill the centroids structure itself.
*
* @param datapoints The datapoints to select from. These datapoints should be WeightedVectors of some kind.
*/
private void initializeSeedsRandomly(List extends WeightedVector> datapoints) {
int numDatapoints = datapoints.size();
double totalWeight = 0;
for (WeightedVector datapoint : datapoints) {
totalWeight += datapoint.getWeight();
}
Multinomial seedSelector = new Multinomial<>();
for (int i = 0; i < numDatapoints; ++i) {
seedSelector.add(i, datapoints.get(i).getWeight() / totalWeight);
}
for (int i = 0; i < numClusters; ++i) {
int sample = seedSelector.sample();
seedSelector.delete(sample);
Centroid centroid = new Centroid(datapoints.get(sample));
centroid.setIndex(i);
centroids.add(centroid);
}
}
/**
* Selects some of the original points according to the k-means++ algorithm. The basic idea is that
* points are selected with probability proportional to their distance from any selected point. In
* this version, points have weights which multiply their likelihood of being selected. This is the
* same as if there were as many copies of the same point as indicated by the weight.
*
* This is pretty expensive, but it vastly improves the quality and convergences of the k-means algorithm.
* The basic idea can be made much faster by only processing a random subset of the original points.
* In the context of streaming k-means, the total number of possible seeds will be about k log n so this
* selection will cost O(k^2 (log n)^2) which isn't much worse than the random sampling idea. At
* n = 10^9, the cost of this initialization will be about 10x worse than a reasonable random sampling
* implementation.
*
* The side effect of this method is to fill the centroids structure itself.
*
* @param datapoints The datapoints to select from. These datapoints should be WeightedVectors of some kind.
*/
private void initializeSeedsKMeansPlusPlus(List extends WeightedVector> datapoints) {
Preconditions.checkArgument(datapoints.size() > 1, "Must have at least two datapoints points to cluster " +
"sensibly");
Preconditions.checkArgument(datapoints.size() >= numClusters,
String.format("Must have more datapoints [%d] than clusters [%d]", datapoints.size(), numClusters));
// Compute the centroid of all of the datapoints. This is then used to compute the squared radius of the datapoints.
Centroid center = new Centroid(datapoints.iterator().next());
for (WeightedVector row : Iterables.skip(datapoints, 1)) {
center.update(row);
}
// Given the centroid, we can compute \Delta_1^2(X), the total squared distance for the datapoints
// this accelerates seed selection.
double deltaX = 0;
DistanceMeasure distanceMeasure = centroids.getDistanceMeasure();
for (WeightedVector row : datapoints) {
deltaX += distanceMeasure.distance(row, center);
}
// Find the first seed c_1 (and conceptually the second, c_2) as might be done in the 2-means clustering so that
// the probability of selecting c_1 and c_2 is proportional to || c_1 - c_2 ||^2. This is done
// by first selecting c_1 with probability:
//
// p(c_1) = sum_{c_1} || c_1 - c_2 ||^2 \over sum_{c_1, c_2} || c_1 - c_2 ||^2
//
// This can be simplified to:
//
// p(c_1) = \Delta_1^2(X) + n || c_1 - c ||^2 / (2 n \Delta_1^2(X))
//
// where c = \sum x / n and \Delta_1^2(X) = sum || x - c ||^2
//
// All subsequent seeds c_i (including c_2) can then be selected from the remaining points with probability
// proportional to Pr(c_i == x_j) = min_{m < i} || c_m - x_j ||^2.
// Multinomial distribution of vector indices for the selection seeds. These correspond to
// the indices of the vectors in the original datapoints list.
Multinomial seedSelector = new Multinomial<>();
for (int i = 0; i < datapoints.size(); ++i) {
double selectionProbability =
deltaX + datapoints.size() * distanceMeasure.distance(datapoints.get(i), center);
seedSelector.add(i, selectionProbability);
}
int selected = random.nextInt(datapoints.size());
Centroid c_1 = new Centroid(datapoints.get(selected).clone());
c_1.setIndex(0);
// Construct a set of weighted things which can be used for random selection. Initial weights are
// set to the squared distance from c_1
for (int i = 0; i < datapoints.size(); ++i) {
WeightedVector row = datapoints.get(i);
double w = distanceMeasure.distance(c_1, row) * 2 * Math.log(1 + row.getWeight());
seedSelector.set(i, w);
}
// From here, seeds are selected with probability proportional to:
//
// r_i = min_{c_j} || x_i - c_j ||^2
//
// when we only have c_1, we have already set these distances and as we select each new
// seed, we update the minimum distances.
centroids.add(c_1);
int clusterIndex = 1;
while (centroids.size() < numClusters) {
// Select according to weights.
int seedIndex = seedSelector.sample();
Centroid nextSeed = new Centroid(datapoints.get(seedIndex));
nextSeed.setIndex(clusterIndex++);
centroids.add(nextSeed);
// Don't select this one again.
seedSelector.delete(seedIndex);
// Re-weight everything according to the minimum distance to a seed.
for (int currSeedIndex : seedSelector) {
WeightedVector curr = datapoints.get(currSeedIndex);
double newWeight = nextSeed.getWeight() * distanceMeasure.distance(nextSeed, curr);
if (newWeight < seedSelector.getWeight(currSeedIndex)) {
seedSelector.set(currSeedIndex, newWeight);
}
}
}
}
/**
* Examines the datapoints and updates cluster centers to be the centroid of the nearest datapoints points. To
* compute a new center for cluster c_i, we average all points that are closer than d_i * trimFraction
* where d_i is
*
* d_i = min_j \sqrt ||c_j - c_i||^2
*
* By ignoring distant points, the centroids converge more quickly to a good approximation of the
* optimal k-means solution (given good starting points).
*
* @param datapoints the points to cluster.
*/
private void iterativeAssignment(List extends WeightedVector> datapoints) {
DistanceMeasure distanceMeasure = centroids.getDistanceMeasure();
// closestClusterDistances.get(i) is the distance from the i'th cluster to its closest
// neighboring cluster.
List closestClusterDistances = new ArrayList<>(numClusters);
// clusterAssignments[i] == j means that the i'th point is assigned to the j'th cluster. When
// these don't change, we are done.
// Each point is assigned to the invalid "-1" cluster initially.
List clusterAssignments = new ArrayList<>(Collections.nCopies(datapoints.size(), -1));
boolean changed = true;
for (int i = 0; changed && i < maxNumIterations; i++) {
changed = false;
// We compute what the distance between each cluster and its closest neighbor is to set a
// proportional distance threshold for points that should be involved in calculating the
// centroid.
closestClusterDistances.clear();
for (Vector center : centroids) {
// If a centroid has no points assigned to it, the clustering failed.
Vector closestOtherCluster = centroids.searchFirst(center, true).getValue();
closestClusterDistances.add(distanceMeasure.distance(center, closestOtherCluster));
}
// Copies the current cluster centroids to newClusters and sets their weights to 0. This is
// so we calculate the new centroids as we go through the datapoints.
List newCentroids = new ArrayList<>();
for (Vector centroid : centroids) {
// need a deep copy because we will mutate these values
Centroid newCentroid = (Centroid)centroid.clone();
newCentroid.setWeight(0);
newCentroids.add(newCentroid);
}
// Pass over the datapoints computing new centroids.
for (int j = 0; j < datapoints.size(); ++j) {
WeightedVector datapoint = datapoints.get(j);
// Get the closest cluster this point belongs to.
WeightedThing closestPair = centroids.searchFirst(datapoint, false);
int closestIndex = ((WeightedVector) closestPair.getValue()).getIndex();
double closestDistance = closestPair.getWeight();
// Update its cluster assignment if necessary.
if (closestIndex != clusterAssignments.get(j)) {
changed = true;
clusterAssignments.set(j, closestIndex);
}
// Only update if the datapoints point is near enough. What this means is that the weight
// of outliers is NOT taken into account and the final weights of the centroids will
// reflect this (it will be less or equal to the initial sum of the weights).
if (closestDistance < trimFraction * closestClusterDistances.get(closestIndex)) {
newCentroids.get(closestIndex).update(datapoint);
}
}
// Add the new centers back into searcher.
centroids.clear();
centroids.addAll(newCentroids);
}
if (correctWeights) {
for (Vector v : centroids) {
((Centroid)v).setWeight(0);
}
for (WeightedVector datapoint : datapoints) {
Centroid closestCentroid = (Centroid) centroids.searchFirst(datapoint, false).getValue();
closestCentroid.setWeight(closestCentroid.getWeight() + datapoint.getWeight());
}
}
}
@Override
public Iterator iterator() {
return Iterators.transform(centroids.iterator(), new Function() {
@Override
public Centroid apply(Vector input) {
Preconditions.checkArgument(input instanceof Centroid, "Non-centroid in centroids " +
"searcher");
//noinspection ConstantConditions
return (Centroid)input;
}
});
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy