All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.leiden.GraphAggregationPhase Maven / Gradle / Ivy

There is a newer version: 2.15.0
Show newest version
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.leiden;

import org.neo4j.gds.ImmutableRelationshipProjections;
import org.neo4j.gds.NodeProjections;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.RelationshipProjection;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.CSRGraphStoreFactory;
import org.neo4j.gds.api.DefaultValue;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.schema.Direction;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
import org.neo4j.gds.core.Aggregation;
import org.neo4j.gds.core.ImmutableGraphDimensions;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.loading.construction.GraphFactory;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.LongToIntFunction;

class GraphAggregationPhase {

    static MemoryEstimation memoryEstimation() {
        return MemoryEstimations.builder(GraphAggregationPhase.class)
            .rangePerGraphDimension("aggregated graph", (rootDimensions, concurrency) -> {
                // The input graph might have multiple node and relationship properties
                // but the aggregated graph will never have more than a single relationship property
                // so let's

                // Handle the case where the input graph has only one node
                var minNodeCount = Math.min(2, rootDimensions.nodeCount());
                var minRelCount = Math.min(1, rootDimensions.relCountUpperBound());

                var minDimensions = ImmutableGraphDimensions
                    .builder()
                    .nodeCount(minNodeCount)
                    .highestPossibleNodeCount(minNodeCount)
                    .relationshipCounts(Map.of(RelationshipType.of("foo"), minRelCount))
                    .relCountUpperBound(minRelCount)
                    .highestRelationshipId(minRelCount)
                    .build();

                var relationshipProjections = ImmutableRelationshipProjections.builder()
                    .putProjection(
                        RelationshipType.of("AGGREGATE"),
                        RelationshipProjection.builder()
                            .type("AGGREGATE")
                            .orientation(Orientation.UNDIRECTED)
                            .aggregation(Aggregation.SUM)
                            .addProperty("prop", "prop", DefaultValue.of(1.0))
                            .build()
                    )
                    .build();

                var memoryEstimation = CSRGraphStoreFactory.getMemoryEstimation(
                    NodeProjections.all(),
                    relationshipProjections,
                    false
                );
                var min = memoryEstimation.estimate(minDimensions, concurrency).memoryUsage().min;
                var max = memoryEstimation.estimate(rootDimensions, concurrency).memoryUsage().max;

                return MemoryRange.of(min, max);
            }).perNode("sorted communities", HugeLongArray::memoryEstimation)
            .perNode("atomic coordination array", HugeAtomicLongArray::memoryEstimation).
            build();
    }

    private final Graph workingGraph;
    private final HugeLongArray communities;
    private final Direction direction;
    private final long maxCommunityId;
    private final ExecutorService executorService;
    private final Concurrency concurrency;
    private final TerminationFlag terminationFlag;
    private final ProgressTracker progressTracker;

    GraphAggregationPhase(
        Graph workingGraph,
        Direction direction,
        HugeLongArray communities,
        long maxCommunityId,
        ExecutorService executorService,
        Concurrency concurrency,
        TerminationFlag terminationFlag,
        ProgressTracker progressTracker
    ) {
        this.workingGraph = workingGraph;
        this.communities = communities;
        this.direction = direction;
        this.maxCommunityId = maxCommunityId;
        this.executorService = executorService;
        this.concurrency = concurrency;

        this.terminationFlag = terminationFlag;
        this.progressTracker = progressTracker;
    }

    Graph run() {
        var nodesBuilder = GraphFactory.initNodesBuilder()
            .maxOriginalId(maxCommunityId)
            .concurrency(this.concurrency)
            .build();

        terminationFlag.assertRunning();

        ParallelUtil.parallelForEachNode(
            workingGraph.nodeCount(),
            concurrency,
            TerminationFlag.RUNNING_TRUE,
            nodeId -> nodesBuilder.addNode(communities.get(nodeId))
        );

        terminationFlag.assertRunning();

        IdMap idMap = nodesBuilder.build().idMap();
        RelationshipsBuilder relationshipsBuilder = GraphFactory.initRelationshipsBuilder()
            .nodes(idMap)
            .relationshipType(RelationshipType.of("_IGNORED_"))
            .orientation(direction.toOrientation())
            .addPropertyConfig(GraphFactory.PropertyConfig.builder()
                .propertyKey("property")
                .aggregation(Aggregation.SUM)
                .build())
            .executorService(executorService)
            .build();

        var sortedNodesByCommunity = getNodesSortedByCommunity(
            communities,
            concurrency
        );

        LongToIntFunction customDegree = x -> workingGraph.degree(sortedNodesByCommunity.get(x));
        var relationshipCreators = PartitionUtils.customDegreePartitionWithBatchSize(
            workingGraph,
            concurrency,
            customDegree,
            partition ->
                new RelationshipCreator(
                    sortedNodesByCommunity,
                    communities,
                    partition,
                    relationshipsBuilder,
                    workingGraph.concurrentCopy(),
                    direction,
                    progressTracker
                ),
            Optional.empty(),
            Optional.of(workingGraph.relationshipCount())
        );

        ParallelUtil.run(relationshipCreators, executorService);

        return GraphFactory.create(idMap, relationshipsBuilder.build());
    }

    static HugeLongArray getNodesSortedByCommunity(HugeLongArray communities, Concurrency concurrency) {
        long nodeCount = communities.size();

        var sortedNodesByCommunity = HugeLongArray.newArray(nodeCount);
        var communityCoordinateArray = HugeAtomicLongArray.of(nodeCount, ParalleLongPageCreator.passThrough(concurrency));


        ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, nodeId -> {
            {
                long communityId = communities.get(nodeId);
                communityCoordinateArray.getAndAdd(communityId, 1);
            }
        });
        AtomicLong atomicNodeSum = new AtomicLong();
        ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, indexId ->
        {
            if (communityCoordinateArray.get(indexId) > 0) {
                var nodeSum = atomicNodeSum.addAndGet(communityCoordinateArray.get(indexId));
                communityCoordinateArray.set(indexId, nodeSum);
            }
        });

        ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, indexId ->
        {

            long nodeId = nodeCount - indexId - 1;
            long communityId = communities.get(nodeId);
            long coordinate = communityCoordinateArray.getAndAdd(communityId, -1);
            sortedNodesByCommunity.set(coordinate - 1, nodeId);

        });

        return sortedNodesByCommunity;

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy