All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.leiden.Leiden Maven / Gradle / Ivy

/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.leiden;

import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.api.schema.Direction;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.DoubleAdder;

//TODO: take care of potential issues w. self-loops

public class Leiden extends Algorithm {

    private final Graph rootGraph;
    private final Direction direction;
    private final int maxIterations;
    private final double initialGamma;
    private final double theta;
    private final double[] modularities;
    private double modularity;
    private final LeidenDendrogramManager dendrogramManager;
    private final Optional seedValues;
    private final ExecutorService executorService;
    private final Concurrency concurrency;
    private final long randomSeed;

    private final double tolerance;

    public Leiden(
        Graph graph,
        int maxIterations,
        double initialGamma,
        double theta,
        boolean includeIntermediateCommunities,
        long randomSeed,
        @Nullable NodePropertyValues seedValues,
        double tolerance,
        Concurrency concurrency,
        ProgressTracker progressTracker,
        TerminationFlag terminationFlag
    ) {
        super(progressTracker);
        this.rootGraph = graph;
        this.direction = rootGraph.schema().direction();
        this.maxIterations = maxIterations;
        this.initialGamma = initialGamma;
        this.theta = theta;
        this.randomSeed = randomSeed;
        // TODO: Pass these two as parameters
        this.executorService = DefaultPool.INSTANCE;
        this.concurrency = concurrency;
        this.dendrogramManager = new LeidenDendrogramManager(
            rootGraph,
            maxIterations,
            concurrency,
            includeIntermediateCommunities,
            terminationFlag
        );
        this.seedValues = Optional.ofNullable(seedValues);
        this.modularities = new double[maxIterations];
        this.modularity = 0d;
        this.tolerance = tolerance;

        this.terminationFlag = terminationFlag;
    }

    @Override
    public LeidenResult compute() {
        progressTracker.beginSubTask("Leiden");
        var workingGraph = rootGraph;
        var nodeCount = workingGraph.nodeCount();
        var localMoveCommunities = LeidenUtils.createStartingCommunities(nodeCount, seedValues.orElse(null));

        SeedCommunityManager seedCommunityManager = SeedCommunityManager.create(
            seedValues.isPresent(),
            localMoveCommunities
        );

        // volume -> the sum of the weights of a nodes outgoing relationships
        var localMoveNodeVolumes = HugeDoubleArray.newArray(nodeCount);
        // the sum of the node volume for all nodes in a community
        var localMoveCommunityVolumes = HugeDoubleArray.newArray(nodeCount);
        double modularityScaleCoefficient = initVolumes(
            localMoveNodeVolumes,
            localMoveCommunityVolumes,
            localMoveCommunities
        );

        double gamma = this.initialGamma * modularityScaleCoefficient;

        //currentActualCommunities keeps a mapping of nodes to the community they currently belong to
        //if no seeding is involved, these values can be considered correct output.
        //Otherwise, they depict the current state without caring consider seeding (i.e., let's say seed:42 is mapped to community 0
        // then  currentCommunities.get(x)=0 not 42 whereas final output should be 42.
        HugeLongArray currentActualCommunities = HugeLongArray.newArray(rootGraph.nodeCount());

        boolean didConverge = false;
        int iteration;
        progressTracker.beginSubTask("Iteration");

        for (iteration = 0; iteration < maxIterations; iteration++) {
            // 1. LOCAL MOVE PHASE - over the singleton localMoveCommunities
            progressTracker.beginSubTask("Local Move");
            var localMovePhase = LocalMovePhase.create(
                workingGraph,
                localMoveCommunities,
                localMoveNodeVolumes,
                localMoveCommunityVolumes,
                gamma,
                concurrency
            );

            localMovePhase.run();
            //if you do swaps,  no convergence
            boolean localPhaseConverged = localMovePhase.swaps == 0;
            progressTracker.endSubTask("Local Move");

            progressTracker.beginSubTask("Modularity Computation");
            updateModularity(
                workingGraph,
                localMoveCommunities,
                localMoveCommunityVolumes,
                modularityScaleCoefficient,
                gamma,
                localPhaseConverged,
                iteration
            );

            progressTracker.endSubTask("Modularity Computation");

            if (localPhaseConverged) {
                didConverge = true;
                break;
            }
            var toleranceStatus = getToleranceStatus(iteration);

            //if you deteriotate performance, exit and return previous iteration
            if (toleranceStatus == ToleranceStatus.DECREASE) {
                break;
            }
            dendrogramManager.updateOutputDendrogram(
                workingGraph,
                currentActualCommunities,
                localMoveCommunities,
                seedCommunityManager,
                iteration
            ); //write user's output

            if (toleranceStatus == ToleranceStatus.CONVERGED) {
                didConverge = true;
                modularity = modularities[iteration];
                iteration++;
                break;
            } //if little difference from previous iteration, keep and break

            if (iteration < maxIterations - 1) { //if there's no next iteration, skip refinement/graph aggregation
                // 2 REFINE
                progressTracker.beginSubTask("Refinement");
                var refinementPhase = RefinementPhase.create(
                    workingGraph,
                    localMoveCommunities,
                    localMoveNodeVolumes,
                    localMoveCommunityVolumes,
                    gamma,
                    theta,
                    randomSeed,
                    concurrency,
                    executorService,
                    progressTracker
                );
                var refinementPhaseResult = refinementPhase.run();
                var refinedCommunities = refinementPhaseResult.communities();
                var refinedCommunityVolumes = refinementPhaseResult.communityVolumes();
                var maximumRefinedCommunityId = refinementPhaseResult.maximumRefinedCommunityId();

                progressTracker.endSubTask("Refinement");

                progressTracker.beginSubTask("Aggregation");
                dendrogramManager.updateAlgorithmDendrogram(
                    workingGraph,
                    currentActualCommunities,
                    refinedCommunities,
                    iteration
                );  //update the actual communities with the refined ones

                // 3 CREATE NEW GRAPH
                var graphAggregationPhase = new GraphAggregationPhase(
                    workingGraph,
                    this.direction,
                    refinedCommunities,
                    maximumRefinedCommunityId,
                    this.executorService,
                    this.concurrency,
                    this.terminationFlag,
                    this.progressTracker
                );
                var previousNodeCount = workingGraph.nodeCount();
                workingGraph = graphAggregationPhase.run();

                // Post-aggregate step: MAINTAIN PARTITION
                var communityData = maintainPartition(
                    workingGraph,
                    localMoveCommunities,
                    refinedCommunityVolumes,
                    previousNodeCount
                );
                localMoveCommunities = communityData.seededCommunitiesForNextIteration;
                localMoveCommunityVolumes = communityData.communityVolumes;
                localMoveNodeVolumes = communityData.aggregatedNodeSeedVolume;
                progressTracker.endSubTask("Aggregation");
            }
            modularity = modularities[iteration];

        }
        progressTracker.endSubTask("Iteration");

        progressTracker.endSubTask("Leiden");

        return getLeidenResult(didConverge, iteration);
    }

    @NotNull
    private LeidenResult getLeidenResult(boolean didConverge, int iteration) {
        boolean stoppedAtFirstIteration = didConverge && iteration == 0;
        if (stoppedAtFirstIteration) {
            var modularity = modularities[0];
            return new LeidenResult(
                LeidenUtils.createStartingCommunities(rootGraph.nodeCount(), seedValues.orElse(null)),
                1,
                true,
                null,
                new double[]{modularity},
                modularity
            );
        } else {
            return new LeidenResult(
                dendrogramManager.getCurrent(),
                iteration,
                didConverge,
                dendrogramManager,
                resizeModularitiesArray(iteration),
                modularity
            );
        }
    }

    private void updateModularity(
        Graph workingGraph,
        HugeLongArray localMoveCommunities,
        HugeDoubleArray localMoveCommunityVolumes,
        double modularityScaleCoefficient,
        double gamma,
        boolean localPhaseConverged,
        int iteration
    ) {
        // Will calculate modularity only if:
        // - the local phase has not converged (i.e., no swaps done)
        // - or we terminate in the first iteration (i.e., given seeding is optimal, graph is empty, etc)
        boolean shouldCalculateModularity = !localPhaseConverged || iteration == 0;

        if (shouldCalculateModularity) {
            modularities[iteration] = ModularityComputer.compute(
                workingGraph,
                localMoveCommunities,
                localMoveCommunityVolumes,
                gamma,
                modularityScaleCoefficient,
                concurrency,
                executorService,
                progressTracker
            );
        }
    }

    private double initVolumes(
        HugeDoubleArray nodeVolumes,
        HugeDoubleArray communityVolumes,
        HugeLongArray initialCommunities
    ) {
        progressTracker.beginSubTask("Initialization");
        double totalVolume;
        var volumeAdder = new DoubleAdder();
        if (rootGraph.hasRelationshipProperty()) {
            List tasks = PartitionUtils.rangePartition(
                concurrency,
                rootGraph.nodeCount(),
                partition -> new InitVolumeTask(
                    rootGraph.concurrentCopy(),
                    nodeVolumes,
                    partition,
                    volumeAdder
                ),
                Optional.empty()
            );
            RunWithConcurrency.builder().
                concurrency(concurrency).
                tasks(tasks).
                executor(executorService)
                .run();
            totalVolume = volumeAdder.sum();
        } else {
            nodeVolumes.setAll(rootGraph::degree);
            totalVolume = rootGraph.relationshipCount();
        }
        rootGraph.forEachNode(nodeId -> {
            long communityId = initialCommunities.get(nodeId);
            progressTracker.logProgress();
            communityVolumes.addTo(communityId, nodeVolumes.get(nodeId));
            return true;
        });
        progressTracker.endSubTask("Initialization");

        return 1 / totalVolume;
    }

    static @NotNull CommunityData maintainPartition(
        Graph workingGraph,
        @NotNull HugeLongArray localPhaseCommunities,
        HugeDoubleArray refinedCommunityVolumes,
        long previousNodeCount
    ) {
        HugeLongArray inputCommunities = HugeLongArray.newArray(workingGraph.nodeCount());

        var localPhaseCommunityToAggregatedNewId = HugeLongArray.newArray(previousNodeCount);
        localPhaseCommunityToAggregatedNewId.setAll(l -> -1);
        //this works under the following constraint:
        //   for every  mapping community x
        //  nodeId  x from the previous graph (i.e., originalNode) is in same  community x
        //Otherwise, we need a reverse Map (see Louvain)

        //refined : corresponds to the refined communities in the previous step (in their original ids)
        //aggregated: corresponds to the refined communities in the previous step (in id in the new graph)
        //localPhase: corresponds to the local phase communities in the previous step
        HugeDoubleArray aggregatedCommunitySeedVolume = HugeDoubleArray.newArray(workingGraph.nodeCount());
        HugeDoubleArray aggregatedNodeSeedVolume = HugeDoubleArray.newArray(workingGraph.nodeCount());
        workingGraph.forEachNode(aggregatedCommunityId -> {
            long refinedCommunityId = workingGraph.toOriginalNodeId(aggregatedCommunityId);
            long localPhaseCommunityId = localPhaseCommunities.get(refinedCommunityId);
            long aggregatedSeedCommunityId;
            // cache the `aggregatedSeedCommunityId`
            if (localPhaseCommunityToAggregatedNewId.get(localPhaseCommunityId) != -1) {
                aggregatedSeedCommunityId = localPhaseCommunityToAggregatedNewId.get(localPhaseCommunityId);
            } else {
                aggregatedSeedCommunityId = aggregatedCommunityId;
                localPhaseCommunityToAggregatedNewId.set(localPhaseCommunityId, aggregatedSeedCommunityId);
            }

            double volumeOfTheAggregatedCommunity = refinedCommunityVolumes.get(refinedCommunityId);
            aggregatedCommunitySeedVolume.addTo(aggregatedSeedCommunityId, volumeOfTheAggregatedCommunity);

            inputCommunities.set(aggregatedCommunityId, aggregatedSeedCommunityId);
            aggregatedNodeSeedVolume.set(aggregatedCommunityId, volumeOfTheAggregatedCommunity);

            return true;
        });
        return new CommunityData(
            inputCommunities,
            aggregatedCommunitySeedVolume,
            aggregatedNodeSeedVolume
        );
    }

    private double[] resizeModularitiesArray(int iteration) {
        double[] resizedModularities = new double[iteration];
        if (iteration < maxIterations) {
            System.arraycopy(this.modularities, 0, resizedModularities, 0, iteration);
        } else {
            return modularities;
        }
        return resizedModularities;
    }

    private ToleranceStatus getToleranceStatus(int iteration) {
        if (iteration == 0) {
            return ToleranceStatus.CONTINUE;
        } else {
            var difference = modularities[iteration] - modularities[iteration - 1];
            if (difference < 0) {
                return ToleranceStatus.DECREASE;
            }
            return (Double.compare(difference, tolerance) < 0) ? ToleranceStatus.CONVERGED : ToleranceStatus.CONTINUE;
        }
    }

    private enum ToleranceStatus {
        CONVERGED, DECREASE, CONTINUE
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy