All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.kmeans.KmeansTask Maven / Gradle / Ivy

/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.kmeans;

import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.mem.Estimate;


public abstract class KmeansTask implements Runnable {
    private final ClusterManager clusterManager;
    private final Partition partition;
    final NodePropertyValues nodePropertyValues;

    private final HugeDoubleArray distanceFromCentroid;

    final HugeIntArray communities;
    final long[] communitySizes;
    final int k;
    final int dimensions;
    private long swaps;

    private double distance;

    private double squaredDistance = 0;

    private TaskPhase phase;

    long getNumAssignedAtCluster(int ith) {
        return communitySizes[ith];
    }

    long getSwaps() {
        return swaps;
    }

    static MemoryEstimation memoryEstimation(int k, int fakeDimensions) {
        var builder = MemoryEstimations.builder(KmeansTask.class);
        builder
            .fixed("communitySizes", Estimate.sizeOfLongArray(k))
            .add("communityCoordinateSums", MemoryEstimations.of("communityCoordinateSums", MemoryRange.of(
                k * Estimate.sizeOfFloatArray(fakeDimensions),
                k * Estimate.sizeOfDoubleArray(fakeDimensions)
            )));
        return builder.build();
    }

    abstract void reset();

    abstract void updateAfterAssignmentToCentroid(long nodeId, int community);

    KmeansTask(
        SamplerType samplerType,
        ClusterManager clusterManager,
        NodePropertyValues nodePropertyValues,
        HugeIntArray communities,
        HugeDoubleArray distanceFromCentroid,
        int k,
        int dimensions,
        Partition partition
    ) {
        this.clusterManager = clusterManager;
        this.nodePropertyValues = nodePropertyValues;
        this.communities = communities;
        this.distanceFromCentroid = distanceFromCentroid;
        this.k = k;
        this.dimensions = dimensions;
        this.partition = partition;
        this.communitySizes = new long[k];
        if (samplerType == SamplerType.UNIFORM) {
            this.phase = TaskPhase.ITERATION;
        } else {
            this.phase = TaskPhase.INITIAL;
        }
        this.distance = 0d;
    }

    static KmeansTask createTask(
        SamplerType samplerType,
        ClusterManager clusterManager,
        NodePropertyValues nodePropertyValues,
        HugeIntArray communities,
        HugeDoubleArray distanceFromCentroid,
        int k,
        int dimensions,
        Partition partition
    ) {
        if (clusterManager instanceof DoubleClusterManager) {
            return new DoubleKmeansTask(
                samplerType,
                clusterManager,
                nodePropertyValues,
                communities,
                distanceFromCentroid,
                k,
                dimensions,
                partition
            );
        }
        return new FloatKmeansTask(
            samplerType,
            clusterManager,
            nodePropertyValues,
            communities,
            distanceFromCentroid,
            k,
            dimensions,
            partition
        );
    }

    void switchToPhase(TaskPhase newPhase) {
        phase = newPhase;
    }

    private void assignNodeToCentroid(long startNode, long endNode) {
        swaps = 0;

        reset();

        for (long nodeId = startNode; nodeId < endNode; nodeId++) {
            int closestCommunity = clusterManager.findClosestCentroid(nodeId);
            communitySizes[closestCommunity]++;
            int previousCommunity = communities.get(nodeId);
            if (closestCommunity != previousCommunity) {
                swaps++;
            }
            communities.set(nodeId, closestCommunity);
            //Note for potential improvement : This is potentially costly when clusters have somewhat stabilized.
            //Because we keep adding the same nodes to the same clusters. Perhaps instead of making the sum 0
            //we keep as is and do a subtraction when a node changes its cluster.
            //On that note,  maybe we can skip stable communities (i.e., communities that did not change between one iteration to another)
            // or avoid calculating their distance from other nodes etc...
            updateAfterAssignmentToCentroid(nodeId, closestCommunity);
        }
    }

    public double getDistanceFromCentroidNormalized() {
        return distance / communities.size();
    }

    public double getSquaredDistance() {
        return squaredDistance;
    }

    private void calculateFinalDistance(long startNode, long endNode) {


        for (long nodeId = startNode; nodeId < endNode; nodeId++) {
            double nodeCentroidDistance = clusterManager.euclidean(nodeId, communities.get(nodeId));
            distance += nodeCentroidDistance;
            distanceFromCentroid.set(nodeId, nodeCentroidDistance);

        }
    }

    private void distanceFromLastSampledCentroid(long startNode, long endNode, int numAssigned) {
        squaredDistance = 0;
        for (long nodeId = startNode; nodeId < endNode; nodeId++) {
            if (distanceFromCentroid.get(nodeId) > -1) {
                double nodeCentroidDistance = clusterManager.euclidean(nodeId, numAssigned - 1);
                if (numAssigned == 1) {
                    distanceFromCentroid.set(nodeId, nodeCentroidDistance);
                    squaredDistance += nodeCentroidDistance * nodeCentroidDistance;
                    communities.set(nodeId, 0);

                } else if (distanceFromCentroid.get(nodeId) > nodeCentroidDistance) {
                    distanceFromCentroid.set(nodeId, nodeCentroidDistance);
                    squaredDistance += nodeCentroidDistance * nodeCentroidDistance;
                    communities.set(nodeId, numAssigned - 1);
                } else {
                    squaredDistance += distanceFromCentroid.get(nodeId) * distanceFromCentroid.get(nodeId);
                }
            }
            if (numAssigned == k) {

                if (distanceFromCentroid.get(nodeId) <= -1) {
                    communities.set(nodeId, (int) -distanceFromCentroid.get(nodeId) - 1);
                    distanceFromCentroid.set(nodeId, 0);
                }
                int communityId = communities.get(nodeId);
                communitySizes[communityId]++;
                updateAfterAssignmentToCentroid(nodeId, communityId);
            }
        }
    }

    @Override
    public void run() {
        var startNode = partition.startNode();
        long endNode = startNode + partition.nodeCount();
        if (phase == TaskPhase.ITERATION) {
            assignNodeToCentroid(startNode, endNode);
        } else if (phase == TaskPhase.DISTANCE) {
            calculateFinalDistance(startNode, endNode);
        } else {
            distanceFromLastSampledCentroid(startNode, endNode, clusterManager.getCurrentlyAssigned());
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy