All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.ojalgo.data.cluster.GeneralisedKMeans Maven / Gradle / Ivy

Go to download

oj! Algorithms - ojAlgo - is Open Source Java code that has to do with mathematics, linear algebra and optimisation.

There is a newer version: 55.1.0
Show newest version
package org.ojalgo.data.cluster;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.function.ToDoubleBiFunction;

import org.ojalgo.type.context.NumberContext;

/**
 * Contains the outline of the k-means algorithm, but designed for customisation.
 * 
    *
  • Works with any type of data *
  • Allows for custom distance calculations *
  • Allows for custom centroid initialisation and updating *
*/ public final class GeneralisedKMeans implements ClusteringAlgorithm { private static final NumberContext ACCURACY = NumberContext.of(4); private final Function, T> myCentroidUpdater; private final ToDoubleBiFunction myDistanceCalculator; private final Function, List> myCentroidInitialiser; /** * You have to configure how distances are measured and how centroids are derived. * * @param centroidInitialiser The initialisation function should return a list of k centroids. This * function determines 'K'. * @param centroidUpdater The update function should return a new centroid based on a collection of points * (the set of items in a cluster). * @param distanceCalculator A function that calculates the distance between two points. */ public GeneralisedKMeans(final Function, List> centroidInitialiser, final Function, T> centroidUpdater, final ToDoubleBiFunction distanceCalculator) { super(); myCentroidInitialiser = centroidInitialiser; myCentroidUpdater = centroidUpdater; myDistanceCalculator = distanceCalculator; } @Override public List> cluster(final Collection input) { List centroids = myCentroidInitialiser.apply(input); int k = centroids.size(); int maxIterations = Math.max(5, Math.min((int) Math.round(Math.sqrt(input.size())), 50)); List> clusters = new ArrayList<>(k); for (int i = 0; i < k; i++) { clusters.add(i, new HashSet<>()); } int iterations = 0; boolean converged = false; do { converged = true; for (Set cluster : clusters) { cluster.clear(); } for (T point : input) { int bestCluster = 0; double minDistance = Double.MAX_VALUE; for (int i = 0; i < k; i++) { double distance = myDistanceCalculator.applyAsDouble(centroids.get(i), point); if (distance < minDistance) { minDistance = distance; bestCluster = i; } } clusters.get(bestCluster).add(point); } Set cluster; T oldCenter; T newCenter; for (int i = 0; i < k; i++) { cluster = clusters.get(i); if (!cluster.isEmpty()) { oldCenter = centroids.get(i); newCenter = myCentroidUpdater.apply(cluster); converged &= ACCURACY.isZero(myDistanceCalculator.applyAsDouble(oldCenter, newCenter)); centroids.set(i, newCenter); } } } while (++iterations < maxIterations && !converged); return clusters; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy