
org.neo4j.gds.betweenness.RandomDegreeSelectionStrategy Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of algo Show documentation
Show all versions of algo Show documentation
Neo4j Graph Data Science :: Algorithms
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package org.neo4j.gds.betweenness;
import com.carrotsearch.hppc.BitSet;
import com.carrotsearch.hppc.BitSetIterator;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import java.util.Collection;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import static com.carrotsearch.hppc.BitSetIterator.NO_MORE;
public class RandomDegreeSelectionStrategy implements SelectionStrategy {
private final long samplingSize;
private final Optional maybeRandomSeed;
private final AtomicLong nodeQueue = new AtomicLong();
private long graphSize;
private BitSet sampleSet;
public RandomDegreeSelectionStrategy(long samplingSize) {
this(samplingSize, Optional.empty());
}
public RandomDegreeSelectionStrategy(long samplingSize, Optional maybeRandomSeed) {
this.samplingSize = samplingSize;
this.maybeRandomSeed = maybeRandomSeed;
}
@Override
public void init(Graph graph, ExecutorService executorService, Concurrency concurrency) {
assert samplingSize <= graph.nodeCount();
this.sampleSet = new BitSet(graph.nodeCount());
this.graphSize = graph.nodeCount();
nodeQueue.set(0);
var partitions = PartitionUtils.numberAlignedPartitioning(concurrency, graph.nodeCount(), Long.SIZE);
var maxDegree = maxDegree(graph, partitions, executorService, concurrency);
selectNodes(graph, partitions, maxDegree, executorService, concurrency);
}
@Override
public long next() {
long nextNodeId;
while ((nextNodeId = nodeQueue.getAndIncrement()) < graphSize) {
if (sampleSet.get(nextNodeId)) {
return nextNodeId;
}
}
return NONE_SELECTED;
}
private static int maxDegree(
Graph graph,
Collection partitions,
ExecutorService executorService,
Concurrency concurrency
) {
AtomicInteger maxDegree = new AtomicInteger(0);
var tasks = partitions.stream()
.map(partition -> (Runnable) () -> partition.consume(nodeId -> {
int degree = graph.degree(nodeId);
int current = maxDegree.get();
while (degree > current) {
int newCurrent = maxDegree.compareAndExchange(current, degree);
if (newCurrent == current) {
break;
}
current = newCurrent;
}
})).collect(Collectors.toList());
RunWithConcurrency.builder()
.concurrency(concurrency)
.tasks(tasks)
.executor(executorService)
.run();
return maxDegree.get();
}
private void selectNodes(
Graph graph,
Collection partitions,
int maxDegree,
ExecutorService executorService,
Concurrency concurrency
) {
var random = maybeRandomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
var selectionSize = new AtomicLong(0);
var tasks = partitions.stream()
.map(partition -> (Runnable) () -> {
var threadLocalRandom = random.split();
var fromNode = partition.startNode();
var toNode = partition.startNode() + partition.nodeCount();
for (long nodeId = fromNode; nodeId < toNode; nodeId++) {
var currentSelectionSize = selectionSize.get();
if (currentSelectionSize >= samplingSize) {
break;
}
int nodeDegree = graph.degree(nodeId);
// probability factor is in range [1, maxDegree] (inclusive both ends)
// the probability of a node being selected is probabilityFactor * (1 / maxDegree)
int probabilityFactor = threadLocalRandom.nextInt(maxDegree) + 1;
if (probabilityFactor <= nodeDegree) {
while (true) {
long actualCurrentSelectionSize = selectionSize.compareAndExchange(
currentSelectionSize,
currentSelectionSize + 1
);
if (currentSelectionSize == actualCurrentSelectionSize) {
sampleSet.set(nodeId);
break;
}
if (actualCurrentSelectionSize >= samplingSize) {
break;
}
currentSelectionSize = actualCurrentSelectionSize;
}
}
}
}).collect(Collectors.toList());
RunWithConcurrency.builder()
.concurrency(concurrency)
.tasks(tasks)
.executor(executorService)
.run();
long actualSelectedNodes = selectionSize.get();
if (actualSelectedNodes < samplingSize) {
// Flip bitset to be able to iterate unset bits.
// The upper range is Graph#nodeCount() since
// BitSet#size() returns a multiple of 64.
// We need to make sure to stay within bounds.
sampleSet.flip(0, graph.nodeCount());
// Potentially iterate the bitset multiple times
// until we have exactly numSeedNodes nodes.
BitSetIterator iterator;
while (actualSelectedNodes < samplingSize) {
iterator = sampleSet.iterator();
var unselectedNode = iterator.nextSetBit();
while (unselectedNode != NO_MORE && actualSelectedNodes < samplingSize) {
if (random.nextDouble() >= 0.5) {
sampleSet.flip(unselectedNode);
actualSelectedNodes++;
}
unselectedNode = iterator.nextSetBit();
}
}
sampleSet.flip(0, graph.nodeCount());
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy