
org.ojalgo.data.cluster.GeneralisedKMeans Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of ojalgo Show documentation
Show all versions of ojalgo Show documentation
oj! Algorithms - ojAlgo - is Open Source Java code that has to do with mathematics, linear algebra and optimisation.
package org.ojalgo.data.cluster;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.function.ToDoubleBiFunction;
import org.ojalgo.type.context.NumberContext;
/**
* Contains the outline of the k-means algorithm, but designed for customisation.
*
* - Works with any type of data
*
- Allows for custom distance calculations
*
- Allows for custom centroid initialisation and updating
*
*/
public final class GeneralisedKMeans implements ClusteringAlgorithm {
private static final NumberContext ACCURACY = NumberContext.of(4);
private final Function, T> myCentroidUpdater;
private final ToDoubleBiFunction myDistanceCalculator;
private final Function, List> myCentroidInitialiser;
/**
* You have to configure how distances are measured and how centroids are derived.
*
* @param centroidInitialiser The initialisation function should return a list of k centroids. This
* function determines 'K'.
* @param centroidUpdater The update function should return a new centroid based on a collection of points
* (the set of items in a cluster).
* @param distanceCalculator A function that calculates the distance between two points.
*/
public GeneralisedKMeans(final Function, List> centroidInitialiser, final Function, T> centroidUpdater,
final ToDoubleBiFunction distanceCalculator) {
super();
myCentroidInitialiser = centroidInitialiser;
myCentroidUpdater = centroidUpdater;
myDistanceCalculator = distanceCalculator;
}
@Override
public List> cluster(final Collection input) {
List centroids = myCentroidInitialiser.apply(input);
int k = centroids.size();
int maxIterations = Math.max(5, Math.min((int) Math.round(Math.sqrt(input.size())), 50));
List> clusters = new ArrayList<>(k);
for (int i = 0; i < k; i++) {
clusters.add(i, new HashSet<>());
}
int iterations = 0;
boolean converged = false;
do {
converged = true;
for (Set cluster : clusters) {
cluster.clear();
}
for (T point : input) {
int bestCluster = 0;
double minDistance = Double.MAX_VALUE;
for (int i = 0; i < k; i++) {
double distance = myDistanceCalculator.applyAsDouble(centroids.get(i), point);
if (distance < minDistance) {
minDistance = distance;
bestCluster = i;
}
}
clusters.get(bestCluster).add(point);
}
Set cluster;
T oldCenter;
T newCenter;
for (int i = 0; i < k; i++) {
cluster = clusters.get(i);
if (!cluster.isEmpty()) {
oldCenter = centroids.get(i);
newCenter = myCentroidUpdater.apply(cluster);
converged &= ACCURACY.isZero(myDistanceCalculator.applyAsDouble(oldCenter, newCenter));
centroids.set(i, newCenter);
}
}
} while (++iterations < maxIterations && !converged);
return clusters;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy