
org.neo4j.gds.procedures.pipelines.LinkPredictionSimilarityComputer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pipelines-procedure-facade Show documentation
Show all versions of pipelines-procedure-facade Show documentation
Neo4j Graph Data Science :: Pipelines Procedure Facade
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package org.neo4j.gds.procedures.pipelines;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.gds.similarity.knn.NeighborFilter;
import org.neo4j.gds.similarity.knn.NeighborFilterFactory;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
public class LinkPredictionSimilarityComputer implements SimilarityComputer {
private static final int POSITIVE_CLASS_INDEX = (int) EdgeSplitter.POSITIVE;
private final LinkFeatureExtractor linkFeatureExtractor;
private final Classifier classifier;
public LinkPredictionSimilarityComputer(
LinkFeatureExtractor linkFeatureExtractor,
Classifier classifier
) {
this.linkFeatureExtractor = linkFeatureExtractor;
this.classifier = classifier;
}
@Override
public double similarity(long sourceId, long targetId) {
var features = linkFeatureExtractor.extractFeatures(sourceId, targetId);
return classifier.predictProbabilities(features)[POSITIVE_CLASS_INDEX];
}
@Override
public boolean isSymmetric() {
return linkFeatureExtractor.isSymmetric();
}
public static final class LinkFilter implements NeighborFilter {
private final LPNodeFilter sourceNodeFilter;
private final LPNodeFilter targetNodeFilter;
private final Graph graph;
private LinkFilter(Graph graph, LPNodeFilter sourceNodeFilter, LPNodeFilter targetNodeFilter) {
this.graph = graph;
this.sourceNodeFilter = sourceNodeFilter;
this.targetNodeFilter = targetNodeFilter;
}
@Override
public boolean excludeNodePair(long firstNodeId, long secondNodeId) {
if (firstNodeId == secondNodeId) {
return true;
}
var matchesFilter = sourceNodeFilter.test(firstNodeId) && targetNodeFilter.test(secondNodeId) || sourceNodeFilter.test(secondNodeId) && targetNodeFilter.test(firstNodeId);
// graph.exists a slower but memory-efficient approach (could be replaced by a dedicated data structure)
return !matchesFilter || graph.exists(firstNodeId, secondNodeId);
}
@Override
public long lowerBoundOfPotentialNeighbours(long node) {
if (sourceNodeFilter.test(node)) {
return Math.max(targetNodeFilter.validNodeCount() - 1 - graph.degree(node), 0L);
} else {
return Math.max(sourceNodeFilter.validNodeCount() - 1 - graph.degree(node), 0L);
}
}
}
public static class LinkFilterFactory implements NeighborFilterFactory {
private final Graph graph;
private final LPNodeFilter sourceNodeFilter;
private final LPNodeFilter targetNodeFilter;
public LinkFilterFactory(Graph graph, LPNodeFilter sourceNodeFilter, LPNodeFilter targetNodeFilter) {
this.graph = graph;
this.sourceNodeFilter = sourceNodeFilter;
this.targetNodeFilter = targetNodeFilter;
}
@Override
public NeighborFilter create() {
return new LinkPredictionSimilarityComputer.LinkFilter(graph.concurrentCopy(), sourceNodeFilter, targetNodeFilter);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy