All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.embeddings.node2vec.Node2Vec Maven / Gradle / Ivy

/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.embeddings.node2vec;

import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.EmbeddingUtils;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.traversal.RandomWalkCompanion;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicLong;

public class Node2Vec extends Algorithm {

    private final Graph graph;
    private final Concurrency concurrency;
    private final SamplingWalkParameters samplingWalkParameters;
    private final List sourceNodes;
    private final Optional maybeRandomSeed;
    private final TrainParameters trainParameters;
    private final int walkBufferSize;


    public Node2Vec(
        Graph graph,
        Concurrency concurrency,
        List sourceNodes,
        Optional maybeRandomSeed,
        int walkBufferSize,
        Node2VecParameters node2VecParameters,
        ProgressTracker progressTracker,
        TerminationFlag terminationFlag
    ) {
        super(progressTracker);
        this.graph = graph;
        this.concurrency = concurrency;
        this.samplingWalkParameters = node2VecParameters.samplingWalkParameters();
        this.walkBufferSize = walkBufferSize;
        this.sourceNodes = sourceNodes;
        this.maybeRandomSeed = maybeRandomSeed;
        this.trainParameters = node2VecParameters.trainParameters();

        this.terminationFlag = terminationFlag;
    }

    @Override
    public Node2VecResult compute() {
        progressTracker.beginSubTask("Node2Vec");

        if (graph.hasRelationshipProperty()) {
            EmbeddingUtils.validateRelationshipWeightPropertyValue(
                graph,
                concurrency,
                weight -> weight >= 0,
                "Node2Vec only supports non-negative weights.",
                DefaultPool.INSTANCE
            );
        }

        var probabilitiesBuilder = new RandomWalkProbabilities.Builder(
            graph.nodeCount(),
            concurrency,
            samplingWalkParameters.positiveSamplingFactor(),
            samplingWalkParameters.negativeSamplingExponent()
        );
        var walks = new CompressedRandomWalks(graph.nodeCount() * samplingWalkParameters.walksPerNode());

        progressTracker.beginSubTask("RandomWalk");

        var tasks = walkTasks(
            walks,
            probabilitiesBuilder,
            graph,
            maybeRandomSeed,
            concurrency,
            sourceNodes,
            samplingWalkParameters,
            walkBufferSize,
            DefaultPool.INSTANCE,
            progressTracker,
            terminationFlag
        );

        progressTracker.beginSubTask("create walks");
        RunWithConcurrency.builder().concurrency(concurrency).tasks(tasks).run();
        walks.setMaxWalkLength(tasks.stream()
            .map(Node2VecRandomWalkTask::maxWalkLength)
            .max(Integer::compareTo)
            .orElse(0));

        walks.setSize(tasks.stream()
            .map(task -> (1 + task.maxIndex()))
            .max(Long::compareTo)
            .orElse(0L));

        progressTracker.endSubTask("create walks");
        progressTracker.endSubTask("RandomWalk");

        var node2VecModel = new Node2VecModel(
            graph::toOriginalNodeId,
            graph.nodeCount(),
            trainParameters,
            concurrency,
            maybeRandomSeed,
            walks,
            probabilitiesBuilder.build(),
            progressTracker
        );

        var result = node2VecModel.train();

        progressTracker.endSubTask("Node2Vec");
        return result;
    }

    private List walkTasks(
        CompressedRandomWalks compressedRandomWalks,
        RandomWalkProbabilities.Builder randomWalkPropabilitiesBuilder,
        Graph graph,
        Optional maybeRandomSeed,
        Concurrency concurrency,
        List sourceNodes,
        SamplingWalkParameters samplingWalkParameters,
        int walkBufferSize,
        ExecutorService executorService,
        ProgressTracker progressTracker,
        TerminationFlag terminationFlag
    ) {
        List tasks = new ArrayList<>();
        var randomSeed = maybeRandomSeed.orElseGet(() -> new Random().nextLong());
        var nextNodeSupplier = RandomWalkCompanion.nextNodeSupplier(graph, sourceNodes);
        var cumulativeWeightsSupplier = RandomWalkCompanion.cumulativeWeights(
            graph,
            concurrency,
            executorService,
            progressTracker
        );

        AtomicLong index = new AtomicLong();
        for (int i = 0; i < concurrency.value(); ++i) {
            tasks.add(new Node2VecRandomWalkTask(
                graph.concurrentCopy(),
                nextNodeSupplier,
                samplingWalkParameters.walksPerNode(),
                cumulativeWeightsSupplier,
                progressTracker,
                terminationFlag,
                index,
                compressedRandomWalks,
                randomWalkPropabilitiesBuilder,
                walkBufferSize,
                randomSeed,
                samplingWalkParameters.walkLength(),
                samplingWalkParameters.returnFactor(),
                samplingWalkParameters.inOutFactor()
            ));
        }
        return tasks;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy