gov.sandia.cognition.learning.algorithm.clustering.DBSCANClusterer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: DBSCANClusterer.java
* Authors: Quinn McNamara
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright August 16, 2016, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.clustering;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.Cluster;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.ClusterCreator;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.Metric;
import gov.sandia.cognition.math.Semimetric;
import gov.sandia.cognition.math.geometry.KDTree;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
/**
* The DBSCAN
algorithm requires three parameters: a distance
* metric, a value for neighborhood radius, and a value for the minimum number
* of surrounding neighbors for a point to be considered non-noise. It clusters
* by iterating point-by-point and grouping points that are close together (in
* the same neighborhood). Points that are not in any neighborhood are labeled
* as noise. Noise points are grouped into the first resultant cluster.
*
* This implementation conditionally uses a KD tree to store the data points and
* perform efficient queries for neighborhoods. The KD tree is only used when
* the metric is a {@link Metric} (not a {@link Semimetric} like
* CosineDistanceMetric). When one of these conditions is not met, neighborhood
* querying has O(n) complexity, giving the overall algorithm O(n^2) time
* complexity. If a KD tree is used, queries have O(logn) complexity, giving the
* overall algorithm O(n logn) complexity.
*
* @param The type of the data to cluster. This is typically defined
* by the metric used.
* @param The type of {@code Cluster} created by the algorithm.
* This is typically defined by the cluster creator function used.
* @author Quinn McNamara
* @since 4.0.0
*/
@PublicationReference(
author =
{
"Martin Ester",
"Hans-Peter Kriegel",
"Jiirg Sander",
"Xiaowei Xu"
},
title = "A Density-Based Algorithm for Discovering Clusters in "
+ "Large Spatial Databases with Noise.",
type = PublicationType.Journal,
publication = "AAAI Press",
pages =
{
226 - 231
},
year = 1996,
url
= "https://www.aaai.org/Papers/KDD/1996/KDD96-037.pdf"
)
public class DBSCANClusterer>
extends AbstractAnytimeBatchLearner, Collection>
implements BatchClusterer
{
/**
* The default eps is {@value}.
*/
public static final double DEFAULT_EPS = 0.5;
/**
* The default minimum samples is {@value}.
*/
public static final int DEFAULT_MIN_SAMPLES = 5;
/**
* The default maximum number of iterations {@value}
*/
public static final int DEFAULT_MAX_ITERATIONS = Integer.MAX_VALUE;
/**
* The radius of a neighborhood.
*/
private double eps;
/**
* The minimum number of samples in the neighborhood for the point to be
* considered a core point.
*/
private int minSamples;
/**
* The distance metric to use.
*/
private Semimetric super DataType> metric;
/**
* A spatial index of the points to improve neighborhood querying.
*/
private KDTree> spatialIndex;
/**
* Number of clusters created so far.
*/
private int clusterCount;
/**
* Set of points that have been clustered so far.
*/
private HashSet clustered;
/**
* Set of points that have been visited so far.
*/
private HashSet visited;
/**
* All the data in an indexible structure.
*/
private ArrayList points;
/**
* The cluster creator for creating clusters.
*/
private ClusterCreator creator;
/**
* The current set of clusters.
*/
private ArrayList clusters;
/**
* The current cluster.
*/
private ArrayList currentCluster;
/**
* The cluster for points marked as noise.
*/
private ArrayList noiseCluster;
/**
* Index of the next point to process in step().
*/
private int pointIndex;
/**
* Creates a new instance of DBSCANClusterer.
*
* @param metric
* @param creator
*/
public DBSCANClusterer(
Semimetric super DataType> metric,
ClusterCreator creator)
{
this(DEFAULT_EPS, DEFAULT_MIN_SAMPLES, metric, creator);
}
/**
* Creates a new instance of AffinityPropagation.
*
* @param eps The divergence function to use to determine the divergence
* between two examples.
* @param minSamples The value for self-divergence to use, which controls
* the number of clusters created.
* @param metric The damping factor (lambda). Must be between 0.0 and 1.0.
* @param creator The cluster creator.
*/
public DBSCANClusterer(
double eps,
int minSamples,
Semimetric super DataType> metric,
ClusterCreator creator)
{
super(DEFAULT_MAX_ITERATIONS);
this.setNeighborhoodRadius(eps);
this.setMinSamples(minSamples);
this.setMetric(metric);
this.setCreator(creator);
}
@Override
public DBSCANClusterer clone()
{
@SuppressWarnings("unchecked")
final DBSCANClusterer result
= (DBSCANClusterer) super.clone();
result.metric = ObjectUtil.cloneSmart(this.metric);
result.clusters = null;
return result;
}
@Override
protected boolean initializeAlgorithm()
{
if (this.getData() == null || this.getData().size() <= 0)
{
// Make sure that the data is valid.
return false;
}
// Copy data into a data structure that can be indexed
this.points = new ArrayList(this.getData());
if (this.metric instanceof Metric)
{
// Construct a new KDTree with the data points
List> pts = this.points.stream()
.map(pt -> new DefaultInputOutputPair<>(pt, 0.0))
.collect(Collectors.toList());
this.spatialIndex = new KDTree<>(pts);
}
// Initialize the main data for the algorithm
this.clusters = new ArrayList();
this.clustered = new HashSet();
this.visited = new HashSet();
this.currentCluster = new ArrayList();
this.noiseCluster = new ArrayList();
// Noise cluster will be 0th cluster, start core clusters index at 1
this.clusters.add(0, this.creator.createCluster(this.noiseCluster));
this.clusterCount = 1;
// Ready to learn.
return true;
}
protected boolean step()
{
// Retrieve the point to process
DataType point = this.points.get(this.pointIndex);
if (!this.visited.contains(point))
{
// Mark point as visited
this.visited.add(point);
// Get all points in this point's neighborhood
List neighbors = this.regionQuery(point);
if (neighbors.size() < this.minSamples)
{
// Assign this point to the noise cluster
this.noiseCluster.add(point);
this.clustered.add(point);
}
else
{
// Expand this cluster
this.expandCluster(point, neighbors);
// Add expanded cluster to set of clusters
this.clusters.add(this.clusterCount, this.creator.createCluster(
this.currentCluster));
// Create the next cluster and set it as the current cluster
this.clusterCount++;
this.currentCluster = new ArrayList();
}
}
// Move to the next point, complete if we have processed every point
this.pointIndex++;
return (this.pointIndex < this.points.size());
}
/**
* Expands the current cluster to include the point given and all
* neighboring points that are not clustered. Repeats this process for the
* neighbors of neighboring points, etc.
*
* @param point The base point to expand the current cluster from.
* @param neighborsLinkedList The list of neighbors of the base point.
*/
private void expandCluster(
DataType point,
List neighbors)
{
LinkedList neighborsLinkedList = new LinkedList<>(neighbors);
this.currentCluster.add(point);
this.clustered.add(point);
while (!neighborsLinkedList.isEmpty())
{
DataType p = neighborsLinkedList.remove();
if (!this.visited.contains(p))
{
this.visited.add(p);
List neighbors2 = regionQuery(p);
if (neighbors2.size() >= this.minSamples)
{
neighborsLinkedList.addAll(neighbors2);
}
}
if (!this.clustered.contains(p))
{
this.currentCluster.add(p);
this.clustered.add(p);
}
}
}
/**
* Gets all the points neighboring the given point. These will be points
* that are at most eps (radius) away from the given point. Uses the spatial
* index (KD tree) if the metric is a Metric and points are Vectorizable.
* Otherwise, performs a brute-force search.
*
* @param point The point to get the neighborhood for.
* @return
*/
private List regionQuery(DataType point)
{
List result;
if (this.metric instanceof Metric)
{
result = this.spatialIndex.findNearestWithinRadius(
point, this.eps, (Metric super DataType>) this.metric)
.stream()
.map(pair -> pair.getFirst())
.collect(Collectors.toList());
}
else
{
result = this.points.stream()
.filter((p) -> (this.metric.evaluate(p, point) <= this.eps))
.collect(Collectors.toList());
}
return result;
}
protected void cleanupAlgorithm()
{
}
public ArrayList getResult()
{
return this.getClusters();
}
/**
* Gets the neighborhood radius.
*
* @return The eps.
*/
public double getNeighborhoodRadius()
{
return this.eps;
}
/**
* Sets the neighborhood radius.
*
* @param eps The eps.
*/
public void setNeighborhoodRadius(double eps)
{
if (eps < 0.0 || eps > 1.0)
{
throw new IllegalArgumentException(
"The eps must be between 0.0 and 1.0.");
}
this.eps = eps;
}
/**
* Gets the minimum number of samples.
*
* @return The minSamples.
*/
public double getMinSamples()
{
return this.minSamples;
}
/**
* Sets the minimum number of samples.
*
* @param minSamples The minSamples.
*/
public void setMinSamples(int minSamples)
{
this.minSamples = minSamples;
}
/**
* Gets the distance metric the clustering uses.
*
* @return The metric.
*/
public Semimetric super DataType> getMetric()
{
return this.metric;
}
/**
* Sets the distance metric the clustering uses.
*
* @param metric The metric.
*/
public void setMetric(Semimetric super DataType> metric)
{
this.metric = metric;
}
/**
* Gets the current clusters, which is a sparse mapping of exemplar
* identifier to cluster object.
*
* @return The current clusters.
*/
protected ArrayList getClusters()
{
return this.clusters;
}
/**
* Get the cluster at this index.
*
* @param i The index of the cluster.
* @return
*/
public ClusterType getCluster(int i)
{
return this.getClusters().get(i);
}
/**
* Sets the current clusters, which is a sparse mapping of exemplar
* identifier to cluster object.
*
* @param clusters The current clusters.
*/
protected void setClusters(
final ArrayList clusters)
{
this.clusters = clusters;
}
/**
* Gets the cluster creator.
*
* @return The cluster creator.
*/
public ClusterCreator getCreator()
{
return this.creator;
}
/**
* Sets the cluster creator.
*
* @param creator The creator for clusters.
*/
public void setCreator(
ClusterCreator creator)
{
this.creator = creator;
}
/**
* Gets the spatial index.
*
* @return The spatial index.
*/
public KDTree> getSpatialIndex()
{
return this.spatialIndex;
}
/**
* Sets the spatial index.
*
* @param spatialIndex The spatial index (speeds up neighborhood queries).
*/
public void setCreator(
KDTree> spatialIndex)
{
this.spatialIndex = spatialIndex;
}
/**
* Gets the list of points.
*
* @return The points being clustered.
*/
public ArrayList getPoints()
{
return this.points;
}
/**
* Sets the list of points.
*
* @param points The points to be clustered.
*/
public void setPoints(
ArrayList points)
{
this.points = points;
}
/**
* Gets the number of clusters.
*
* @return The number of clusters.
*/
public int getClusterCount()
{
return this.clusterCount;
}
/**
* Sets the number of clusters.
*
* @param count The number of clusters.
*/
public void setClusterCount(
int count)
{
this.clusterCount = count;
}
/**
* Gets the point index.
*
* @return The point index.
*/
public int getPointIndex()
{
return this.pointIndex;
}
/**
* Sets the point index.
*
* @param index The point index.
*/
public void setPointIndex(
int index)
{
this.pointIndex = index;
}
}