All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.kmeans.KmeansPlusPlusSampler Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.kmeans;


import com.carrotsearch.hppc.BitSet;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;

import java.util.List;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;

public class KmeansPlusPlusSampler extends KmeansSampler {

    private final List tasks;
    private final Concurrency concurrency;
    private final ProgressTracker progressTracker;
    private final HugeDoubleArray distanceFromClosestCentroid;
    private final ExecutorService executorService;


    public KmeansPlusPlusSampler(
        SplittableRandom random,
        ClusterManager clusterManager,
        long nodeCount,
        int k,
        HugeDoubleArray distanceFromClosestCentroid,
        Concurrency concurrency,
        ExecutorService executorService,
        List tasks,
        ProgressTracker progressTracker
    ) {
        super(random, clusterManager, nodeCount, k);
        this.distanceFromClosestCentroid = distanceFromClosestCentroid;
        this.executorService = executorService;
        this.tasks = tasks;
        this.concurrency = concurrency;
        this.progressTracker = progressTracker;
    }


    @Override
    public void performInitialSampling() {
        long firstId = random.nextLong(nodeCount);

        BitSet bitSet = new BitSet(nodeCount);
        assignToCluster(bitSet, 0, firstId);
        progressTracker.logProgress(1);
        for (int selectionClusterId = 1; selectionClusterId < k; ++selectionClusterId) {

            RunWithConcurrency.builder()
                .concurrency(concurrency)
                .tasks(tasks)
                .executor(executorService)
                .run();

            double squaredDistance = 0;
            for (KmeansTask task : tasks) {
                squaredDistance += task.getSquaredDistance();
            }
            long nextNode = -1;

            //This is fail-case in case of overflow
            if (!(Double.isInfinite(squaredDistance) || squaredDistance <= 0)) {

                double x = random.nextDouble() * squaredDistance;
                double curr = 0;

                for (long nodeId = 0; nodeId < nodeCount; nodeId++) {
                    double distanceFromCentroid = distanceFromClosestCentroid.get(nodeId);
                    if (distanceFromCentroid <= -1) {
                        continue;
                    }
                    curr += distanceFromCentroid * distanceFromCentroid;

                    if (x <= curr) {
                        nextNode = nodeId;
                        break;
                    }

                }
            }
            if (nextNode == -1) {
                nextNode = random.nextLong(nodeCount);
                while (bitSet.get(nextNode)) {
                    nextNode = random.nextLong(nodeCount);
                }
            }
            assignToCluster(bitSet, selectionClusterId, nextNode);
            progressTracker.logProgress(1);
        }
        //nowe we have k clusters and distanceFromClusterAlso for each node closest communit in 0...k-2
        RunWithConcurrency.builder()  //now run one last time just to save  have the vest community in 0...k-1
            .concurrency(concurrency)
            .tasks(tasks)
            .executor(executorService)
            .run();

        for (KmeansTask task : tasks) {
            task.switchToPhase(TaskPhase.ITERATION);
        }
    }

    private void assignToCluster(BitSet bitSet, int position, long nextNode) {
        bitSet.set(nextNode);
        clusterManager.initialAssignCluster(nextNode);
        distanceFromClosestCentroid.set(nextNode, -(position + 1));
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy