All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.embeddings.fastrp.FastRP Maven / Gradle / Ivy

/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.embeddings.fastrp;

import org.jetbrains.annotations.TestOnly;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.partition.DegreePartition;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionConsumer;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.features.FeatureConsumer;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.features.FeatureExtractor;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.utils.CloseableThreadLocal;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations.addInPlace;
import static org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations.addWeightedInPlace;
import static org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations.l2Norm;
import static org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations.scale;
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;

public class FastRP extends Algorithm {

    private static final int SPARSITY = 3;
    private static final double ENTRY_PROBABILITY = 1.0 / (2 * SPARSITY);
    private static final float EPSILON = 10f / Float.MAX_VALUE;

    private final Graph graph;
    private final Concurrency concurrency;
    private final float normalizationStrength;
    private final List featureExtractors;
    private final Optional relationshipWeightProperty;
    private final double relationshipWeightFallback;
    private final int inputDimension;
    private final float[][] propertyVectors;
    private final HugeObjectArray embeddings;
    private final HugeObjectArray embeddingA;
    private final HugeObjectArray embeddingB;
    private final EmbeddingCombiner embeddingCombiner;
    private final long randomSeed;

    private final int embeddingDimension;
    private final int baseEmbeddingDimension;
    private final Number nodeSelfInfluence;
    private final List iterationWeights;
    private final int minBatchSize;
    private List partitions;

    public FastRP(
        Graph graph,
        FastRPParameters parameters,
        Concurrency concurrency,
        int minBatchSize,
        List featureExtractors,
        ProgressTracker progressTracker,
        Optional randomSeed,
        TerminationFlag terminationFlag
    ) {
        super(progressTracker);
        this.graph = graph;
        this.featureExtractors = featureExtractors;
        this.relationshipWeightProperty = parameters.relationshipWeightProperty();
        this.relationshipWeightFallback = this.relationshipWeightProperty.map(s -> Double.NaN).orElse(1.0);
        this.inputDimension = FeatureExtraction.featureCount(featureExtractors);
        this.randomSeed = improveSeed(randomSeed.orElseGet(System::nanoTime));
        this.concurrency = concurrency;
        this.minBatchSize = minBatchSize;

        this.propertyVectors = new float[inputDimension][parameters.propertyDimension()];
        this.embeddings = HugeObjectArray.newArray(float[].class, graph.nodeCount());
        this.embeddingA = HugeObjectArray.newArray(float[].class, graph.nodeCount());
        this.embeddingB = HugeObjectArray.newArray(float[].class, graph.nodeCount());

        this.embeddingDimension = parameters.embeddingDimension();
        this.baseEmbeddingDimension = parameters.embeddingDimension() - parameters.propertyDimension();
        this.iterationWeights = parameters.iterationWeights();
        this.nodeSelfInfluence = parameters.nodeSelfInfluence();
        this.normalizationStrength = parameters.normalizationStrength();
        this.embeddingCombiner = graph.hasRelationshipProperty()
            ? this::addArrayValuesWeighted
            : (lhs, rhs, ignoreWeight) -> addInPlace(lhs, rhs);
        this.embeddings.setAll((i) -> new float[embeddingDimension]);

        this.terminationFlag = terminationFlag;
    }

    @Override
    public FastRPResult compute() {
        progressTracker.beginSubTask();
        initDegreePartition();
        initPropertyVectors();
        initRandomVectors();
        addInitialVectorsToEmbedding();
        propagateEmbeddings();
        progressTracker.endSubTask();
        return new FastRPResult(embeddings);
    }

    public void initDegreePartition() {
        this.partitions = PartitionUtils.degreePartitionStream(
            graph.nodeCount(),
            graph.relationshipCount(),
            concurrency,
            graph::degree
        ).collect(Collectors.toList());
    }

    void initPropertyVectors() {
        int propertyDimension = embeddingDimension - baseEmbeddingDimension;
        float entryValue = (float) Math.sqrt(SPARSITY) / (float) Math.sqrt(embeddingDimension);
        var random = new HighQualityRandom(randomSeed);
        for (int i = 0; i < inputDimension; i++) {
            this.propertyVectors[i] = new float[propertyDimension];
        }
        for (int d = 0; d < propertyDimension; d++) {
            for (int i = 0; i < inputDimension; i++) {
                this.propertyVectors[i][d] = computeRandomEntry(random, entryValue);
            }
        }
    }

    void initRandomVectors() {
        progressTracker.beginSubTask();

        var sqrtEmbeddingDimension = (float) Math.sqrt(embeddingDimension);
        List tasks = PartitionUtils.rangePartition(
            concurrency,
            graph.nodeCount(),
            partition -> new InitRandomVectorTask(
                partition,
                sqrtEmbeddingDimension
            ),
            Optional.of(minBatchSize)
        );
        RunWithConcurrency.builder()
            .concurrency(concurrency)
            .tasks(tasks)
            .run();

        progressTracker.endSubTask();
    }

    void addInitialVectorsToEmbedding() {
        if (Float.compare(nodeSelfInfluence.floatValue(), 0.0f) == 0) return;
        progressTracker.beginSubTask();

        ParallelUtil.parallelForEachNode(
            graph.nodeCount(),
            concurrency,
            terminationFlag,
            this::addInitialStateToEmbedding
        );

        progressTracker.endSubTask();
    }

    void propagateEmbeddings() {
        progressTracker.beginSubTask();

        for (int i = 0; i < iterationWeights.size(); i++) {
            progressTracker.beginSubTask();

            HugeObjectArray currentEmbeddings = i % 2 == 0 ? embeddingA : embeddingB;
            HugeObjectArray previousEmbeddings = i % 2 == 0 ? embeddingB : embeddingA;
            var iterationWeight = iterationWeights.get(i).floatValue();
            boolean firstIteration = i == 0;

            Supplier> taskSupplier = () -> new PropagateEmbeddingsTask(
                currentEmbeddings,
                previousEmbeddings,
                iterationWeight,
                firstIteration
            );

            parallelPartitionsConsume(
                RunWithConcurrency.builder().executor(DefaultPool.INSTANCE).concurrency(concurrency),
                partitions.stream(),
                taskSupplier
            );

            progressTracker.endSubTask();
        }
        progressTracker.endSubTask();
    }

    /**
     * This method is useful, when |partitions| is greatly larger than concurrency as we only create a single consumer per thread.
     * Compared to parallelForEachNode, thread local state does not need to be resolved for each node but only per partition.
     */
    public static 

void parallelPartitionsConsume( RunWithConcurrency.Builder runnerBuilder, Stream

partitions, Supplier> taskSupplier ) { try (var localConsumer = CloseableThreadLocal.withInitial(taskSupplier)) { var taskStream = partitions.map(partition -> (Runnable) () -> localConsumer.get().consume(partition)); runnerBuilder.tasks(taskStream); runnerBuilder.run(); } } @TestOnly HugeObjectArray currentEmbedding(int iteration) { return iteration % 2 == 0 ? this.embeddingA : this.embeddingB; } @TestOnly float[][] propertyVectors() { return propertyVectors; } @TestOnly HugeObjectArray embeddings() { return embeddings; } private void addArrayValuesWeighted(float[] lhs, float[] rhs, double weight) { for (int i = 0; i < lhs.length; i++) { lhs[i] = (float) Math.fma(rhs[i], weight, lhs[i]); } } private static float computeRandomEntry(Random random, float entryValue) { double randomValue = random.nextDouble(); if (randomValue < ENTRY_PROBABILITY) { return entryValue; } else if (randomValue < ENTRY_PROBABILITY * 2.0) { return -entryValue; } else { return 0.0f; } } private static class HighQualityRandom extends Random { private long u; private long v; private long w; public HighQualityRandom(long seed) { reseed(seed); } public void reseed(long seed) { v = 4101842887655102017L; w = 1; u = seed ^ v; nextLong(); v = u; nextLong(); w = v; nextLong(); } @Override public long nextLong() { u = u * 2862933555777941757L + 7046029254386353087L; v ^= v >>> 17; v ^= v << 31; v ^= v >>> 8; w = 4294957665L * w + (w >>> 32); long x = u ^ (u << 21); x ^= x >>> 35; x ^= x << 4; return (x + v) ^ w; } @Override protected int next(int bits) { return (int) (nextLong() >>> (64-bits)); } } private long improveSeed(long randomSeed) { return new HighQualityRandom(randomSeed).nextLong(); } private interface EmbeddingCombiner { void combine(float[] into, float[] add, double weight); } private final class InitRandomVectorTask implements Runnable { final float sqrtSparsity = (float) Math.sqrt(SPARSITY); private final Partition partition; private final float sqrtEmbeddingDimension; private final PropertyVectorAdder propertyVectorAdder; private InitRandomVectorTask( Partition partition, float sqrtEmbeddingDimension ) { this.partition = partition; this.sqrtEmbeddingDimension = sqrtEmbeddingDimension; this.propertyVectorAdder = new PropertyVectorAdder(); } @Override public void run() { // this value currently doesnt matter because of reseeding below var random = new HighQualityRandom(randomSeed); partition.consume( nodeId -> { int degree = graph.degree(nodeId); float scaling = degree == 0 ? 1.0f : (float) Math.pow(degree, normalizationStrength); float entryValue = scaling * sqrtSparsity / sqrtEmbeddingDimension; random.reseed(randomSeed ^ graph.toOriginalNodeId(nodeId)); var randomVector = computeRandomVector(nodeId, random, entryValue); embeddingB.set(nodeId, randomVector); embeddingA.set(nodeId, new float[embeddingDimension]); }); progressTracker.logProgress(partition.nodeCount()); } private float[] computeRandomVector(long nodeId, Random random, float entryValue) { var randomVector = new float[embeddingDimension]; for (int i = 0; i < baseEmbeddingDimension; i++) { randomVector[i] = computeRandomEntry(random, entryValue); } propertyVectorAdder.setRandomVector(randomVector); FeatureExtraction.extract(nodeId, -1, featureExtractors, propertyVectorAdder); return randomVector; } private class PropertyVectorAdder implements FeatureConsumer { private float[] randomVector; void setRandomVector(float[] randomVector) { this.randomVector = randomVector; } @Override public void acceptScalar(long ignored, int offset, double value) { float floatValue = (float) value; for (int i = baseEmbeddingDimension; i < embeddingDimension; i++) { randomVector[i] += floatValue * propertyVectors[offset][i - baseEmbeddingDimension]; } } @Override public void acceptArray(long ignored, int offset, double[] values) { for (int j = 0; j < values.length; j++) { var value = (float) values[j]; float[] propertyVector = propertyVectors[offset + j]; for (int i = baseEmbeddingDimension; i < embeddingDimension; i++) { randomVector[i] += value * propertyVector[i - baseEmbeddingDimension]; } } } } } private void addInitialStateToEmbedding(long nodeId) { var initialVector = embeddingB.get(nodeId); var l2Norm = l2Norm(initialVector); float adjustedL2Norm = l2Norm < EPSILON ? 1f : l2Norm; addWeightedInPlace(embeddings.get(nodeId), initialVector, nodeSelfInfluence.floatValue() / adjustedL2Norm); progressTracker.logProgress(1); } private final class PropagateEmbeddingsTask implements PartitionConsumer { private final HugeObjectArray currentEmbeddings; private final HugeObjectArray previousEmbeddings; private final float iterationWeight; private final Graph localGraph; private final boolean firstIteration; private PropagateEmbeddingsTask( HugeObjectArray currentEmbeddings, HugeObjectArray previousEmbeddings, float iterationWeight, boolean firstIteration ) { this.currentEmbeddings = currentEmbeddings; this.previousEmbeddings = previousEmbeddings; this.iterationWeight = iterationWeight; this.localGraph = graph.concurrentCopy(); this.firstIteration = firstIteration; } public void consume(DegreePartition partition) { partition.consume(nodeId -> { var embedding = embeddings.get(nodeId); var currentEmbedding = currentEmbeddings.get(nodeId); Arrays.fill(currentEmbedding, 0.0f); // Collect and combine the neighbour embeddings localGraph.forEachRelationship(nodeId, relationshipWeightFallback, (source, target, weight) -> { if (firstIteration && Double.isNaN(weight)) { throw new IllegalArgumentException(formatWithLocale( "Missing relationship property `%s` on relationship between nodes with ids `%d` and `%d`.", relationshipWeightProperty.orElse(""), graph.toOriginalNodeId(source), graph.toOriginalNodeId(target) )); } embeddingCombiner.combine(currentEmbedding, previousEmbeddings.get(target), weight); return true; }); // Normalize neighbour embeddings var degree = graph.degree(nodeId); int adjustedDegree = degree == 0 ? 1 : degree; float degreeScale = 1.0f / adjustedDegree; scale(currentEmbedding, degreeScale); var invL2Norm = 1.0f / l2Norm(currentEmbedding); var safeInvL2Norm = Float.isFinite(invL2Norm) ? invL2Norm : 1.0f; // Update the result embedding addWeightedInPlace(embedding, currentEmbedding, safeInvL2Norm * iterationWeight); }); progressTracker.logProgress(partition.relationshipCount()); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy