All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.embeddings.graphsage.GraphSageHelper Maven / Gradle / Ivy

There is a newer version: 2.15.0
Show newest version
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.embeddings.graphsage;

import org.apache.commons.lang3.mutable.MutableInt;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.schema.GraphSchema;
import org.neo4j.gds.api.schema.NodeSchemaEntry;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainMemoryEstimateParameters;
import org.neo4j.gds.embeddings.graphsage.algo.MultiLabelFeatureExtractors;
import org.neo4j.gds.ml.core.NeighborhoodFunction;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.features.BiasFeature;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.features.FeatureExtractor;
import org.neo4j.gds.ml.core.features.HugeObjectArrayFeatureConsumer;
import org.neo4j.gds.ml.core.functions.NormalizeRows;
import org.neo4j.gds.ml.core.subgraph.NeighborhoodSampler;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import static org.neo4j.gds.mem.Estimate.sizeOfDoubleArray;
import static org.neo4j.gds.mem.Estimate.sizeOfIntArray;
import static org.neo4j.gds.mem.Estimate.sizeOfLongArray;
import static org.neo4j.gds.mem.Estimate.sizeOfObjectArray;
import static org.neo4j.gds.ml.core.features.FeatureExtraction.featureCount;
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;

public final class GraphSageHelper {

    private GraphSageHelper() {}

    static Variable embeddingsComputationGraph(
        List subGraphs,
        Layer[] layers,
        Variable batchedFeaturesExtractor
    ) {
        Variable previousLayerRepresentations = batchedFeaturesExtractor;

        for (int layerNr = layers.length - 1; layerNr >= 0; layerNr--) {
            Layer layer = layers[layers.length - layerNr - 1];
            previousLayerRepresentations = layer
                .aggregator()
                .aggregate(
                    previousLayerRepresentations,
                    subGraphs.get(layerNr)
                );
        }
        return new NormalizeRows(previousLayerRepresentations);
    }

    // expecting a thread-local graph here
    static List subGraphsPerLayer(Graph graph, long[] nodeIds, Layer[] layers, long randomSeed) {
        var random = new Random(randomSeed);

        List samplers = Arrays
            .stream(layers)
            .map(layer -> {
                var neighborhoodSampler = new NeighborhoodSampler(random.nextLong());
                return (NeighborhoodFunction) (nodeId) -> neighborhoodSampler.sample(graph, nodeId, layer.sampleSize());
            })
            .collect(Collectors.toList());

        Collections.reverse(samplers);

        return SubGraph.buildSubGraphs(nodeIds, samplers, SubGraph.relationshipWeightFunction(graph));
    }

    public static MemoryEstimation embeddingsEstimation(
        GraphSageTrainMemoryEstimateParameters config,
        long batchSize,
        long nodeCount,
        int labelCount,
        boolean withGradientDescent
    ) {
        var isMultiLabel = config.isMultiLabel();

        var layerConfigs = config.layerConfigs();
        var numberOfLayers = layerConfigs.size();

        var computationGraphBuilder = MemoryEstimations.builder("computationGraph").startField("subgraphs");

        final var minBatchNodeCounts = new ArrayList(numberOfLayers + 1);
        final var maxBatchNodeCounts = new ArrayList(numberOfLayers + 1);
        minBatchNodeCounts.add(batchSize);
        maxBatchNodeCounts.add(batchSize);

        for (int i = 0; i < numberOfLayers; i++) {
            var sampleSize = layerConfigs.get(i).sampleSize();

            var min = minBatchNodeCounts.get(i);
            var max = maxBatchNodeCounts.get(i);
            var minNextNodeCount = Math.min(min, nodeCount);
            var maxNextNodeCount = Math.min(max * (sampleSize + 1), nodeCount);
            minBatchNodeCounts.add(minNextNodeCount);
            maxBatchNodeCounts.add(maxNextNodeCount);

            var subgraphRange = MemoryRange.of(
                sizeOfIntArray(min) + sizeOfObjectArray(min) + min * sizeOfIntArray(0) + sizeOfLongArray(
                    minNextNodeCount),
                sizeOfIntArray(max) + sizeOfObjectArray(max) + max * sizeOfIntArray(sampleSize) + sizeOfLongArray(
                    maxNextNodeCount)
            );

            computationGraphBuilder.add(MemoryEstimations.of("subgraph " + (i + 1), subgraphRange));
        }

        // aggregators go backwards through the layers
        Collections.reverse(minBatchNodeCounts);
        Collections.reverse(maxBatchNodeCounts);

        var aggregatorsBuilder = MemoryEstimations.builder();
        for (int i = 0; i < numberOfLayers; i++) {
            var layerConfig = layerConfigs.get(i);

            var minPreviousNodeCount = minBatchNodeCounts.get(i);
            var maxPreviousNodeCount = maxBatchNodeCounts.get(i);
            var minNodeCount = minBatchNodeCounts.get(i + 1);
            var maxNodeCount = maxBatchNodeCounts.get(i + 1);

            if (i == 0) {
                var featureSize = config.estimationFeatureDimension();
                MemoryRange firstLayerMemory = MemoryRange.of(
                    sizeOfDoubleArray(minPreviousNodeCount * featureSize),
                    sizeOfDoubleArray(maxPreviousNodeCount * featureSize)
                );
                if (isMultiLabel) {
                    // for the matrix product of weights x node features for a single node
                    firstLayerMemory = firstLayerMemory.add(MemoryRange.of(sizeOfDoubleArray(featureSize)));
                }
                aggregatorsBuilder.fixed("firstLayer", firstLayerMemory);
            }

            var aggregatorType = layerConfig.aggregatorType();
            var embeddingDimension = config.embeddingDimension();

            var aggregatorMemoryEstimation = switch (aggregatorType) {
                case MEAN -> new MeanAggregatorMemoryEstimator();
                case POOL -> new PoolAggregatorMemoryEstimator();
            };

            var aggregatorMemoryRange = aggregatorMemoryEstimation.estimate(
                minNodeCount,
                maxNodeCount,
                minPreviousNodeCount,
                maxPreviousNodeCount,
                layerConfig.cols(),
                embeddingDimension
            );

            aggregatorsBuilder.fixed(
                formatWithLocale("%s %d", aggregatorType.name(), i + 1),
                aggregatorMemoryRange
            );

            if (i == numberOfLayers - 1) {
                aggregatorsBuilder.fixed(
                    "normalizeRows",
                    MemoryRange.of(
                        sizeOfDoubleArray(minNodeCount * embeddingDimension),
                        sizeOfDoubleArray(maxNodeCount * embeddingDimension)
                    )
                );
            }
        }

        computationGraphBuilder = computationGraphBuilder
            .endField();

        if (isMultiLabel) {
            var minFeatureFunction = sizeOfObjectArray(minBatchNodeCounts.get(0));
            var maxFeatureFunction = sizeOfObjectArray(maxBatchNodeCounts.get(0));
            var copyOfLabels = sizeOfObjectArray(labelCount);

            computationGraphBuilder.fixed(
                "multiLabelFeatureFunction",
                MemoryRange.of(minFeatureFunction, maxFeatureFunction).add(MemoryRange.of(copyOfLabels))
            );
        }

        computationGraphBuilder = computationGraphBuilder
            .startField("forward")
            .addComponentsOf(aggregatorsBuilder.build());

        if (withGradientDescent) {
            computationGraphBuilder = computationGraphBuilder
                .endField()
                .startField("backward")
                .addComponentsOf(aggregatorsBuilder.build());
        }
        return computationGraphBuilder.endField().build();
    }

    public static HugeObjectArray initializeSingleLabelFeatures(
        Graph graph,
        Collection featureProperties
    ) {
        var features = HugeObjectArray.newArray(double[].class, graph.nodeCount());
        var extractors = FeatureExtraction.propertyExtractors(graph, featureProperties);

        return FeatureExtraction.extract(graph, extractors, features);
    }

    public static MultiLabelFeatureExtractors multiLabelFeatureExtractors(Graph graph, List featureProperties) {
        var filteredKeysPerLabel = filteredPropertyKeysPerNodeLabel(graph, featureProperties);
        var featureCountPerLabel = new HashMap();
        var extractorsPerLabel = new HashMap>();
        graph.forEachNode(nodeId -> {
            var nodeLabel = labelOf(graph, nodeId);
            extractorsPerLabel.computeIfAbsent(nodeLabel, label -> {
                var propertyKeys = filteredKeysPerLabel.get(label);
                var featureExtractors = new ArrayList<>(FeatureExtraction.propertyExtractors(graph, propertyKeys, nodeId));
                featureExtractors.add(new BiasFeature());
                return featureExtractors;
            });
            featureCountPerLabel.computeIfAbsent(
                nodeLabel,
                label -> featureCount(extractorsPerLabel.get(label))
            );
            return true;
        });
        return new MultiLabelFeatureExtractors(featureCountPerLabel, extractorsPerLabel);
    }

    public static HugeObjectArray initializeMultiLabelFeatures(
        Graph graph,
        MultiLabelFeatureExtractors multiLabelFeatureExtractors
    ) {
        var features = HugeObjectArray.newArray(double[].class, graph.nodeCount());
        var featureConsumer = new HugeObjectArrayFeatureConsumer(features);
        graph.forEachNode(nodeId -> {
            var nodeLabel = labelOf(graph, nodeId);
            var extractors = multiLabelFeatureExtractors.extractorsPerLabel().get(nodeLabel);
            var featureCount = multiLabelFeatureExtractors.featureCountPerLabel().get(nodeLabel);
            features.set(nodeId, new double[featureCount]);
            FeatureExtraction.extract(nodeId, nodeId, extractors, featureConsumer);
            return true;
        });

        return features;
    }

    public static List layerConfigs(int featureDimension, List sampleSizes, Optional randomSeed, AggregatorType aggregatorType, ActivationFunctionType activationFunction, int embeddingDimension) {
        Random random = new Random();
        randomSeed.ifPresent(random::setSeed);

        List result = new ArrayList<>(sampleSizes.size());
        for (int i = 0; i < sampleSizes.size(); i++) {
            LayerConfig layerConfig = LayerConfig.builder()
                .aggregatorType(aggregatorType)
                .activationFunction(activationFunction)
                .rows(embeddingDimension)
                .cols(i == 0 ? featureDimension : embeddingDimension)
                .sampleSize(sampleSizes.get(i))
                .randomSeed(random.nextLong())
                .build();

            result.add(layerConfig);
        }

        return result;
    }

    private static Map> propertyKeysPerNodeLabel(GraphSchema graphSchema) {
        return graphSchema
            .nodeSchema()
            .entries()
            .stream()
            .collect(Collectors.toMap(NodeSchemaEntry::identifier, e -> e.properties().keySet()));
    }

    private static Map> filteredPropertyKeysPerNodeLabel(Graph graph, List featureProperties) {
        return propertyKeysPerNodeLabel(graph.schema())
            .entrySet()
            .stream()
            .collect(Collectors.toMap(
                Map.Entry::getKey,
                e -> featureProperties
                    .stream()
                    .filter(e.getValue()::contains)
                    .collect(Collectors.toSet())
            ));
    }

    private static NodeLabel labelOf(IdMap idMap, long nodeId) {
        var labelRef = new AtomicReference();
        var labelCount = new MutableInt(0);

        idMap.forEachNodeLabel(nodeId, nodeLabel -> {
            labelRef.set(nodeLabel);
            return labelCount.getAndIncrement() == 0;
        });

        if (labelCount.intValue() != 1) {
            throw new IllegalArgumentException(
                formatWithLocale("Each node must have exactly one label: nodeId=%d, labels=%s", nodeId, idMap.nodeLabels(nodeId))
            );
        }

        return labelRef.get();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy