All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.paths.traverse.BFS Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.paths.traverse;


import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;

/**
 * Parallel implementation of the BFS algorithm.
 *
 * It uses the concept of bucketing/chunking to keep track of the ordering of the
 * visited nodes.
 *
 * Conceptually, a bucket keeps all nodes at a fixed distance from the starting node.
 * The nodes within each bucket are kept in a list ordered by their final position
 * in the output BFS ordering.
 *
 * To implement parallelism, the nodes within a bucket are processed concurrently.
 * For this, the nodes of the bucket are partitioned into chunks, where each chunk
 * contains a continuous segment from the list of nodes. Threads are then assigned
 * chunks in parallel and process (relax) each node within the assigned chunk.
 *
 * To maintain a correct ordering, once the parallel processing phase has concluded,
 * we perform a sequential step, where we examine the chunks from earliest to latest
 * to create the next bucket, such that a correct BFS ordering is returned where all
 * descendants from the nodes of a chunk, appear together before those from a later
 * chunk.
 */
public final class BFS extends Algorithm {

    private static final int DEFAULT_DELTA = 64;
    public static final int ALL_DEPTHS_ALLOWED = -1;

    private final long sourceNodeId;
    private final ExitPredicate exitPredicate;
    private final Aggregator aggregatorFunction;
    private final Graph graph;
    private final int delta;
    private final long maximumDepth;
    // An array to keep the node ids that were already traversed in the correct order.
    // It is initialized with the total number of nodes, but may contain less than that.
    private final HugeLongArray traversedNodes;

    // An array to keep the weight/depth of the node at the same position in `traversedNodes`.
    // It is initialized with the total number of nodes, but may contain less than that.
    // This is used for early termination when `maxDepth` parameter is specified.
    // `maxDepth` specifies the number of "layers" that will be traversed in the input graph,
    // starting from `startNodeId`.
    private final HugeDoubleArray weights;

    // Used to keep track of the visited nodes, the value at each index will be `true` for
    // each node id in the `traversedNodes`.
    private final HugeAtomicBitSet visited;

    private final Concurrency concurrency;

    public static BFS create(
        Graph graph,
        long startNodeId,
        ExitPredicate exitPredicate,
        Aggregator aggregatorFunction,
        Concurrency concurrency,
        ProgressTracker progressTracker,
        long maximumDepth,
        TerminationFlag terminationFlag
    ) {
        return create(
            graph,
            startNodeId,
            exitPredicate,
            aggregatorFunction,
            concurrency,
            progressTracker,
            DEFAULT_DELTA,
            maximumDepth,
            terminationFlag
        );
    }

    static BFS create(
        Graph graph,
        long startNodeId,
        ExitPredicate exitPredicate,
        Aggregator aggregatorFunction,
        Concurrency concurrency,
        ProgressTracker progressTracker,
        int delta,
        long maximumDepth,
        TerminationFlag terminationFlag
    ) {

        var nodeCount = graph.nodeCount();

        var traversedNodes = HugeLongArray.newArray(nodeCount);
        var weights = HugeDoubleArray.newArray(nodeCount);
        var visited = HugeAtomicBitSet.create(nodeCount);

        return new BFS(
            graph,
            startNodeId,
            traversedNodes,
            weights,
            visited,
            exitPredicate,
            aggregatorFunction,
            concurrency,
            progressTracker,
            delta,
            maximumDepth,
            terminationFlag
        );
    }

    private BFS(
        Graph graph,
        long sourceNodeId,
        HugeLongArray traversedNodes,
        HugeDoubleArray weights,
        HugeAtomicBitSet visited,
        ExitPredicate exitPredicate,
        Aggregator aggregatorFunction,
        Concurrency concurrency,
        ProgressTracker progressTracker,
        int delta,
        long maximumDepth,
        TerminationFlag terminationFlag
    ) {
        super(progressTracker);
        this.graph = graph;
        this.sourceNodeId = sourceNodeId;
        this.exitPredicate = exitPredicate;
        this.aggregatorFunction = aggregatorFunction;
        this.concurrency = concurrency;
        this.delta = delta;
        this.maximumDepth = maximumDepth;
        this.traversedNodes = traversedNodes;
        this.weights = weights;
        this.visited = visited;
        this.terminationFlag = terminationFlag;
    }

    @Override
    public HugeLongArray compute() {
        progressTracker.beginSubTask(graph.relationshipCount());

        // This is used to read from `traversedNodes` in chunks, updated in `BFSTask`.
        var traversedNodesIndex = new AtomicLong(0);
        // This keeps the current length of the `traversedNodes`, updated in `BFSTask.syncNextChunk`.
        var traversedNodesLength = new AtomicLong(1);
        // Used for early exit when target node is reached (if specified by the user), updated in `BFSTask`.
        var targetFoundIndex = new AtomicLong(Long.MAX_VALUE);

        // The minimum position of a predecessor that contains a relationship to the node in the `traversedNodes`.
        // This is updated in `BFSTask` and is helping to maintain the correct traversal order for the output.
        var minimumChunk = HugeAtomicLongArray.of(
            graph.nodeCount(),
            ParalleLongPageCreator.of(concurrency, l -> Long.MAX_VALUE)
        );

        visited.set(sourceNodeId);
        traversedNodes.set(0, sourceNodeId);
        weights.set(0, 0);

        var bfsTaskList = initializeBfsTasks(
            traversedNodesIndex,
            traversedNodesLength,
            targetFoundIndex,
            minimumChunk,
            delta
        );
        int bfsTaskListSize = bfsTaskList.size();
        long currentDepth = 0;
        while (terminationFlag.running()) {
            if (currentDepth == maximumDepth) {
                break;
            }
            ParallelUtil.run(bfsTaskList, DefaultPool.INSTANCE);

            if (targetFoundIndex.get() != Long.MAX_VALUE) {
                break;
            }

            // Synchronize the results sequentially
            var previousTraversedNodesLength = traversedNodesLength.get();
            int numberOfFinishedTasks = 0;
            int numberOfTasksWithChunks = countTasksWithChunks(bfsTaskList);
            while (numberOfFinishedTasks != numberOfTasksWithChunks && terminationFlag.running()) {
                int minimumTaskIndex = -1;
                for (int bfsTaskIndex = 0; bfsTaskIndex < bfsTaskListSize; ++bfsTaskIndex) {
                    var currentBfsTask = bfsTaskList.get(bfsTaskIndex);
                    if (currentBfsTask.hasMoreChunks()) {
                        if (minimumTaskIndex == -1) {
                            minimumTaskIndex = bfsTaskIndex;
                        } else {
                            if (bfsTaskList.get(minimumTaskIndex).currentChunkId() > currentBfsTask.currentChunkId()) {
                                minimumTaskIndex = bfsTaskIndex;
                            }
                        }
                    }
                }
                var minimumIndexBfsTask = bfsTaskList.get(minimumTaskIndex);
                minimumIndexBfsTask.syncNextChunk();
                if (!minimumIndexBfsTask.hasMoreChunks()) {
                    numberOfFinishedTasks++;
                }
            }

            if (traversedNodesLength.get() == previousTraversedNodesLength) {
                break;
            }

            traversedNodesIndex.set(previousTraversedNodesLength);
            currentDepth++;
        }

        // Find the portion of `traversedNodes` that contains the actual result, doesn't account for target node, hence the `if` statement.
        var nodesLengthToRetain = traversedNodesLength.get();
        if (targetFoundIndex.get() != Long.MAX_VALUE) {
            nodesLengthToRetain = targetFoundIndex.longValue() + 1;
        }

        var result = traversedNodes.copyOf(nodesLengthToRetain);

        progressTracker.endSubTask();
        return result;
    }

    private List initializeBfsTasks(
        AtomicLong traversedNodesIndex,
        AtomicLong traversedNodesLength,
        AtomicLong targetFoundIndex,
        HugeAtomicLongArray minimumChunk,
        int delta
    ) {
        var bfsTaskList = new ArrayList(concurrency.value());
        for (int i = 0; i < concurrency.value(); ++i) {
            bfsTaskList.add(new BFSTask(
                graph,
                traversedNodes,
                traversedNodesIndex,
                traversedNodesLength,
                visited,
                weights,
                targetFoundIndex,
                minimumChunk,
                exitPredicate,
                aggregatorFunction,
                delta,
                sourceNodeId,
                terminationFlag,
                progressTracker
            ));
        }
        return bfsTaskList;
    }

    private int countTasksWithChunks(Collection bfsTaskList) {
        return (int) bfsTaskList.stream().filter(BFSTask::hasMoreChunks).count();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy