All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.similarity.knn.RandomWalkKnnSampler Maven / Gradle / Ivy

/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.similarity.knn;

import com.carrotsearch.hppc.LongHashSet;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.ml.core.samplers.LongUniformSamplerFromRange;
import org.neo4j.gds.ml.core.samplers.RandomWalkSampler;

import java.util.Optional;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.function.LongPredicate;

import static org.neo4j.gds.mem.Estimate.sizeOfInstance;
import static org.neo4j.gds.mem.Estimate.sizeOfLongArray;
import static org.neo4j.gds.mem.Estimate.sizeOfLongHashSet;

class RandomWalkKnnSampler implements KnnSampler {

    // Extend the walks a bit to compensate for possibly sampling duplicate nodes in the walks.
    private static final int WALK_LENGTH_MULTIPLIER = 3;

    private final RandomWalkSampler randomWalkSampler;
    private final LongUniformSamplerFromRange uniformSamplerFromRange;
    private final long exclusiveMax;
    private final LongHashSet sampledValuesCache;

    RandomWalkKnnSampler(
        Graph graph,
        SplittableRandom random,
        // Since RandomWalk seeds per node the RandomWalkSampler can't take a SplittableRandom.
        Optional randomSeed,
        int k
    ) {
        assert k > 0;

        this.randomWalkSampler = new RandomWalkSampler(
            graph,
            graph::degree,
            WALK_LENGTH_MULTIPLIER * k,
            // Prefer deeper walks.
            0.4,
            0.6,
            1.0,
           randomSeed.orElseGet(() -> new Random().nextLong())
        );
        this.uniformSamplerFromRange = new LongUniformSamplerFromRange(random);
        this.exclusiveMax = graph.nodeCount();
        this.sampledValuesCache = new LongHashSet();
    }

    public static MemoryRange memoryEstimation(long boundedK) {
        var baseEstimation = RandomWalkSampler.memoryEstimation(boundedK * WALK_LENGTH_MULTIPLIER)
            .add(MemoryRange.of(
                sizeOfInstance(RandomWalkKnnSampler.class) +
                sizeOfLongArray(boundedK) +
                sizeOfLongHashSet(boundedK)
            ));

        return baseEstimation
            .add(LongUniformSamplerFromRange.memoryEstimation(0))
            .union(baseEstimation.add(LongUniformSamplerFromRange.memoryEstimation(boundedK)));
    }

    @Override
    public long[] sample(
        long nodeId,
        long lowerBoundOnValidSamplesInRange,
        final int numberOfSamples,
        final LongPredicate isInvalidSample
    ) {

        final var walk = randomWalkSampler.walk(nodeId);

        sampledValuesCache.clear();
        final var samples = new long[numberOfSamples];
        int addedSamples = 0;
        for (int i = 1; i < walk.length; i++) {
            long node = walk[i];

            if (isInvalidSample.test(node)) {
                continue;
            }

            if (sampledValuesCache.contains(node)) {
                continue;
            }

            sampledValuesCache.add(node);
            samples[addedSamples++] = node;

            if (addedSamples == numberOfSamples) {
                return samples;
            }
        }

        // Fill up with uniformly random nodes if walk did not contain enough unique valid sample candidates.
        var uniformSamples = uniformSamplerFromRange.sample(
            0,
            exclusiveMax,
            lowerBoundOnValidSamplesInRange - addedSamples,
            numberOfSamples - addedSamples,
            node -> isInvalidSample.test(node) || sampledValuesCache.contains(node)
        );

        System.arraycopy(uniformSamples, 0, samples, addedSamples, uniformSamples.length);

        return samples;
    }

    public static class Factory implements KnnSampler.Factory {
        private final Graph graph;
        private final Optional randomSeed;
        private final int boundedK;
        private final SplittableRandom random;

        Factory(Graph graph, Optional randomSeed, int boundedK, SplittableRandom random) {
            this.graph = graph;
            this.randomSeed = randomSeed;
            this.boundedK = boundedK;
            this.random = random;
        }

        public KnnSampler create() {
            return new RandomWalkKnnSampler(graph.concurrentCopy(), random.split(), randomSeed, boundedK);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy