All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.similarity.knn.Knn Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds.similarity.knn;

import com.carrotsearch.hppc.LongArrayList;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.stream.LongStream;

public class Knn extends Algorithm {

    public static Knn create(
        Graph graph,
        KnnParameters parameters,
        SimilarityComputer similarityComputer,
        NeighborFilterFactory neighborFilterFactory,
        KnnContext context,
        TerminationFlag terminationFlag
    ) {
        var similarityFunction = new SimilarityFunction(similarityComputer);
        return new Knn(
            graph,
            context.progressTracker(),
            context.executor(),
            parameters.kHolder(),
            parameters.concurrency(),
            parameters.minBatchSize(),
            parameters.maxIterations(),
            parameters.similarityCutoff(),
            parameters.perturbationRate(),
            parameters.randomJoins(),
            parameters.randomSeed(),
            parameters.samplerType(),
            similarityFunction,
            neighborFilterFactory,
            NeighbourConsumers.no_op,
            terminationFlag
        );
    }

    private final Graph graph;
    private final Concurrency concurrency;
    private final int maxIterations;
    private final double similarityCutoff;
    private final int minBatchSize;
    private final NeighborFilterFactory neighborFilterFactory;
    private final ExecutorService executorService;
    private final KnnSampler.Factory samplerFactory;
    private final JoinNeighbors.Factory joinNeighborsFactory;
    private final GenerateRandomNeighbors.Factory generateRandomNeighborsFactory;
    private final SplitOldAndNewNeighbors.Factory splitOldAndNewNeighborsFactory;
    private final long updateThreshold;

    public Knn(
        Graph graph,
        ProgressTracker progressTracker,
        ExecutorService executorService,
        K k,
        Concurrency concurrency,
        int minBatchSize,
        int maxIterations,
        double similarityCutoff,
        double perturbationRate,
        int randomJoins,
        Optional randomSeed,
        KnnSampler.SamplerType initialSamplerType,
        SimilarityFunction similarityFunction,
        NeighborFilterFactory neighborFilterFactory,
        NeighbourConsumers neighborConsumers,
        TerminationFlag terminationFlag
    ) {
        super(progressTracker);
        this.graph = graph;
        this.concurrency = concurrency;
        this.maxIterations = maxIterations;
        this.similarityCutoff = similarityCutoff;
        this.minBatchSize = minBatchSize;
        this.neighborFilterFactory = neighborFilterFactory;
        this.executorService = executorService;

        this.updateThreshold = k.updateThreshold;

        var splittableRandom = randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
        switch (initialSamplerType) {
            case UNIFORM:
                this.samplerFactory = new UniformKnnSampler.Factory(graph.nodeCount(), splittableRandom);
                break;
            case RANDOMWALK:
                this.samplerFactory = new RandomWalkKnnSampler.Factory(graph, randomSeed, k.value, splittableRandom);
                break;
            default:
                throw new IllegalStateException("Invalid KnnSampler");
        }
        this.generateRandomNeighborsFactory = new GenerateRandomNeighbors.Factory(
            similarityFunction,
            neighborConsumers,
            k.value,
            splittableRandom,
            progressTracker
        );
        this.splitOldAndNewNeighborsFactory = new SplitOldAndNewNeighbors.Factory(
            k.sampledValue,
            splittableRandom,
            progressTracker
        );
        this.joinNeighborsFactory = new JoinNeighbors.Factory(
            similarityFunction,
            k.sampledValue,
            perturbationRate,
            randomJoins,
            splittableRandom,
            progressTracker
        );

        this.terminationFlag = terminationFlag;
    }

    public ExecutorService executorService() {
        return executorService;
    }

    @Override
    public KnnResult compute() {
        if (graph.nodeCount() < 2) {
            return new EmptyResult();
        }
        progressTracker.beginSubTask();
        progressTracker.beginSubTask();
        var neighbors = initializeRandomNeighbors();
        progressTracker.endSubTask();

        long updateCount;
        int iteration = 0;
        boolean didConverge = false;

        progressTracker.beginSubTask();
        for (; iteration < maxIterations; iteration++) {
            updateCount = iteration(neighbors);
            if (updateCount <= updateThreshold) {
                iteration++;
                didConverge = true;
                break;
            }
        }
        if (similarityCutoff > 0) {
            var neighborFilterTasks = PartitionUtils.rangePartition(
                concurrency,
                neighbors.size(),
                partition -> (Runnable) () -> partition.consume(
                    nodeId -> neighbors.filterHighSimilarityResult(nodeId, similarityCutoff)
                ),
                Optional.of(minBatchSize)
            );
            RunWithConcurrency.builder()
                .concurrency(concurrency)
                .tasks(neighborFilterTasks)
                .terminationFlag(terminationFlag)
                .executor(executorService)
                .run();
        }
        progressTracker.endSubTask();

        progressTracker.endSubTask();
        return ImmutableKnnResult.of(
            neighbors.data(),
            iteration,
            didConverge,
            neighbors.neighborsFound() + neighbors.joinCounter(),
            graph.nodeCount()
        );
    }

    private Neighbors initializeRandomNeighbors() {
        var neighbors = new Neighbors(graph.nodeCount());

        var randomNeighborGenerators = PartitionUtils.rangePartition(
            concurrency,
            graph.nodeCount(),
            partition -> generateRandomNeighborsFactory.create(
                partition,
                neighbors,
                samplerFactory.create(),
                neighborFilterFactory.create()
            ),
            Optional.of(minBatchSize)
        );

        RunWithConcurrency.builder()
            .concurrency(concurrency)
            .tasks(randomNeighborGenerators)
            .terminationFlag(terminationFlag)
            .executor(executorService)
            .run();

        return neighbors;
    }

    private long iteration(Neighbors neighbors) {
        var nodeCount = graph.nodeCount();

        // TODO: init in ctor and reuse - benchmark against new allocations
        var allOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount);
        var allNewNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount);

        progressTracker.beginSubTask();
        ParallelUtil.readParallel(concurrency, nodeCount, executorService, splitOldAndNewNeighborsFactory.create(
            neighbors,
            allOldNeighbors,
            allNewNeighbors
        ));
        progressTracker.endSubTask();

        // TODO: init in ctor and reuse - benchmark against new allocations
        var reverseOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount);
        var reverseNewNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount);

        progressTracker.beginSubTask();
        reverseOldAndNewNeighbors(
            allOldNeighbors,
            allNewNeighbors,
            reverseOldNeighbors,
            reverseNewNeighbors,
            concurrency,
            minBatchSize,
            progressTracker
        );
        progressTracker.endSubTask();

        var neighborsJoiners = PartitionUtils.rangePartition(
            concurrency,
            nodeCount,
            partition -> joinNeighborsFactory.create(
                partition,
                neighbors,
                allOldNeighbors,
                allNewNeighbors,
                reverseOldNeighbors,
                reverseNewNeighbors,
                neighborFilterFactory.create()
            ),
            Optional.of(minBatchSize)
        );

        progressTracker.beginSubTask();
        RunWithConcurrency.builder()
            .concurrency(concurrency)
            .tasks(neighborsJoiners)
            .terminationFlag(terminationFlag)
            .executor(executorService)
            .run();
        progressTracker.endSubTask();

        return neighborsJoiners.stream().mapToLong(JoinNeighbors::updateCount).sum();
    }

    private static void reverseOldAndNewNeighbors(
        HugeObjectArray allOldNeighbors,
        HugeObjectArray allNewNeighbors,
        HugeObjectArray reverseOldNeighbors,
        HugeObjectArray reverseNewNeighbors,
        Concurrency concurrency,
        int minBatchSize,
        ProgressTracker progressTracker
    ) {
        long nodeCount = allNewNeighbors.size();
        long logBatchSize = ParallelUtil.adjustedBatchSize(nodeCount, concurrency, minBatchSize);

        // TODO: cursors
        for (long nodeId = 0; nodeId < nodeCount; nodeId++) {
            reverseNeighbors(nodeId, allOldNeighbors, reverseOldNeighbors);
            reverseNeighbors(nodeId, allNewNeighbors, reverseNewNeighbors);

            if ((nodeId + 1) % logBatchSize == 0) {
                progressTracker.logProgress(logBatchSize);
            }
        }
    }

    static void reverseNeighbors(
        long nodeId,
        HugeObjectArray allNeighbors,
        HugeObjectArray reverseNeighbors
    ) {
        // adding nodeId to the neighbors of its neighbors (reversing the neighbors direction)
        var neighbors = allNeighbors.get(nodeId);
        if (neighbors != null) {
            for (var neighbor : neighbors) {
                assert neighbor.value != nodeId;
                var oldReverse = reverseNeighbors.get(neighbor.value);
                if (oldReverse == null) {
                    oldReverse = new LongArrayList();
                    reverseNeighbors.set(neighbor.value, oldReverse);
                }
                oldReverse.add(nodeId);
            }
        }
    }

    private static final class EmptyResult extends KnnResult {

        @Override
        HugeObjectArray neighborList() {
            return HugeObjectArray.of();
        }

        @Override
        public int ranIterations() {
            return 0;
        }

        @Override
        public boolean didConverge() {
            return false;
        }

        @Override
        public long nodePairsConsidered() {
            return 0;
        }

        @Override
        public LongStream neighborsOf(long nodeId) {
            return LongStream.empty();
        }

        @Override
        public long size() {
            return 0;
        }

        @Override
        public long nodesCompared() {
            return 0;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy