All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.kmeans.Kmeans Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.kmeans;

import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.nodeproperties.ValueType;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;

public final class Kmeans extends Algorithm {
    private static final int UNASSIGNED = -1;
    private HugeIntArray bestCommunities;
    private final Graph graph;
    private final KmeansParameters parameters;
    private final Concurrency concurrency;
    private final ExecutorService executorService;
    private final SplittableRandom random;
    private final NodePropertyValues nodePropertyValues;
    private final int dimensions;

    private double[][] bestCentroids;

    private HugeDoubleArray distanceFromCentroid;
    private final KmeansIterationStopper kmeansIterationStopper;
    private HugeDoubleArray silhouette;
    private double averageSilhouette;
    private double bestDistance;
    private long[] nodesInCluster;


    public static Kmeans createKmeans(Graph graph, KmeansParameters parameters, KmeansContext context, TerminationFlag terminationFlag) {
        String nodeWeightProperty = parameters.nodeProperty();
        NodePropertyValues nodeProperties = graph.nodeProperties(nodeWeightProperty);
        if (nodeProperties == null) {
            throw new IllegalArgumentException("Property '" + nodeWeightProperty + "' does not exist for all nodes");
        }
        return new Kmeans(
            context.progressTracker(),
            context.executor(),
            graph,
            parameters,
            getSplittableRandom(parameters.randomSeed()),
            nodeProperties,
            terminationFlag
        );
    }

    private Kmeans(
        ProgressTracker progressTracker,
        ExecutorService executorService,
        Graph graph,
        KmeansParameters parameters,
        SplittableRandom random,
        NodePropertyValues nodePropertyValues,
        TerminationFlag terminationFlag
    ) {
        super(progressTracker);
        this.executorService = executorService;
        this.graph = graph;

        this.random = random;
        this.bestCommunities = HugeIntArray.newArray(graph.nodeCount());
        this.nodePropertyValues = nodePropertyValues;
        this.dimensions = nodePropertyValues.doubleArrayValue(0).length;
        this.kmeansIterationStopper = new KmeansIterationStopper(
            parameters.deltaThreshold(),
            parameters.maxIterations(),
            graph.nodeCount()
        );
        this.distanceFromCentroid = HugeDoubleArray.newArray(graph.nodeCount());
        this.parameters = parameters;
        this.concurrency = parameters.concurrency();
        this.nodesInCluster = new long[parameters.k()];
        this.terminationFlag = terminationFlag;
    }

    @Override
    public KmeansResult compute() {
        progressTracker.beginSubTask(); // KMeans start

        checkInputValidity();

        if (parameters.k() > graph.nodeCount()) {
            // Every node in its own community. Warn and return early.
            progressTracker.logWarning("Number of requested clusters is larger than the number of nodes.");
            bestCommunities.setAll(v -> (int) v);
            distanceFromCentroid.setAll(v -> 0d);
            progressTracker.endSubTask(); // KMeans end --> conditional!!!
            bestCentroids = new double[(int) graph.nodeCount()][dimensions];
            for (int i = 0; i < (int) graph.nodeCount(); ++i) {
                bestCentroids[i] = nodePropertyValues.doubleArrayValue(i);
            }
            return ImmutableKmeansResult.of(bestCommunities, distanceFromCentroid, bestCentroids, 0.0, silhouette, 0.0);
        }
        long nodeCount = graph.nodeCount();

        var currentCommunities = HugeIntArray.newArray(nodeCount);
        var currentDistanceFromCentroid = HugeDoubleArray.newArray(nodeCount);

        bestDistance = Double.POSITIVE_INFINITY;
        bestCommunities.setAll(v -> UNASSIGNED);

        // We need this `if` because the task tree is different if the number of restart is > 1.
        if (parameters.numberOfRestarts() == 1) {
            kMeans(nodeCount, currentCommunities, currentDistanceFromCentroid, 0);
        } else {
            for (int restartIteration = 0; restartIteration < parameters.numberOfRestarts(); ++restartIteration) {
                progressTracker.beginSubTask(); // KMeans Iteration - start
                kMeans(nodeCount, currentCommunities, currentDistanceFromCentroid, restartIteration);
                progressTracker.endSubTask(); // KMeans Iteration - end
            }
        }

        if (parameters.computeSilhouette()) {
            calculateSilhouette();
        }
        progressTracker.endSubTask(); // KMeans end
        return ImmutableKmeansResult.of(
            bestCommunities,
            distanceFromCentroid,
            bestCentroids,
            bestDistance,
            silhouette,
            averageSilhouette
        );
    }

    private void kMeans(
        long nodeCount,
        HugeIntArray currentCommunities,
        HugeDoubleArray currentDistanceFromCentroid,
        int restartIteration
    ) {

        //note: currentDistanceFromCentroid is not reset to a [0,...,0] distance array, but it does not have to
        // it's used only in K-Means++ (where it is essentially reset; see func distanceFromLastSampledCentroid in KmeansTask)
        // or during final distance calculation where it is reset as well (see calculateFinalDistance in KmeansTask)

        ClusterManager clusterManager = ClusterManager.createClusterManager(
            nodePropertyValues,
            dimensions,
            parameters.k()
        );

        currentCommunities.setAll(v -> UNASSIGNED);

        var tasks = PartitionUtils.rangePartition(
            concurrency,
            nodeCount,
            partition -> KmeansTask.createTask(
                parameters.samplerType(),
                clusterManager,
                nodePropertyValues,
                currentCommunities,
                currentDistanceFromCentroid,
                parameters.k(),
                dimensions,
                partition
            ),
            Optional.of((int) nodeCount / concurrency.value())
        );
        int numberOfTasks = tasks.size();

        KmeansSampler sampler = KmeansSampler.createSampler(
            parameters.samplerType(),
            random,
            clusterManager,
            nodeCount,
            parameters.k(),
            concurrency,
            currentDistanceFromCentroid,
            executorService,
            tasks,
            progressTracker
        );

        assert numberOfTasks <= concurrency.value();

        //Initialization do initial centroid computation and assignment
        initializeCentroids(clusterManager, sampler);

        int iteration = 0;
        progressTracker.beginSubTask(); // Main - start
        while (true) {
            progressTracker.beginSubTask(); // Iteration - start

            long numberOfSwaps = 0;
            //assign each node to a centroid
            boolean shouldComputeDistance = (iteration > 0)
                || (parameters.samplerType() == SamplerType.UNIFORM);
            if (shouldComputeDistance) {
                RunWithConcurrency.builder()
                    .concurrency(concurrency)
                    .tasks(tasks)
                    .executor(executorService)
                    .run();

                for (KmeansTask task : tasks) {
                    numberOfSwaps += task.getSwaps();
                }
            }
            recomputeCentroids(clusterManager, tasks);
            progressTracker.endSubTask(); // Iteration - end
            if (kmeansIterationStopper.shouldQuit(numberOfSwaps, ++iteration)) {
                break;
            }

        }
        progressTracker.endSubTask(); // Main - end

        double averageDistanceFromCentroid = calculateDistancePhase(tasks);
        updateBestSolution(
            restartIteration,
            clusterManager,
            averageDistanceFromCentroid,
            currentCommunities,
            currentDistanceFromCentroid
        );
    }

    private void initializeCentroids(ClusterManager clusterManager, KmeansSampler sampler) {
        progressTracker.beginSubTask(); // Initialization - start
        if (parameters.isSeeded()) {
            clusterManager.assignSeededCentroids(parameters.seedCentroids());
        } else {
            sampler.performInitialSampling();
        }
        progressTracker.endSubTask(); // Initialization - end
    }

    private void recomputeCentroids(ClusterManager clusterManager, Iterable tasks) {
        clusterManager.reset();

        for (KmeansTask task : tasks) {
            clusterManager.updateFromTask(task);
        }
        clusterManager.normalizeClusters();
    }

    @NotNull
    private static SplittableRandom getSplittableRandom(Optional randomSeed) {
        return randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
    }

    private void checkInputValidity() {
        if (parameters.isSeeded()) {
            var seededCentroids = parameters.seedCentroids();
            for (List centroid : seededCentroids) {
                if (centroid.size() != dimensions) {
                    throw new IllegalStateException(
                        "All property arrays for K-Means should have the same number of dimensions");
                } else {
                    for (double value : centroid) {
                        if (Double.isNaN(value)) {
                            throw new IllegalArgumentException("Input for K-Means should not contain any NaN values");
                        }
                    }
                }

            }
        }
        ParallelUtil.parallelForEachNode(
            graph.nodeCount(),
            concurrency,
            TerminationFlag.RUNNING_TRUE,
            nodeId -> {
                if (nodePropertyValues.valueType() == ValueType.FLOAT_ARRAY) {
                    var value = nodePropertyValues.floatArrayValue(nodeId);
                    if (value == null) {
                        throw new IllegalArgumentException("Property '" + parameters.nodeProperty() + "' does not exist for all nodes");
                    }
                    if (value.length != dimensions) {
                        throw new IllegalStateException(
                            "All property arrays for K-Means should have the same number of dimensions");
                    } else {
                        for (int dimension = 0; dimension < dimensions; ++dimension) {
                            if (Float.isNaN(value[dimension])) {
                                throw new IllegalArgumentException("Input for K-Means should not contain any NaN values");
                            }
                        }
                    }
                } else {
                    var value = nodePropertyValues.doubleArrayValue(nodeId);
                    if (value == null) {
                        throw new IllegalArgumentException("Property '" + parameters.nodeProperty() + "' does not exist for all nodes");
                    }
                    if (value.length != dimensions) {
                        throw new IllegalStateException(
                            "All property arrays for K-Means should have the same number of dimensions");
                    } else {
                        for (int dimension = 0; dimension < dimensions; ++dimension) {
                            if (Double.isNaN(value[dimension])) {
                                throw new IllegalArgumentException("Input for K-Means should not contain any NaN values");

                            }
                        }
                    }
                }
            }
        );
    }


    private void calculateSilhouette() {
        var nodeCount = graph.nodeCount();
        progressTracker.beginSubTask();
        this.silhouette = HugeDoubleArray.newArray(nodeCount);
        var tasks = PartitionUtils.rangePartition(
            concurrency,
            nodeCount,
            partition -> SilhouetteTask.createTask(
                nodePropertyValues,
                bestCommunities,
                silhouette,
                parameters.k(),
                dimensions,
                nodesInCluster,
                partition,
                progressTracker
            ),
            Optional.of((int) nodeCount / concurrency.value())
        );
        RunWithConcurrency.builder()
            .concurrency(concurrency)
            .tasks(tasks)
            .executor(executorService)
            .run();

        for (var task : tasks) {
            averageSilhouette += task.getAverageSilhouette();
        }
        progressTracker.endSubTask();

    }

    private double calculateDistancePhase(Iterable tasks) {
        for (KmeansTask task : tasks) {
            task.switchToPhase(TaskPhase.DISTANCE);
        }
        RunWithConcurrency.builder()
            .concurrency(concurrency)
            .tasks(tasks)
            .executor(executorService)
            .run();
        double averageDistanceFromCentroid = 0;
        for (KmeansTask task : tasks) {
            averageDistanceFromCentroid += task.getDistanceFromCentroidNormalized();
        }
        return averageDistanceFromCentroid;
    }

    private void updateBestSolution(
        int restartIteration,
        ClusterManager clusterManager,
        double averageDistanceFromCentroid,
        HugeIntArray currentCommunities,
        HugeDoubleArray currentDistanceFromCentroid
    ) {
        if (restartIteration >= 1) {
            if (averageDistanceFromCentroid < bestDistance) {
                bestDistance = averageDistanceFromCentroid;
                ParallelUtil.parallelForEachNode(
                    graph.nodeCount(),
                    concurrency,
                    terminationFlag,
                    v -> {
                        bestCommunities.set(v, currentCommunities.get(v));
                        distanceFromCentroid.set(v, currentDistanceFromCentroid.get(v));
                    }
                );
                bestCentroids = clusterManager.getCentroids();
                if (parameters.computeSilhouette()) {
                    nodesInCluster = clusterManager.getNodesInCluster();
                }
            }
        } else {
            bestCommunities = currentCommunities;
            distanceFromCentroid = currentDistanceFromCentroid;
            bestCentroids = clusterManager.getCentroids();
            bestDistance = averageDistanceFromCentroid;
            if (parameters.computeSilhouette()) {
                nodesInCluster = clusterManager.getNodesInCluster();
            }

        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy