![JAR search and dependency download from the Maven repository](/logo.png)
gov.sandia.cognition.learning.algorithm.clustering.MiniBatchKMeansClusterer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: MiniBatchKMeansClusterer.java
* Authors: Jeff Piersol
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright April 1, 2016, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.clustering;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.ClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.MiniBatchCentroidCluster;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.VectorMeanMiniBatchCentroidClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.CentroidClusterDivergenceFunction;
import gov.sandia.cognition.learning.algorithm.clustering.initializer.FixedClusterInitializer;
import gov.sandia.cognition.learning.algorithm.clustering.initializer.GreedyClusterInitializer;
import gov.sandia.cognition.learning.function.distance.EuclideanDistanceMetric;
import gov.sandia.cognition.math.Semimetric;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.DiscreteSamplingUtil;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.RandomAccess;
import static java.util.stream.Collectors.*;
import java.util.stream.IntStream;
import java.util.stream.Stream;
/**
* Approximates k-means clustering by working on random subsets of the
* data. This method is particularly useful for large data sets.
*
* For spherical k-means, use a Cosine distance semimetric for the
* divergence function, normalize all input data, and normalize each centroid
* after the algorithm converges.
*
* @param The type of the data to cluster. This is typically defined
* by the divergence function used.
* @author Jeff Piersol
* @since 4.0.0
*/
@PublicationReference(
author = "Jeff Piersol",
title = "Parallel Mini-Batch k-means Clustering",
type = PublicationType.Conference,
year = 2016,
publication
= "to appear",
url
= "to appear"
)
public class MiniBatchKMeansClusterer
extends KMeansClusterer
implements Randomized
{
private static final long serialVersionUID = 2587013040037999607L;
/**
* The default maximum number of iterations is {@value}.
*/
public static final int DEFAULT_MAX_ITERATIONS = 100000;
/**
* The random number generator to use for initialization and subset
* selection.
*/
protected Random random;
/**
* The size of the mini-batches.
*/
private int minibatchSize;
/**
* Indices of the data. This should be a range of ints= [0, data.size)
*/
protected List dataIndices;
/**
* Indicates if the iteration process should stop early. If the fraction of
* samples that changed assignment is lower than this number, iteration
* stops.
*/
private double stoppingCriterion = 0.01;
/**
* Create a clusterer with the default parameters. This is a 'vanilla'
* mini-batch k-means, using Euclidean distance for the semimetric.
*
* @param numClusters the number of clusters to output
*/
public MiniBatchKMeansClusterer(int numClusters)
{
this(numClusters, new Random(),
VectorMeanMiniBatchCentroidClusterCreator.INSTANCE);
}
private MiniBatchKMeansClusterer(int numClusters,
Random random,
ClusterCreator creator)
{
this(
numClusters, DEFAULT_MAX_ITERATIONS,
new GreedyClusterInitializer<>(EuclideanDistanceMetric.INSTANCE,
creator, random),
EuclideanDistanceMetric.INSTANCE, creator, random
);
}
/**
* Creates a new {@link MiniBatchKMeansClusterer}.
*
* @param numClusters the number of clusters to create
* @param maxIterations the number of iterations before stopping
* @param initializer sets the initial centroids
* @param metric the metric to use
* @param creator the cluster creator to use
* @param random the random number generator to use
*/
public MiniBatchKMeansClusterer(
int numClusters,
int maxIterations,
FixedClusterInitializer initializer,
Semimetric super Vector> metric,
ClusterCreator creator,
Random random)
{
super(numClusters, maxIterations, initializer,
new CentroidClusterDivergenceFunction<>(metric), creator);
this.setRandom(random);
}
@Override
@SuppressWarnings("unchecked")
public MiniBatchKMeansClusterer clone()
{
final MiniBatchKMeansClusterer result
= (MiniBatchKMeansClusterer) super.clone();
random = ObjectUtil.cloneSmart(random); //TODO: is this really what we want to do?
return result;
}
@Override
protected boolean initializeAlgorithm()
{
boolean superReturn = super.initializeAlgorithm();
if (superReturn)
{
minibatchSize = getNumClusters() < 1 ? 0 : minibatchSize <= 0
? Math.min(getNumElements(), 10000)
: minibatchSize; // I totally made up this heuristic
}
return superReturn;
}
/**
* Do a step of the clustering algorithm.
*
* @return true means keep going, false means stop clustering.
*/
@Override
protected boolean step()
{
// First, assign each data point to a cluster, given the current
// location of the clusters
List extends DataType> data = getData();
ArrayList sampleIndices
= DiscreteSamplingUtil.sampleWithReplacement(random, dataIndices,
minibatchSize);
List samples
= sampleIndices.stream().map(data::get).collect(toList());
int[] sampleAssignments = this.assignDataToClusters(samples);
ClusterCreator creator
= getCreator();
// Bin all of the samples into their clusters
Map> samplesInCluster
= IntStream.range(0, sampleIndices.size()).parallel()
.mapToObj(Integer::valueOf)
.collect(groupingByConcurrent(
idx -> clusters.get(sampleAssignments[idx]),
mapping(idx -> samples.get(idx), toList())
));
// Update centroids
samplesInCluster.entrySet().stream().parallel().forEach(
(Map.Entry> clusterAndSamples)
-> clusterAndSamples.getKey().updateCluster(
clusterAndSamples.getValue())
);
int numChanged = 0;
for (int i = 0; i < sampleAssignments.length; i++)
{
int assignment = sampleAssignments[i];
if (this.setAssignment(sampleIndices.get(i), assignment))
{
numChanged++;
}
}
this.setNumChanged(numChanged);
// Continue iterating if a significant number of points changed assignment
return this.getNumChanged() / (double) minibatchSize > stoppingCriterion;
}
/**
* Saves the final clustering for each data point.
*/
protected void saveFinalClustering()
{
if (clusters.size() > 0)
{
List extends DataType> data = getData();
assignments = assignDataToClusters(data);
clusters.forEach(cluster -> cluster.getMembers().clear());
IntStream.range(0, assignments.length).parallel()
.mapToObj(Integer::valueOf)
.collect(
groupingByConcurrent(idx -> assignments[idx]))
.forEach((clusterIdx, clusterPoints)
-> clusters.get(clusterIdx).getMembers()
.addAll(
clusterPoints.stream()
.map(idx -> data.get(idx))
.collect(toList())));
}
}
@Override
protected void cleanupAlgorithm()
{
saveFinalClustering();
}
@Override
public Random getRandom()
{
return random;
}
@Override
public final void setRandom(Random random)
{
this.random = random;
}
@Override
@SuppressWarnings("unchecked")
public List extends DataType> getData()
{
return (List extends DataType>) super.getData();
}
/**
* Set the data to be clustered. If the data is not a
* {@link RandomAccess} {@link List}, it will be copied into one.
*
* @param data
*/
@Override
public void setData(Collection extends Vector> data)
{
if (data == null)
{
data = new ArrayList<>();
}
super.setData(data instanceof List && data instanceof RandomAccess
? data
: new ArrayList<>(data));
this.dataIndices = IntStream.range(0, data.size()).boxed().collect(
toList());
}
/**
* Get the stopping criterion for this clusterer. See
* {@link #setStoppingCriterion(double)} for details on the criterion.
*
* @return
*/
public double getStoppingCriterion()
{
return stoppingCriterion;
}
/**
* Set the stopping criterion for this clusterer.
*
* @param stoppingCriterion if the fraction of samples that changed
* assignment is lower than this number, iteration stops. Set this to a
* negative value to disable early stopping
*/
public void setStoppingCriterion(double stoppingCriterion)
{
this.stoppingCriterion = stoppingCriterion;
}
/**
* Get the size of the mini-batches used.
*
* @return
*/
public int getMinibatchSize()
{
return minibatchSize;
}
/**
* Set the size of the mini-batches. If the size is ≤0, a heuristic will
* be used to compute the size before clustering.
*
* @param minibatchSize
*/
public void setMinibatchSize(int minibatchSize)
{
this.minibatchSize = minibatchSize;
}
@Override
protected int[] assignDataToClusters(
Collection extends Vector> data)
{
// Parallelize if there are more than a few data points
Stream extends Vector> dataStream = data.size() > 25
? data.parallelStream() : data.stream();
return dataStream.mapToInt(point -> this.getClosestClusterIndex(point)).toArray();
}
/**
* Can be used to create custom {@link MiniBatchKMeansClusterer}s without
* using the big constructor.
*
* @param
*/
public static class Builder
{
private int numClusters, maxIterations, minibatchSize;
private FixedClusterInitializer initializer;
private Semimetric super Vector> metric;
private ClusterCreator creator;
private Random random;
/**
* Create a mini-batch k-means clusterer builder and set it to
* the given number of clusters. Centroids will be initialized using a
* {@link GreedyClusterInitializer}.
*
* @param numClusters
*/
@SuppressWarnings("unchecked")
public Builder(int numClusters)
{
this(numClusters, EuclideanDistanceMetric.INSTANCE);
}
/**
* Create a mini-batch k-means clusterer builder and set it to
* the given number of clusters. Centroids will be initialized using a
* {@link GreedyClusterInitializer}, and the given metric will be used
* to measure all distances.
*
* @param numClusters
* @param metric the semimetric to use to measure distances
*/
@SuppressWarnings("unchecked")
public Builder(int numClusters,
Semimetric super Vector> metric)
{
this.numClusters = numClusters;
this.maxIterations = DEFAULT_MAX_ITERATIONS;
this.random = new Random();
this.creator = VectorMeanMiniBatchCentroidClusterCreator.INSTANCE;
this.metric = metric;
this.initializer = new GreedyClusterInitializer<>(this.metric,
creator,
random);
}
/**
* Builds the clusterer.
*
* @return
* The newly built clusterer.
*/
public MiniBatchKMeansClusterer build()
{
MiniBatchKMeansClusterer clusterer
= new MiniBatchKMeansClusterer<>(numClusters, maxIterations,
initializer, metric, creator, random);
clusterer.setMinibatchSize(minibatchSize);
return clusterer;
}
/**
* @param numClusters the number of clusters to create
* @return the builder
*/
public Builder withNumClusters(int numClusters)
{
this.numClusters = numClusters;
return this;
}
/**
* @param maxIterations the number of iterations before stopping
* @return the builder
*/
public Builder withMaxIterations(int maxIterations)
{
this.maxIterations = maxIterations;
return this;
}
/**
* @param minibatchSize the mini-batch size
* @return the builder
* @see MiniBatchKMeansClusterer#setMinibatchSize(int)
*/
public Builder withMinibatchSize(int minibatchSize)
{
this.minibatchSize = minibatchSize;
return this;
}
/**
* @param initializer sets the initial centroids
* @return the builder
*/
public Builder withInitializer(
FixedClusterInitializer initializer)
{
this.initializer = initializer;
return this;
}
/**
* @param creator the cluster creator to use
* @return the builder
*/
public Builder withCreator(
ClusterCreator creator)
{
this.creator = creator;
return this;
}
/**
* @param random the random number generator to use
* @return the builder
*/
public Builder withRandom(Random random)
{
this.random = random;
return this;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy