org.neo4j.gds.similarity.knn.Knn Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of algo Show documentation
Show all versions of algo Show documentation
Neo4j Graph Data Science :: Algorithms
The newest version!
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package org.neo4j.gds.similarity.knn;
import com.carrotsearch.hppc.LongArrayList;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
import org.neo4j.gds.termination.TerminationFlag;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.stream.LongStream;
public class Knn extends Algorithm {
public static Knn create(
Graph graph,
KnnParameters parameters,
SimilarityComputer similarityComputer,
NeighborFilterFactory neighborFilterFactory,
KnnContext context,
TerminationFlag terminationFlag
) {
var similarityFunction = new SimilarityFunction(similarityComputer);
return new Knn(
graph,
context.progressTracker(),
context.executor(),
parameters.kHolder(),
parameters.concurrency(),
parameters.minBatchSize(),
parameters.maxIterations(),
parameters.similarityCutoff(),
parameters.perturbationRate(),
parameters.randomJoins(),
parameters.randomSeed(),
parameters.samplerType(),
similarityFunction,
neighborFilterFactory,
NeighbourConsumers.no_op,
terminationFlag
);
}
private final Graph graph;
private final Concurrency concurrency;
private final int maxIterations;
private final double similarityCutoff;
private final int minBatchSize;
private final NeighborFilterFactory neighborFilterFactory;
private final ExecutorService executorService;
private final KnnSampler.Factory samplerFactory;
private final JoinNeighbors.Factory joinNeighborsFactory;
private final GenerateRandomNeighbors.Factory generateRandomNeighborsFactory;
private final SplitOldAndNewNeighbors.Factory splitOldAndNewNeighborsFactory;
private final long updateThreshold;
public Knn(
Graph graph,
ProgressTracker progressTracker,
ExecutorService executorService,
K k,
Concurrency concurrency,
int minBatchSize,
int maxIterations,
double similarityCutoff,
double perturbationRate,
int randomJoins,
Optional randomSeed,
KnnSampler.SamplerType initialSamplerType,
SimilarityFunction similarityFunction,
NeighborFilterFactory neighborFilterFactory,
NeighbourConsumers neighborConsumers,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.graph = graph;
this.concurrency = concurrency;
this.maxIterations = maxIterations;
this.similarityCutoff = similarityCutoff;
this.minBatchSize = minBatchSize;
this.neighborFilterFactory = neighborFilterFactory;
this.executorService = executorService;
this.updateThreshold = k.updateThreshold;
var splittableRandom = randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
switch (initialSamplerType) {
case UNIFORM:
this.samplerFactory = new UniformKnnSampler.Factory(graph.nodeCount(), splittableRandom);
break;
case RANDOMWALK:
this.samplerFactory = new RandomWalkKnnSampler.Factory(graph, randomSeed, k.value, splittableRandom);
break;
default:
throw new IllegalStateException("Invalid KnnSampler");
}
this.generateRandomNeighborsFactory = new GenerateRandomNeighbors.Factory(
similarityFunction,
neighborConsumers,
k.value,
splittableRandom,
progressTracker
);
this.splitOldAndNewNeighborsFactory = new SplitOldAndNewNeighbors.Factory(
k.sampledValue,
splittableRandom,
progressTracker
);
this.joinNeighborsFactory = new JoinNeighbors.Factory(
similarityFunction,
k.sampledValue,
perturbationRate,
randomJoins,
splittableRandom,
progressTracker
);
this.terminationFlag = terminationFlag;
}
public ExecutorService executorService() {
return executorService;
}
@Override
public KnnResult compute() {
if (graph.nodeCount() < 2) {
return new EmptyResult();
}
progressTracker.beginSubTask();
progressTracker.beginSubTask();
var neighbors = initializeRandomNeighbors();
progressTracker.endSubTask();
long updateCount;
int iteration = 0;
boolean didConverge = false;
progressTracker.beginSubTask();
for (; iteration < maxIterations; iteration++) {
updateCount = iteration(neighbors);
if (updateCount <= updateThreshold) {
iteration++;
didConverge = true;
break;
}
}
if (similarityCutoff > 0) {
var neighborFilterTasks = PartitionUtils.rangePartition(
concurrency,
neighbors.size(),
partition -> (Runnable) () -> partition.consume(
nodeId -> neighbors.filterHighSimilarityResult(nodeId, similarityCutoff)
),
Optional.of(minBatchSize)
);
RunWithConcurrency.builder()
.concurrency(concurrency)
.tasks(neighborFilterTasks)
.terminationFlag(terminationFlag)
.executor(executorService)
.run();
}
progressTracker.endSubTask();
progressTracker.endSubTask();
return ImmutableKnnResult.of(
neighbors.data(),
iteration,
didConverge,
neighbors.neighborsFound() + neighbors.joinCounter(),
graph.nodeCount()
);
}
private Neighbors initializeRandomNeighbors() {
var neighbors = new Neighbors(graph.nodeCount());
var randomNeighborGenerators = PartitionUtils.rangePartition(
concurrency,
graph.nodeCount(),
partition -> generateRandomNeighborsFactory.create(
partition,
neighbors,
samplerFactory.create(),
neighborFilterFactory.create()
),
Optional.of(minBatchSize)
);
RunWithConcurrency.builder()
.concurrency(concurrency)
.tasks(randomNeighborGenerators)
.terminationFlag(terminationFlag)
.executor(executorService)
.run();
return neighbors;
}
private long iteration(Neighbors neighbors) {
var nodeCount = graph.nodeCount();
// TODO: init in ctor and reuse - benchmark against new allocations
var allOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount);
var allNewNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount);
progressTracker.beginSubTask();
ParallelUtil.readParallel(concurrency, nodeCount, executorService, splitOldAndNewNeighborsFactory.create(
neighbors,
allOldNeighbors,
allNewNeighbors
));
progressTracker.endSubTask();
// TODO: init in ctor and reuse - benchmark against new allocations
var reverseOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount);
var reverseNewNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount);
progressTracker.beginSubTask();
reverseOldAndNewNeighbors(
allOldNeighbors,
allNewNeighbors,
reverseOldNeighbors,
reverseNewNeighbors,
concurrency,
minBatchSize,
progressTracker
);
progressTracker.endSubTask();
var neighborsJoiners = PartitionUtils.rangePartition(
concurrency,
nodeCount,
partition -> joinNeighborsFactory.create(
partition,
neighbors,
allOldNeighbors,
allNewNeighbors,
reverseOldNeighbors,
reverseNewNeighbors,
neighborFilterFactory.create()
),
Optional.of(minBatchSize)
);
progressTracker.beginSubTask();
RunWithConcurrency.builder()
.concurrency(concurrency)
.tasks(neighborsJoiners)
.terminationFlag(terminationFlag)
.executor(executorService)
.run();
progressTracker.endSubTask();
return neighborsJoiners.stream().mapToLong(JoinNeighbors::updateCount).sum();
}
private static void reverseOldAndNewNeighbors(
HugeObjectArray allOldNeighbors,
HugeObjectArray allNewNeighbors,
HugeObjectArray reverseOldNeighbors,
HugeObjectArray reverseNewNeighbors,
Concurrency concurrency,
int minBatchSize,
ProgressTracker progressTracker
) {
long nodeCount = allNewNeighbors.size();
long logBatchSize = ParallelUtil.adjustedBatchSize(nodeCount, concurrency, minBatchSize);
// TODO: cursors
for (long nodeId = 0; nodeId < nodeCount; nodeId++) {
reverseNeighbors(nodeId, allOldNeighbors, reverseOldNeighbors);
reverseNeighbors(nodeId, allNewNeighbors, reverseNewNeighbors);
if ((nodeId + 1) % logBatchSize == 0) {
progressTracker.logProgress(logBatchSize);
}
}
}
static void reverseNeighbors(
long nodeId,
HugeObjectArray allNeighbors,
HugeObjectArray reverseNeighbors
) {
// adding nodeId to the neighbors of its neighbors (reversing the neighbors direction)
var neighbors = allNeighbors.get(nodeId);
if (neighbors != null) {
for (var neighbor : neighbors) {
assert neighbor.value != nodeId;
var oldReverse = reverseNeighbors.get(neighbor.value);
if (oldReverse == null) {
oldReverse = new LongArrayList();
reverseNeighbors.set(neighbor.value, oldReverse);
}
oldReverse.add(nodeId);
}
}
}
private static final class EmptyResult extends KnnResult {
@Override
HugeObjectArray neighborList() {
return HugeObjectArray.of();
}
@Override
public int ranIterations() {
return 0;
}
@Override
public boolean didConverge() {
return false;
}
@Override
public long nodePairsConsidered() {
return 0;
}
@Override
public LongStream neighborsOf(long nodeId) {
return LongStream.empty();
}
@Override
public long size() {
return 0;
}
@Override
public long nodesCompared() {
return 0;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy