All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.labelpropagation.LabelPropagation Maven / Gradle / Ivy

There is a newer version: 2.15.0
Show newest version
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.labelpropagation;

import org.neo4j.gds.Algorithm;
import org.neo4j.gds.algorithms.community.CommunityCompanion;
import org.neo4j.gds.api.DefaultValue;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.primitive.PrimitiveLongCollections;
import org.neo4j.gds.collections.primitive.PrimitiveLongIterable;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.loading.NullPropertyMap.DoubleNullPropertyMap;
import org.neo4j.gds.core.loading.NullPropertyMap.LongNullPropertyMap;
import org.neo4j.gds.core.utils.LazyBatchCollection;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutorService;

import static java.util.concurrent.TimeUnit.MICROSECONDS;
import static org.neo4j.kernel.api.StatementConstants.NO_SUCH_LABEL;

public class LabelPropagation extends Algorithm {

    static final String LABEL_PROPAGATION_DESCRIPTION =
        "The Label Propagation algorithm is a fast algorithm for finding communities in a graph.";
    private final long nodeCount;
    private final NodePropertyValues nodePropertyValues;
    private final NodePropertyValues nodeWeights;
    private final ExecutorService executor;

    private final Graph graph;
    private final LabelPropagationParameters parameters;
    private HugeLongArray labels;
    private final long maxLabelId;
    private int batchSize;

    public LabelPropagation(
        Graph graph,
        LabelPropagationParameters parameters,
        ExecutorService executor,
        ProgressTracker progressTracker,
        TerminationFlag terminationFlag
    ) {
        super(progressTracker);
        this.graph = graph;
        this.nodeCount = graph.nodeCount();
        this.parameters = parameters;
        this.executor = executor;
        this.batchSize = ParallelUtil.DEFAULT_BATCH_SIZE;

        NodePropertyValues seedProperty;
        String seedPropertyKey = parameters.seedProperty();
        if (seedPropertyKey != null && graph.availableNodeProperties().contains(seedPropertyKey)) {
            seedProperty = CommunityCompanion.extractSeedingNodePropertyValues(graph, seedPropertyKey);
        } else {
            seedProperty = new LongNullPropertyMap(DefaultValue.LONG_DEFAULT_FALLBACK);
        }
        this.nodePropertyValues = seedProperty;

        NodePropertyValues nodeWeightProperty;
        String nodeWeightPropertyKey = parameters.nodeWeightProperty();
        if (nodeWeightPropertyKey != null && graph.availableNodeProperties().contains(nodeWeightPropertyKey)) {
            nodeWeightProperty = graph.nodeProperties(nodeWeightPropertyKey);
        } else {
            nodeWeightProperty = new DoubleNullPropertyMap(1.0);
        }
        this.nodeWeights = nodeWeightProperty;

        maxLabelId = seedProperty.getMaxLongPropertyValue().orElse(NO_SUCH_LABEL);

        this.terminationFlag = terminationFlag;
    }

    @Override
    public LabelPropagationResult compute() {
        if (parameters.maxIterations() <= 0L) {
            throw new IllegalArgumentException("Must iterate at least 1 time");
        }

        progressTracker.beginSubTask();

        if (labels == null || labels.size() != nodeCount) {
            labels = HugeLongArray.newArray(nodeCount);
        }

        long ranIterations = 0L;
        boolean didConverge = false;

        List stepRunners = stepRunners();

        progressTracker.beginSubTask();
        while (ranIterations < parameters.maxIterations()) {
            progressTracker.beginSubTask();
            RunWithConcurrency.builder()
                .concurrency(parameters.concurrency())
                .tasks(stepRunners)
                .waitTime(1L, MICROSECONDS)
                .terminationFlag(terminationFlag)
                .executor(executor)
                .run();
            ++ranIterations;
            didConverge = stepRunners.stream().allMatch(StepRunner::didConverge);
            progressTracker.endSubTask();
            if (didConverge) {
                break;
            }
        }
        progressTracker.endSubTask();

        stepRunners.forEach(StepRunner::release);
        progressTracker.endSubTask();

        return new LabelPropagationResult(labels, didConverge, ranIterations);
    }

    private List stepRunners() {
        long nodeCount = graph.nodeCount();
        long batchSize = ParallelUtil.adjustedBatchSize(nodeCount, this.batchSize);

        Collection nodeBatches = LazyBatchCollection.of(
            nodeCount,
            batchSize,
            (start, length) -> () -> PrimitiveLongCollections.range(start, start + length - 1L)
        );

        int threads = nodeBatches.size();
        List tasks = new ArrayList<>(threads);
        for (PrimitiveLongIterable iter : nodeBatches) {
            InitStep initStep = new InitStep(
                graph,
                nodePropertyValues,
                nodeWeights,
                iter,
                labels,
                progressTracker,
                maxLabelId
            );
            StepRunner task = new StepRunner(initStep);
            tasks.add(task);
        }
        progressTracker.beginSubTask();
        RunWithConcurrency.builder()
            .concurrency(parameters.concurrency())
            .tasks(tasks)
            .waitTime(1, MICROSECONDS)
            .terminationFlag(terminationFlag)
            .executor(executor)
            .run();
        progressTracker.endSubTask();
        return tasks;
    }

    void withBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy