
org.neo4j.gds.similarity.nodesim.NodeSimilarity Maven / Gradle / Ivy
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package org.neo4j.gds.similarity.nodesim;
import com.carrotsearch.hppc.BitSet;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.api.properties.relationships.RelationshipConsumer;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.SetBitsIterable;
import org.neo4j.gds.core.utils.paged.HugeLongLongMap;
import org.neo4j.gds.core.utils.progress.BatchingProgressLogger;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.similarity.SimilarityGraphBuilder;
import org.neo4j.gds.similarity.SimilarityGraphResult;
import org.neo4j.gds.similarity.SimilarityResult;
import org.neo4j.gds.similarity.filtering.NodeFilter;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.wcc.WccParameters;
import org.neo4j.gds.wcc.WccStub;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.LongUnaryOperator;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
public class NodeSimilarity extends Algorithm {
private final Graph graph;
private final NodeSimilarityParameters parameters;
private final boolean sortVectors;
private final boolean weighted;
private final BitSet sourceNodes;
private final BitSet targetNodes;
private final NodeFilter sourceNodeFilter;
private final NodeFilter targetNodeFilter;
private final ExecutorService executorService;
private final Concurrency concurrency;
private final MetricSimilarityComputer similarityComputer;
private HugeObjectArray neighbors;
private HugeObjectArray weights;
private LongUnaryOperator components;
private Function sourceNodesStream;
private BiFunction targetNodesStream;
private final WccStub wccStub;
public NodeSimilarity(
Graph graph,
NodeSimilarityParameters parameters,
Concurrency concurrency,
ExecutorService executorService,
ProgressTracker progressTracker,
NodeFilter sourceNodeFilter,
NodeFilter targetNodeFilter,
TerminationFlag terminationFlag,
WccStub wccStub
) {
super(progressTracker);
this.graph = graph;
this.sortVectors = graph.schema().relationshipSchema().availableTypes().size() > 1;
this.sourceNodeFilter = sourceNodeFilter;
this.targetNodeFilter = targetNodeFilter;
this.concurrency = concurrency;
this.parameters = parameters;
this.similarityComputer = parameters.similarityComputer();
this.executorService = executorService;
this.sourceNodes = new BitSet(graph.nodeCount());
this.targetNodes = new BitSet(graph.nodeCount());
this.wccStub = wccStub;
this.weighted = this.parameters.hasRelationshipWeightProperty();
this.terminationFlag = terminationFlag;
}
@Override
public NodeSimilarityResult compute() {
progressTracker.beginSubTask();
prepare();
if (parameters.computeToStream()) {
var computeToStream = computeToStream();
progressTracker.endSubTask();
return ImmutableNodeSimilarityResult.of(
Optional.of(computeToStream),
Optional.empty()
);
} else {
var computeToGraph = computeToGraph();
progressTracker.endSubTask();
return ImmutableNodeSimilarityResult.of(
Optional.empty(),
Optional.of(computeToGraph)
);
}
}
private Stream computeToStream() {
// Create a filter for which nodes to compare and calculate the neighborhood for each node
terminationFlag.assertRunning();
// Compute similarities
if (parameters.hasTopN() && !parameters.hasTopK()) {
// Special case: compute topN without topK.
// This can not happen when algo is called from proc.
// Ignore parallelism, always run single threaded,
// but run on primitives.
return computeTopN();
} else {
return concurrency.value() > 1
? computeParallel()
: computeSimilarityResultStream();
}
}
private SimilarityGraphResult computeToGraph() {
Graph similarityGraph;
boolean isTopKGraph = false;
if (parameters.hasTopK() && !parameters.hasTopN()) {
terminationFlag.assertRunning();
TopKMap topKMap = concurrency.value() > 1
? computeTopKMapParallel()
: computeTopKMap();
isTopKGraph = true;
similarityGraph = new TopKGraph(graph, topKMap);
} else {
Stream similarities = computeToStream();
similarityGraph = new SimilarityGraphBuilder(
graph,
concurrency,
executorService,
terminationFlag
).build(similarities);
}
return new SimilarityGraphResult(similarityGraph, sourceNodes.cardinality(), isTopKGraph);
}
private void prepare() {
progressTracker.beginSubTask();
components = initComponents();
if (parameters.runWCC()){
prepareAsChildTask();
}else{
prepareAsLeaf();
}
progressTracker.endSubTask();
}
private void prepareAsLeaf(){
setUpNodesAndFilters();
}
private void prepareAsChildTask(){
progressTracker.beginSubTask();
setUpNodesAndFilters();
progressTracker.endSubTask();
}
private void setUpNodesAndFilters(){
initNodeSpecificFields();
sourceNodesStream = initSourceNodesStream();
targetNodesStream = initTargetNodesStream();
}
private Stream computeSimilarityResultStream() {
if (parameters.hasTopK()) {
var topKMap = computeTopKMap();
return parameters.hasTopN() ? computeTopN(topKMap) : topKMap.stream();
} else {
return computeAll();
}
}
private Stream computeParallel() {
if (parameters.hasTopK()) {
var topKMap = computeTopKMapParallel();
return parameters.hasTopN() ? computeTopN(topKMap) : topKMap.stream();
} else {
return computeAllParallel();
}
}
private LongUnaryOperator initComponents() {
if (!parameters.useComponents()) {
// considering everything as within the same component
return n -> 0;
}
if (parameters.componentProperty() != null) {
// extract component info from property
NodePropertyValues nodeProperties = graph.nodeProperties(parameters.componentProperty());
return initComponentIdMapping(graph, nodeProperties::longValue);
}
// run WCC to determine components
var wccParameters = new WccParameters(0D, concurrency);
var disjointSets = wccStub.wcc(graph, wccParameters, progressTracker, false);
return disjointSets::setIdOf;
}
private void initNodeSpecificFields() {
neighbors = HugeObjectArray.newArray(long[].class, graph.nodeCount());
if (weighted) {
weights = HugeObjectArray.newArray(double[].class, graph.nodeCount());
}
DegreeComputer degreeComputer = new DegreeComputer();
VectorComputer vectorComputer = VectorComputer.of(graph, weighted);
DegreeFilter degreeFilter = new DegreeFilter(parameters.degreeCutoff(), parameters.upperDegreeCutoff());
neighbors.setAll(node -> {
graph.forEachRelationship(node, degreeComputer);
int degree = degreeComputer.degree;
degreeComputer.reset();
vectorComputer.reset(degree);
progressTracker.logProgress(graph.degree(node));
if (degreeFilter.apply(degree)) {
if (sourceNodeFilter.test(node)) {
sourceNodes.set(node);
}
if (targetNodeFilter.test(node)) {
targetNodes.set(node);
}
// TODO: we don't need to do the rest of the prepare for a node that isn't going to be used in the computation
vectorComputer.forEachRelationship(node);
if (sortVectors) {
vectorComputer.sortTargetIds();
}
if (weighted) {
weights.set(node, vectorComputer.getWeights());
}
return vectorComputer.targetIds.buffer;
}
return null;
});
}
private Stream computeAll() {
progressTracker.beginSubTask(calculateWorkload());
var similarityResultStream = loggableAndTerminableSourceNodeStream()
.boxed()
.flatMap(this::computeSimilaritiesForNode);
progressTracker.endSubTask();
return similarityResultStream;
}
private Stream computeAllParallel() {
return ParallelUtil.parallelStream(
loggableAndTerminableSourceNodeStream(), concurrency, stream -> stream
.boxed()
.flatMap(this::computeSimilaritiesForNode)
);
}
private TopKMap computeTopKMap() {
progressTracker.beginSubTask(calculateWorkload());
var comparator = parameters.normalizedK() > 0
? SimilarityResult.DESCENDING
: SimilarityResult.ASCENDING;
var topKMap = new TopKMap(neighbors.size(), sourceNodes, Math.abs(parameters.normalizedK()), comparator);
loggableAndTerminableSourceNodeStream()
.forEach(sourceNodeId -> {
if (sourceNodeFilter.equals(NodeFilter.ALLOW_EVERYTHING)) {
targetNodesStream.apply(components.applyAsLong(sourceNodeId), sourceNodeId + 1)
.forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId,
(source, target, similarity) -> {
topKMap.put(source, target, similarity);
topKMap.put(target, source, similarity);
}
));
} else {
targetNodesStream.apply(components.applyAsLong(sourceNodeId), 0L)
.filter(targetNodeId -> sourceNodeId != targetNodeId)
.forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topKMap::put));
}
});
progressTracker.endSubTask();
return topKMap;
}
private TopKMap computeTopKMapParallel() {
progressTracker.beginSubTask(calculateWorkload());
var comparator = parameters.normalizedK() > 0
? SimilarityResult.DESCENDING
: SimilarityResult.ASCENDING;
var topKMap = new TopKMap(neighbors.size(), sourceNodes, Math.abs(parameters.normalizedK()), comparator);
ParallelUtil.parallelStreamConsume(
loggableAndTerminableSourceNodeStream(),
concurrency,
terminationFlag,
stream -> stream
.forEach(sourceNodeId ->
// We deliberately compute the full matrix (except the diagonal).
// The parallel workload is partitioned based on the outer stream.
// The TopKMap stores a priority queue for each node. Writing
// into these queues is not considered to be thread-safe.
// Hence, we need to ensure that down the stream, exactly one queue
// within the TopKMap processes all pairs for a single node.
targetNodesStream.apply(components.applyAsLong(sourceNodeId), 0L)
.filter(targetNodeId -> sourceNodeId != targetNodeId)
.forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topKMap::put))
)
);
progressTracker.endSubTask();
return topKMap;
}
private Stream computeTopN() {
progressTracker.beginSubTask(calculateWorkload());
var topNList = new TopNList(parameters.normalizedN());
loggableAndTerminableSourceNodeStream()
.forEach(sourceNodeId -> {
if (sourceNodeFilter.equals(NodeFilter.ALLOW_EVERYTHING)) {
targetNodesStream.apply(components.applyAsLong(sourceNodeId), sourceNodeId + 1)
.forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topNList::add));
} else {
targetNodesStream.apply(components.applyAsLong(sourceNodeId), 0L)
.filter(targetNodeId -> sourceNodeId != targetNodeId)
.forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topNList::add));
}
});
progressTracker.endSubTask();
return topNList.stream();
}
private Stream computeTopN(TopKMap topKMap) {
var topNList = new TopNList(parameters.normalizedN());
topKMap.forEach(topNList::add);
return topNList.stream();
}
private Function initSourceNodesStream() {
return offset -> new SetBitsIterable(sourceNodes, offset).stream();
}
private BiFunction initTargetNodesStream() {
if (!parameters.useComponents()) {
return (componentId, offset) -> new SetBitsIterable(targetNodes, offset).stream();
}
var componentNodes = ComponentNodes.create(components, targetNodes::get, graph.nodeCount(), concurrency);
return (componentId, offset) -> StreamSupport
.longStream(componentNodes.spliterator(componentId, offset), false);
}
private LongStream loggableAndTerminableSourceNodeStream() {
return checkProgress(sourceNodesStream.apply(0L));
}
private Stream computeSimilaritiesForNode(long sourceNodeId) {
return targetNodesStream.apply(components.applyAsLong(sourceNodeId), sourceNodeId + 1)
.mapToObj(targetNodeId -> {
var resultHolder = new SimilarityResult[]{null};
computeSimilarityFor(
sourceNodeId,
targetNodeId,
(source, target, similarity) -> resultHolder[0] = new SimilarityResult(source, target, similarity)
);
return resultHolder[0];
})
.filter(Objects::nonNull);
}
private static LongUnaryOperator initComponentIdMapping(IdMap idMap, LongUnaryOperator originComponentIdMapper) {
var componentIdMappings = new HugeLongLongMap();
var mappedComponentId = new AtomicLong(0L);
var mappedComponentIdPerNode = HugeLongArray.newArray(idMap.nodeCount());
idMap.forEachNode(n -> {
long originComponentIdForNode = originComponentIdMapper.applyAsLong(n);
long mappedComponentIdForNode = componentIdMappings.getOrDefault(originComponentIdMapper.applyAsLong(n),
mappedComponentId.getAndIncrement());
if (!componentIdMappings.containsKey(originComponentIdForNode)) {
componentIdMappings.put(originComponentIdForNode, mappedComponentIdForNode);
}
mappedComponentIdPerNode.set(n, mappedComponentIdForNode);
return true;
});
return mappedComponentIdPerNode::get;
}
interface SimilarityConsumer {
void accept(long sourceNodeId, long targetNodeId, double similarity);
}
private void computeSimilarityFor(long sourceNodeId, long targetNodeId, SimilarityConsumer consumer) {
double similarity;
var sourceNodeNeighbors = neighbors.get(sourceNodeId);
var targetNodeNeighbors = neighbors.get(targetNodeId);
if (weighted) {
similarity = computeWeightedSimilarity(
sourceNodeNeighbors, targetNodeNeighbors, weights.get(sourceNodeId), weights.get(targetNodeId)
);
} else {
similarity = computeSimilarity(sourceNodeNeighbors, targetNodeNeighbors);
}
if (!Double.isNaN(similarity)) {
consumer.accept(sourceNodeId, targetNodeId, similarity);
}
}
private double computeWeightedSimilarity(
long[] sourceNodeNeighbors,
long[] targetNodeNeighbors,
double[] sourceNodeWeights,
double[] targetNodeWeights
) {
double similarity = similarityComputer.computeWeightedSimilarity(
sourceNodeNeighbors,
targetNodeNeighbors,
sourceNodeWeights,
targetNodeWeights
);
progressTracker.logProgress();
return similarity;
}
private double computeSimilarity(long[] sourceNodeNeighbors, long[] targetNodeNeighbors) {
double similarity = similarityComputer.computeSimilarity(sourceNodeNeighbors, targetNodeNeighbors);
progressTracker.logProgress();
return similarity;
}
private LongStream checkProgress(LongStream stream) {
return stream.peek(node -> {
if ((node & BatchingProgressLogger.MAXIMUM_LOG_INTERVAL) == 0) {
terminationFlag.assertRunning();
}
});
}
private long calculateWorkload() {
//for each source node, examine all their target nodes
//if no filter then sourceNodes == targetNodes
long workload = sourceNodes.cardinality() * targetNodes.cardinality();
//when on concurrency of 1 on not-filtered similarity, we only compare nodeId with greater indexed nodes
// so work is halved. This does not hold for filtered similarity, since the targetNodes might be lesser indexed.
boolean isNotFiltered = sourceNodeFilter.equals(NodeFilter.ALLOW_EVERYTHING) && targetNodeFilter.equals(
NodeFilter.ALLOW_EVERYTHING);
if (concurrency.value() == 1 && isNotFiltered) {
workload = workload / 2;
}
return workload;
}
private static final class DegreeComputer implements RelationshipConsumer {
long lastTarget = -1;
int degree = 0;
@Override
public boolean accept(long source, long target) {
if (source != target && lastTarget != target) {
degree++;
}
lastTarget = target;
return true;
}
void reset() {
lastTarget = -1;
degree = 0;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy