All Downloads are FREE. Search and download functionalities are using the official Maven repository.

angry1980.neo4j.louvain.Louvain Maven / Gradle / Ivy

There is a newer version: 0.0.10
Show newest version
package angry1980.neo4j.louvain;

import it.unimi.dsi.fastutil.longs.*;
import org.neo4j.graphdb.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;

public class Louvain {

    private final Logger LOG = LoggerFactory.getLogger(Louvain.class);

    private final double totalEdgeWeight;
    private final LouvainResult louvainResult;
    private int layerCount = 0;
    private Long2ObjectMap nodes = new Long2ObjectArrayMap<>();

    //todo: use LNode list as constructor argument
    public Louvain(GraphDatabaseService g){
        this(g, new DefaultTaskAdapter());
    }

    public Louvain(GraphDatabaseService g, TaskAdapter adapter) {
        this.louvainResult = new LouvainResult();
        try (Transaction tx = g.beginTx()) {
            for (Node n : adapter.getNodes(g)) {
                List rels = new ArrayList<>();
                for(Relationship r : adapter.getRelationships(n)){
                    rels.add(new LRel(adapter.getInitWeight(r), adapter.getId(r.getOtherNode(n))));
                }
                LNode ln = new LNode(adapter.getId(n), adapter.getId(n), rels);
                nodes.put(ln.id, ln);
            }
            tx.success();
        }
        totalEdgeWeight = nodes.values().stream().mapToDouble(LNode::getWeight).sum();
    }

    public void execute() {
        int macroNodeCount = 0;
        do {
            LOG.debug("Layer count: " + layerCount);
            macroNodeCount = this.pass(macroNodeCount);
        } while (macroNodeCount != 0);
    }

    public int pass(int macroNodeCount) {
        this.firstPhase();
        LOG.debug("Starting modularity...");
        int totMacroNodes = this.secondPhase(macroNodeCount);
        LOG.debug("Created " + totMacroNodes);

        layerCount++;
        return totMacroNodes;
    }

    public void firstPhase() {
        int movements;

        do {
            movements = 0;
            for(LNode src : nodes.values()){
                long bestCommunity = src.community;
                double bestDelta = 0.0;
                for (LRel r : src.rels) {
                    long neighCommunity = nodes.get(r.otherNode).community;

                    double delta = this.calculateDelta(src, src.community, neighCommunity);
                    if (delta > bestDelta) {
                        bestDelta = delta;
                        bestCommunity = neighCommunity;
                    }
                }

                if (src.community != bestCommunity) {
                    src.community = bestCommunity;
                    movements++;
                }
            }
            LOG.debug("Movements so far: " + movements);
        } while (movements != 0);
    }

    private double calculateDelta(LNode n, long srcCommunity, long dstCommunity) {
        double first, second;

        first = n.communityWeightWithout(dstCommunity) - n.communityWeightWithout(srcCommunity);
        first = first / totalEdgeWeight;

        second = (n.communityVolumeWithout(srcCommunity) - n.communityVolumeWithout(dstCommunity)) * n.getWeight();
        second = second / (2 * Math.pow(totalEdgeWeight, 2));

        return first + second;
    }

    public int secondPhase(int macroNodeCount) {
        int totMacroNodes = 0;
        Long2ObjectMap macros = new Long2ObjectArrayMap<>();
        Map originalToMacro = new HashMap<>();
        LongSet macroNodesCommunities = new LongOpenHashSet();
        LouvainLayer louvainLayer = louvainResult.layer(layerCount);
        for(LNode n : nodes.values()){
            macroNodesCommunities.add(n.community);
            louvainLayer.add(n.id, n.community);
        }

        // Check if a new layer must be created
        if (macroNodesCommunities.size() == macroNodeCount) {
            // Nothing to move: save to layer object and exit
            return totMacroNodes;
        }

        // Get all nodes of current layer
        for(LNode activeNode : nodes.values()){

            // Prendi il macronode associato a questa community
            LNode macroNode = macros.get(activeNode.community);
            if (macroNode == null) {    // Se non esiste, crealo
                totMacroNodes++;
                macroNode = new LNode(activeNode.community, activeNode.community, new ArrayList<>());
                macros.put(macroNode.id, macroNode);
            }

            // Create a relationship to the original node
            originalToMacro.put(activeNode, macroNode);
        }

        for(Map.Entry entry : originalToMacro.entrySet()){
            LNode macroNode = entry.getValue();
            for (LRel r : entry.getKey().rels) {
                LRel macroRel = macroNode.tryToFindRel(r.otherNode)
                        .orElseGet(() -> {
                            LRel mr = new LRel(0.0, originalToMacro.get(nodes.get(r.otherNode)).id);
                            macroNode.rels.add(mr);
                            return mr;
                        });
                macroRel.weight = macroRel.weight + 1.0;
            }
        }
        nodes = macros;
        return totMacroNodes;
    }

    public LouvainResult getResult() {
        return this.louvainResult;
    }

    class LNode{
        long id;
        List rels;
        double weight;
        long community;

        public LNode(long id, long community, List rels) {
            this.id = id;
            this.rels = rels;
            this.community = community;
        }

        public long getCommunity() {
            return community;
        }

        public double getWeight(){
            if(weight == 0){
                weight = rels.stream().mapToDouble(LRel::getWeight).sum();
            }
            return weight;
        }

        public Optional tryToFindRel(long other){
            return rels.stream()
                    .filter(r -> r.otherNode == other)
                    .findAny();
        }

        public double communityWeightWithout(long cId) {
            return rels.stream()
                    .filter(r -> nodes.get(r.otherNode).community == cId)
                    .mapToDouble(LRel::getWeight)
                    .sum();
        }

        public double communityVolumeWithout(long cId) {
            return nodes.values().stream()
                    .filter(n -> n.community == cId)
                    .filter(n -> !n.equals(this))
                    .mapToDouble(LNode::getWeight)
                    .sum();
        }

    }

    class LRel{
        double weight;
        long otherNode;

        public LRel(double weight, long otherNode) {
            this.weight = weight;
            this.otherNode = otherNode;
        }

        public double getWeight() {
            return weight;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy