All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.embeddings.graphsage.GraphSageEmbeddingsGenerator Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.embeddings.graphsage;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadLocalRandom;

public class GraphSageEmbeddingsGenerator {
    private final Layer[] layers;
    private final int batchSize;
    private final Concurrency concurrency;
    private final FeatureFunction featureFunction;
    private final long randomSeed;
    private final ExecutorService executor;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;

    public GraphSageEmbeddingsGenerator(
        Layer[] layers,
        int batchSize,
        Concurrency concurrency,
        FeatureFunction featureFunction,
        Optional randomSeed,
        ExecutorService executor,
        ProgressTracker progressTracker,
        TerminationFlag terminationFlag
    ) {
        this.layers = layers;
        this.batchSize = batchSize;
        this.concurrency = concurrency;
        this.featureFunction = featureFunction;
        this.randomSeed = randomSeed.orElseGet(() -> ThreadLocalRandom.current().nextLong());
        this.executor = executor;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
    }

    public HugeObjectArray makeEmbeddings(
        Graph graph,
        HugeObjectArray features
    ) {
        HugeObjectArray result = HugeObjectArray.newArray(
            double[].class,
            graph.nodeCount()
        );

        progressTracker.beginSubTask();

        var tasks = PartitionUtils.rangePartitionWithBatchSize(
            graph.nodeCount(),
            batchSize,
            partition -> createEmbeddings(graph.concurrentCopy(), partition, features, result, terminationFlag)
        );

        RunWithConcurrency.builder()
            .concurrency(concurrency)
            .tasks(tasks)
            .executor(executor)
            .run();

        progressTracker.endSubTask();

        return result;
    }

    private Runnable createEmbeddings(
        Graph graph,
        Partition partition,
        HugeObjectArray features,
        HugeObjectArray result,
        TerminationFlag terminationFlag
    ) {
        return () -> {
            terminationFlag.assertRunning();
            List subGraphs = GraphSageHelper.subGraphsPerLayer(
                graph,
                partition.stream().toArray(),
                layers,
                randomSeed
            );

            Variable batchedFeaturesExtractor = featureFunction.apply(
                graph,
                subGraphs.get(subGraphs.size() - 1).originalNodeIds(),
                features
            );

            Variable embeddingVariable = GraphSageHelper.embeddingsComputationGraph(
                subGraphs,
                layers,
                batchedFeaturesExtractor
            );

            Matrix embeddings = new ComputationContext().forward(embeddingVariable);

            var partitionStartNodeId = partition.startNode();
            var partitionNodeCount = partition.nodeCount();
            for (int partitionIdx = 0; partitionIdx < partitionNodeCount; partitionIdx++) {
                long nodeId = partitionStartNodeId + partitionIdx;
                result.set(nodeId, embeddings.getRow(partitionIdx));
            }

            progressTracker.logProgress(partitionNodeCount);
        };
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy