com.graphaware.nlp.ml.pagerank.PageRank Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of nlp Show documentation
Show all versions of nlp Show documentation
GraphAware Framework Module for adding NLP capabilities to Neo4j.
/*
* Copyright (c) 2013-2018 GraphAware
*
* This file is part of the GraphAware Framework.
*
* GraphAware Framework is free software: you can redistribute it and/or modify it under the terms of
* the GNU General Public License as published by the Free Software Foundation, either
* version 3 of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
* without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU General Public License for more details. You should have received a copy of
* the GNU General Public License along with this program. If not, see
* .
*/
package com.graphaware.nlp.ml.pagerank;
import com.google.common.util.concurrent.AtomicDouble;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Relationship;
import org.neo4j.graphdb.Result;
import org.neo4j.graphdb.Transaction;
import org.neo4j.logging.Log;
import com.graphaware.common.log.LoggerFactory;
import static com.graphaware.nlp.util.TypeConverter.getDoubleValue;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
public class PageRank {
private static final Log LOG = LoggerFactory.getLogger(PageRank.class);
protected final GraphDatabaseService database;
private Map nodeWeights;
public PageRank(GraphDatabaseService database) {
this.database = database;
}
public void setNodeWeights(Map w) {
this.nodeWeights = w;
}
public Map run(Map> coOccurrences, int iter, double dampFactor, double threshold) {
nodeWeights = initializeNodeWeights(coOccurrences);
Map pagerank = getInitializedPageRank(nodeWeights, dampFactor);
int nNodes = pagerank.size();
boolean thresholdHit = false;
Map prTemp = new HashMap<>();
for (int iteration = 0; iteration < iter && !thresholdHit; iteration++) {
//Map prTemp = new HashMap<>();
// calculate main part of the PR calculation, include weights of nodes and relationships
nodeWeights.entrySet().stream().forEach(enExt -> {
Long nodeIdExt = enExt.getKey();
Double nodeWExt = enExt.getValue();
AtomicDouble internalSum = new AtomicDouble(0.0);
//AtomicDouble internalNodeWSum = new AtomicDouble(0.0);
nodeWeights.entrySet().stream()
.filter(enInt -> coOccurrences.containsKey(enInt.getKey()) && coOccurrences.get(enInt.getKey()).containsKey(nodeIdExt))
.forEach(enInt -> {
Long nodeIdInt = enInt.getKey();
Double nodeWInt = enInt.getValue();
//internalNodeWSum.addAndGet(nodeWInt);
Map coOccurrentTags = coOccurrences.get(nodeIdInt);
//Can be optimized
double totalWeightSum = coOccurrentTags.values().stream().map(item -> item.getCount()).mapToDouble(Number::doubleValue).sum();
//internalSum.addAndGet(1.0d/coOccurrentTags.size() * pagerank.get(nodeIdInt)); // no relationship weights
internalSum.addAndGet(((1.0d * coOccurrentTags.get(nodeIdExt).getCount()) / totalWeightSum) * pagerank.get(nodeIdInt)); // with relationship weights
//internalSum.addAndGet(((1.0d * coOccurrentTags.get(nodeIdExt).getCount()) / totalWeightSum) * pagerank.get(nodeIdInt) * nodeWInt); // with relationship & node weights
});
double newPrValue = (1 - dampFactor) / nNodes + dampFactor * internalSum.get(); // PR is a probability (PR values add up to 1)
// PageRank with node weights
//long nInt = nodeWeights.entrySet().stream()
// .filter(enInt -> coOccurrences.containsKey(enInt.getKey()) && coOccurrences.get(enInt.getKey()).containsKey(nodeIdExt))
// .count();
//double newPrValue = (1 - dampFactor) / nNodes + dampFactor * internalSum.get() * (nInt / internalNodeWSum.get()); // PR is a probability (PR values add up to 1); WITH node weights
prTemp.put(nodeIdExt, newPrValue);
});
thresholdHit = checkThreshold(pagerank, prTemp, threshold);
if (thresholdHit) {
LOG.info("Threshold hit after " + (iteration + 1) + " iterations");
}
// finish page rank computation and store it to the final list
nodeWeights.keySet().stream().forEach((nodeIdExt) -> {
pagerank.put(nodeIdExt, prTemp.get(nodeIdExt));
});
} // iterations
return pagerank;
}
public Map> createGraph(String query, boolean respectDirections) {
LOG.info("Running query: " + query);
Map> results = new HashMap<>();
try (Transaction tx = database.beginTx();) {
Result res = database.execute(query);
while (res != null && res.hasNext()) {
Map next = res.next();
Long tag1 = (Long) next.get("start");
Long tag2 = (Long) next.get("dest");
double w = getDoubleValue(next.get("weight"));
addTagToCoOccurrence(results, tag1, tag2, w);
if (!respectDirections)
addTagToCoOccurrence(results, tag2, tag1, w);
}
tx.success();
} catch (Exception e) {
LOG.error("processGraph() failed: " + e.getMessage());
}
return results;
}
public void storeOnGraph(Map pageranks) {
pageranks.keySet().forEach((tag) -> {
try (Transaction tx = database.beginTx();) {
database.execute("MATCH (t) WHERE id(t)=" + tag + "\n SET t.pagerank = " + pageranks.get(tag));
tx.success();
} catch (Exception e) {
LOG.error("storeOnGraph() failed: " + e.getMessage());
}
});
}
private void addTagToCoOccurrence(Map> results, Long tag1, Long tag2, double w) {
Map mapTag1;
if (!results.containsKey(tag1)) {
mapTag1 = new HashMap<>();
results.put(tag1, mapTag1);
} else {
mapTag1 = results.get(tag1);
}
if (mapTag1.containsKey(tag2)) {
mapTag1.get(tag2).incCountBy(w);
} else {
mapTag1.put(tag2, new CoOccurrenceItem(tag1, tag2));
mapTag1.get(tag2).setCount(w);
}
}
private Map initializeNodeWeights(Map> coOccurrences) {
if (nodeWeights != null && nodeWeights.size() > 0) {
return nodeWeights;
}
Map nodeInitialWeights = new HashMap<>();
coOccurrences.entrySet().stream().forEach((coOccurrence) -> {
coOccurrence.getValue().entrySet().stream().forEach((entry) -> {
nodeInitialWeights.put(entry.getValue().getSource(), 1.0d);
nodeInitialWeights.put(entry.getValue().getDestination(), 1.0d);
});
});
return nodeInitialWeights;
}
private Map getInitializedPageRank(Map nodeWeights, double damp) {
Map pageRank = new HashMap<>();
int n = nodeWeights.size();
nodeWeights.entrySet().stream().forEach((item) -> {
pageRank.put(item.getKey(), (1. - damp) / n);
});
return pageRank;
}
private boolean checkThreshold(Map pagerank, Map prTemp, double threshold) {
Iterator iterator = pagerank.keySet().iterator();
while (iterator.hasNext()) {
long nodeIdExt = iterator.next();
double diff = Math.abs(prTemp.get(nodeIdExt) - pagerank.get(nodeIdExt));
if (diff > threshold) {
return false;
}
}
return true;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy