apoc.algo.algorithms.BetweennessCentrality Maven / Gradle / Ivy
package apoc.algo.algorithms;
import apoc.Pools;
import org.neo4j.collection.primitive.Primitive;
import org.neo4j.collection.primitive.PrimitiveIntObjectMap;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.logging.Log;
import org.neo4j.procedure.TerminationGuard;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
public class BetweennessCentrality implements AlgorithmInterface {
public static final int WRITE_BATCH=100_000;
public final int MINIMUM_BATCH_SIZE =10_000 ;
private Algorithm algorithm;
private Log log;
GraphDatabaseAPI db;
ExecutorService pool;
private int nodeCount;
private int relCount;
private Statistics stats = new Statistics();
private PrimitiveIntObjectMap intermediateBcPerThread;
float betweennessCentrality[];
private String property;
private final TerminationGuard guard;
public BetweennessCentrality(GraphDatabaseAPI db,
ExecutorService pool, Log log, TerminationGuard guard)
{
this.pool = pool;
this.db = db;
this.log = log;
this.guard = guard;
algorithm = new Algorithm(db, pool, log);
}
@Override
public double getResult(long node) {
float val = -1;
int logicalIndex = algorithm.getAlgoNodeId((int)node);
if (logicalIndex >= 0 && betweennessCentrality.length >= logicalIndex) {
val = betweennessCentrality[logicalIndex];
}
return val;
}
@Override
public long numberOfNodes() {
return nodeCount;
}
@Override
public String getPropertyName() {
return "betweenness_centrality";
}
@Override
public long getMappedNode(int algoId) {
return algorithm.getMappedNode(algoId);
}
public boolean readNodeAndRelCypherData(String relCypher, String nodeCypher, Number weight, Number batchSize, int concurrency) {
boolean success = algorithm.readNodeAndRelCypher(relCypher, nodeCypher,weight, batchSize, concurrency);
this.nodeCount = algorithm.getNodeCount();
this.relCount = algorithm.relCount;
stats.readNodeMillis = algorithm.readNodeMillis;
stats.readRelationshipMillis = algorithm.readRelationshipMillis;
stats.nodes = nodeCount;
stats.relationships = relCount;
return success;
}
public long numberOfRels() {
return relCount;
}
public Statistics getStatistics() {
return stats;
}
public void computeUnweightedSeq() {
computeUnweightedSeq(algorithm.sourceDegreeData,
algorithm.sourceChunkStartingIndex,
algorithm.relationshipTarget);
}
private void computeUnweightedSeq(int[] sourceDegreeData, int[] sourceChunkStartingIndex, int[] relationshipTarget) {
betweennessCentrality = new float[nodeCount];
Arrays.fill(betweennessCentrality, 0);
long before = System.currentTimeMillis();
int start = 0;
int end = nodeCount;
processNodesInBatch(-1, start, end, sourceDegreeData, sourceChunkStartingIndex, relationshipTarget);
long after = System.currentTimeMillis();
long difference = after - before;
log.info("Computations took " + difference + " milliseconds");
stats.computeMillis = difference;
}
public void computeUnweightedParallel() {
computeUnweightedParallel(algorithm.sourceDegreeData,
algorithm.sourceChunkStartingIndex,
algorithm.relationshipTarget);
}
public void computeUnweightedParallel(int [] sourceDegreeData,
int [] sourceChunkStartingIndex,
int [] relationshipTarget) {
betweennessCentrality = new float[nodeCount];
Arrays.fill(betweennessCentrality, 0);
long before = System.currentTimeMillis();
int numOfThreads = Pools.getNoThreadsInDefaultPool();
assert(numOfThreads != 0);
int batchSize = (int)nodeCount/numOfThreads;
int batches = 0;
if (batchSize > 0)
batches = (int)nodeCount/batchSize;
if (batchSize < MINIMUM_BATCH_SIZE) {
batches = 1;
batchSize = nodeCount;
}
List futures = new ArrayList<>(batches);
intermediateBcPerThread = Primitive.intObjectMap();
int nodeIter = 0;
int batchNumber = 0;
while(nodeIter < nodeCount) {
final int start = nodeIter;
final int end = Integer.min(start + batchSize, nodeCount);
final int threadBatchNo = batchNumber;
Future future = pool.submit(new Runnable() {
@Override
public void run() {
processNodesInBatch(threadBatchNo, start, end, sourceDegreeData, sourceChunkStartingIndex, relationshipTarget);
}
});
nodeIter = end;
batchNumber++;
futures.add(future);
}
log.info("Total batches: " + batchNumber);
AlgoUtils.waitForTasks(futures);
compileResults(batchNumber);
long after = System.currentTimeMillis();
long difference = after - before;
log.info("Computations took " + difference + " milliseconds");
stats.computeMillis = difference;
}
private void compileResults(int batchNumber) {
for (int i = 0; i < nodeCount; i++) {
float value = 0;
Object batchValue = 0;
for (int batch = 0; batch < batchNumber; batch++) {
batchValue = ((PrimitiveIntObjectMap)intermediateBcPerThread.get(batch)).get(i);
if (batchValue != null)
value += (float)batchValue;
}
betweennessCentrality[i] = value;
}
}
private void processNodesInBatch(int threadBatchNo,
int start,
int end,
int [] sourceDegreeData,
int [] sourceChunkStartingIndex,
int [] relationshipTarget) {
Stack stack = new Stack<>(); // S
Queue queue = new LinkedList<>();
log.info("Thread: " + Thread.currentThread().getName() + " processing " + start + " " + end);
// Map>predecessors = new HashMap>(); // Pw
PrimitiveIntObjectMap predecessors = Primitive.intObjectMap();
int numShortestPaths[] = new int [nodeCount]; // sigma
int distance[] = new int[nodeCount]; // distance
PrimitiveIntObjectMap map = Primitive.intObjectMap();
float delta[] = new float[nodeCount];
int processedNode = 0;
for (int source = start; source < end; source++) {
processedNode++;
if (sourceDegreeData[source] == 0) {
continue;
}
stack.clear();
predecessors.clear();
Arrays.fill(numShortestPaths, 0);
numShortestPaths[source] = 1;
Arrays.fill(distance, -1);
distance[source] = 0;
queue.clear();
queue.add(source);
Arrays.fill(delta, 0);
while (!queue.isEmpty()) {
int nodeDequeued = queue.remove();
stack.push(nodeDequeued);
// For each neighbour of dequeued.
int chunkIndex = sourceChunkStartingIndex[nodeDequeued];
int degree = sourceDegreeData[nodeDequeued];
for (int j = 0; j < degree; j++) {
int target = relationshipTarget[chunkIndex + j];
if (distance[target] < 0) {
queue.add(target);
distance[target] = distance[nodeDequeued] + 1;
}
if (distance[target] == (distance[nodeDequeued] + 1)) {
numShortestPaths[target] = numShortestPaths[target] + numShortestPaths[nodeDequeued];
if (!predecessors.containsKey(target)) {
ArrayList list = new ArrayList();
predecessors.put(target, list);
}
((ArrayList)predecessors.get(target)).add(nodeDequeued);
}
}
}
int poppedNode;
double partialDependency;
while (!stack.isEmpty()) {
poppedNode = stack.pop();
ArrayList list = (ArrayList)predecessors.get(poppedNode);
for (int i = 0; list != null && i < list.size() ; i++) {
int node = list.get(i);
assert(numShortestPaths[poppedNode] != 0);
partialDependency = (numShortestPaths[node] / (double) numShortestPaths[poppedNode]);
partialDependency *= (1.0) + delta[poppedNode];
delta[node] += partialDependency;
}
if (poppedNode != source && delta[poppedNode] != 0.0) {
if (threadBatchNo == -1) {
betweennessCentrality[poppedNode] = betweennessCentrality[poppedNode] + delta[poppedNode];
} else {
Object storedValue = map.get(poppedNode);
if (storedValue != null)
map.put(poppedNode, ((float)storedValue) + delta[poppedNode]);
else
map.put(poppedNode, delta[poppedNode]);
}
}
}
if (processedNode%10000 == 0) {
log.debug("Thread: " + Thread.currentThread().getName() + " processed " + processedNode);
}
}
intermediateBcPerThread.put(threadBatchNo, map);
delta = null;
numShortestPaths = null;
stack = null;
queue = null;
distance = null;
log.debug("Thread: " + Thread.currentThread().getName() + " Finishing " + processedNode);
}
public void writeResultsToDB(String property) {
this.property = property;
stats.write = true;
long before = System.currentTimeMillis();
AlgoUtils.writeBackResults(pool, db, this, WRITE_BATCH, guard);
stats.writeMillis = System.currentTimeMillis() - before;
stats.property = getPropertyName();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy