All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.graphaware.nlp.ml.pagerank.PageRank Maven / Gradle / Ivy

There is a newer version: 3.5.4.53.18
Show newest version
/*
 * Copyright (c) 2013-2018 GraphAware
 *
 * This file is part of the GraphAware Framework.
 *
 * GraphAware Framework is free software: you can redistribute it and/or modify it under the terms of
 * the GNU General Public License as published by the Free Software Foundation, either
 * version 3 of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
 * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details. You should have received a copy of
 * the GNU General Public License along with this program.  If not, see
 * .
 */
package com.graphaware.nlp.ml.pagerank;

import com.google.common.util.concurrent.AtomicDouble;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Relationship;
import org.neo4j.graphdb.Result;
import org.neo4j.graphdb.Transaction;
import org.neo4j.logging.Log;
import com.graphaware.common.log.LoggerFactory;
import static com.graphaware.nlp.util.TypeConverter.getDoubleValue;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

public class PageRank {

    private static final Log LOG = LoggerFactory.getLogger(PageRank.class);

    protected final GraphDatabaseService database;
    private Map nodeWeights;

    public PageRank(GraphDatabaseService database) {
        this.database = database;
    }

    public void setNodeWeights(Map w) {
        this.nodeWeights = w;
    }

    public Map run(Map> coOccurrences, int iter, double dampFactor, double threshold) {
        nodeWeights = initializeNodeWeights(coOccurrences);
        Map pagerank = getInitializedPageRank(nodeWeights, dampFactor);
        int nNodes = pagerank.size();
        boolean thresholdHit = false;
        Map prTemp = new HashMap<>();
        for (int iteration = 0; iteration < iter && !thresholdHit; iteration++) {
            //Map prTemp = new HashMap<>();
            // calculate main part of the PR calculation, include weights of nodes and relationships
            nodeWeights.entrySet().stream().forEach(enExt -> {
                Long nodeIdExt = enExt.getKey();
                Double nodeWExt = enExt.getValue();
                AtomicDouble internalSum = new AtomicDouble(0.0);
                //AtomicDouble internalNodeWSum = new AtomicDouble(0.0);
                nodeWeights.entrySet().stream()
                        .filter(enInt -> coOccurrences.containsKey(enInt.getKey()) && coOccurrences.get(enInt.getKey()).containsKey(nodeIdExt))
                        .forEach(enInt -> {
                            Long nodeIdInt = enInt.getKey();
                            Double nodeWInt = enInt.getValue();
                            //internalNodeWSum.addAndGet(nodeWInt);
                            Map coOccurrentTags = coOccurrences.get(nodeIdInt);
                            //Can be optimized
                            double totalWeightSum = coOccurrentTags.values().stream().map(item -> item.getCount()).mapToDouble(Number::doubleValue).sum();
                            //internalSum.addAndGet(1.0d/coOccurrentTags.size() * pagerank.get(nodeIdInt)); // no relationship weights
                            internalSum.addAndGet(((1.0d * coOccurrentTags.get(nodeIdExt).getCount()) / totalWeightSum) * pagerank.get(nodeIdInt)); // with relationship weights
                            //internalSum.addAndGet(((1.0d * coOccurrentTags.get(nodeIdExt).getCount()) / totalWeightSum) * pagerank.get(nodeIdInt) * nodeWInt); // with relationship & node weights
                        });
                double newPrValue = (1 - dampFactor) / nNodes + dampFactor * internalSum.get(); // PR is a probability (PR values add up to 1)

                // PageRank with node weights
                //long nInt = nodeWeights.entrySet().stream()
                //        .filter(enInt -> coOccurrences.containsKey(enInt.getKey()) && coOccurrences.get(enInt.getKey()).containsKey(nodeIdExt))
                //        .count();
                //double newPrValue = (1 - dampFactor) / nNodes + dampFactor * internalSum.get() * (nInt / internalNodeWSum.get()); // PR is a probability (PR values add up to 1); WITH node weights
                prTemp.put(nodeIdExt, newPrValue);
            });
            thresholdHit = checkThreshold(pagerank, prTemp, threshold);
            if (thresholdHit) {
                LOG.info("Threshold hit after " + (iteration + 1) + " iterations");
            }
            // finish page rank computation and store it to the final list
            nodeWeights.keySet().stream().forEach((nodeIdExt) -> {
                pagerank.put(nodeIdExt, prTemp.get(nodeIdExt));
            });

        } // iterations
        return pagerank;
    }

    public Map> createGraph(String query, boolean respectDirections) {
        LOG.info("Running query: " + query);
        Map> results = new HashMap<>();
        try (Transaction tx = database.beginTx();) {
            Result res = database.execute(query);
            while (res != null && res.hasNext()) {
                Map next = res.next();
                Long tag1 = (Long) next.get("start");
                Long tag2 = (Long) next.get("dest");
                double w = getDoubleValue(next.get("weight"));
                addTagToCoOccurrence(results, tag1, tag2, w);
                if (!respectDirections)
                    addTagToCoOccurrence(results, tag2, tag1, w);
            }
            tx.success();
        } catch (Exception e) {
            LOG.error("processGraph() failed: " + e.getMessage());
        }
        return results;
    }

    public void storeOnGraph(Map pageranks) {
        pageranks.keySet().forEach((tag) -> {
            try (Transaction tx = database.beginTx();) {
                database.execute("MATCH (t) WHERE id(t)=" + tag + "\n SET t.pagerank = " + pageranks.get(tag));
                tx.success();
            } catch (Exception e) {
                LOG.error("storeOnGraph() failed: " + e.getMessage());
            }
        });
    }

    private void addTagToCoOccurrence(Map> results, Long tag1, Long tag2, double w) {
        Map mapTag1;
        if (!results.containsKey(tag1)) {
            mapTag1 = new HashMap<>();
            results.put(tag1, mapTag1);
        } else {
            mapTag1 = results.get(tag1);
        }
        if (mapTag1.containsKey(tag2)) {
            mapTag1.get(tag2).incCountBy(w);
        } else {
            mapTag1.put(tag2, new CoOccurrenceItem(tag1, tag2));
            mapTag1.get(tag2).setCount(w);
        }
    }

    private Map initializeNodeWeights(Map> coOccurrences) {
        if (nodeWeights != null && nodeWeights.size() > 0) {
            return nodeWeights;
        }
        Map nodeInitialWeights = new HashMap<>();
        coOccurrences.entrySet().stream().forEach((coOccurrence) -> {
            coOccurrence.getValue().entrySet().stream().forEach((entry) -> {
                nodeInitialWeights.put(entry.getValue().getSource(), 1.0d);
                nodeInitialWeights.put(entry.getValue().getDestination(), 1.0d);
            });
        });
        return nodeInitialWeights;
    }

    private Map getInitializedPageRank(Map nodeWeights, double damp) {
        Map pageRank = new HashMap<>();
        int n = nodeWeights.size();
        nodeWeights.entrySet().stream().forEach((item) -> {
            pageRank.put(item.getKey(), (1. - damp) / n);
        });

        return pageRank;
    }

    private boolean checkThreshold(Map pagerank, Map prTemp, double threshold) {
        Iterator iterator = pagerank.keySet().iterator();
        while (iterator.hasNext()) {
            long nodeIdExt = iterator.next();
            double diff = Math.abs(prTemp.get(nodeIdExt) - pagerank.get(nodeIdExt));
            if (diff > threshold) {
                return false;
            }
        }
        return true;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy