All Downloads are FREE. Search and download functionalities are using the official Maven repository.

apoc.nodes.Grouping Maven / Gradle / Ivy

There is a newer version: 5.25.1
Show newest version
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package apoc.nodes;

import static java.util.Collections.*;

import apoc.Pools;
import apoc.result.VirtualNode;
import apoc.result.VirtualRelationship;
import apoc.util.Util;
import apoc.util.collection.Iterables;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.graphdb.*;
import org.neo4j.logging.Log;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.NotThreadSafe;
import org.neo4j.procedure.Procedure;

/**
 * @author mh
 * @since 14.06.17
 */
public class Grouping {

    private static final int BATCHSIZE = 10000;

    @Context
    public GraphDatabaseService db;

    @Context
    public Transaction tx;

    @Context
    public Log log;

    @Context
    public Pools pools;

    public static class GroupResult {
        @Description("A list of grouped nodes represented as virtual nodes.")
        public List nodes;

        @Description("A list of grouped relationships represented as virtual relationships.")
        public List relationships;

        @Description("The grouping node.")
        public Node node;

        @Description("The grouping relationship.")
        public Relationship relationship;

        public GroupResult(Node node, Relationship relationship) {
            this.node = node;
            this.relationship = relationship;
            this.nodes = singletonList(node);
            this.relationships = singletonList(relationship);
        }

        public GroupResult(Node node, List relationships) {
            this.nodes = singletonList(node);
            this.relationships = relationships;
            this.node = node;
            this.relationship = relationships.isEmpty() ? null : relationships.get(0);
        }

        public Stream spread() {
            return Stream.concat(Stream.of(this), relationships.stream().skip(1).map(r -> new GroupResult(node, r)));
        }
    }

    @NotThreadSafe
    @Procedure("apoc.nodes.group")
    @Description("Allows for the aggregation of `NODE` values based on the given properties.\n"
            + "This procedure returns virtual `NODE` values.")
    public Stream group(
            @Name(
                            value = "labels",
                            description =
                                    "The list of node labels to aggregate over. Use `['*']` to indicate all node labels should be looked at.")
                    List labelNames,
            @Name(value = "groupByProperties", description = "The property keys to group the nodes by.")
                    List groupByProperties,
            @Name(
                            value = "aggregations",
                            defaultValue = "[{`*`:\"count\"},{`*`:\"count\"}]",
                            description =
                                    "The first map specifies the node properties to aggregate with their corresponding aggregation functions, while the second map specifies the relationship properties for aggregation.")
                    List> aggregations,
            @Name(
                            value = "config",
                            defaultValue = "{}",
                            description =
                                    """
                    {
                        includeRels :: STRING | LIST
                        excludeRels :: STRING | LIST,
                        orphans = true :: BOOLEAN,
                        selfRels = true :: BOOLEAN,
                        limitNodes = -1 :: INTEGER,
                        limitRels = -1 :: INTEGER,
                        relsPerNode = -1 :: INTEGER,
                        filter :: MAP
                    }
                    """)
                    Map config) {

        Set labels = new HashSet<>(labelNames);
        if (labels.remove("*"))
            labels.addAll(Iterables.stream(tx.getAllLabels()).map(Label::name).collect(Collectors.toSet()));

        String[] keys = groupByProperties.toArray(new String[groupByProperties.size()]);

        if (aggregations == null || aggregations.isEmpty()) {
            aggregations = Arrays.asList(singletonMap("*", "count"), singletonMap("*", "count"));
        }
        Map> nodeAggNames =
                (aggregations.size() > 0) ? toStringListMap(aggregations.get(0)) : emptyMap();
        String[] nodeAggKeys = keyArray(nodeAggNames, "*");

        Map> relAggNames =
                (aggregations.size() > 1) ? toStringListMap(aggregations.get(1)) : emptyMap();
        String[] relAggKeys = keyArray(relAggNames, "*");

        // todo bitset
        Set includeRels = computeIncludedRels(config);

        /*
        config:{orphans:false,selfRels:false,limitNodes:100, limitRels:1000, filter:{Person.count_*.min,10,Person.sum_age.max,200,KNOWS.count_*.min:5}}
         */
        boolean orphans = (boolean) config.getOrDefault("orphans", true);
        boolean selfRels = (boolean) config.getOrDefault("selfRels", true);
        long limitNodes = (long) config.getOrDefault("limitNodes", -1L);
        long limitRels = (long) config.getOrDefault("limitRels", -1L);
        long relsPerNode = (long) config.getOrDefault("relsPerNode", -1L);

        // filter min, max on aggregated properties
        // (TYPE.)prop.min: value,(TYPE.)prop.max: value,
        // also filter (esp. max) during aggregation?
        Map filter = configuredFilter(config);

        Map> grouped = new ConcurrentHashMap<>();
        Map virtualNodes = new ConcurrentHashMap<>();
        Map virtualRels = new ConcurrentHashMap<>();

        List futures = new ArrayList<>(1000);

        ExecutorService pool = pools.getDefaultExecutorService();
        for (String labelName : labels) {
            Label label = Label.label(labelName);
            Label[] singleLabel = {label};

            try (ResourceIterator nodes =
                    (labelName.equals("*")) ? tx.getAllNodes().iterator() : tx.findNodes(label)) {
                while (nodes.hasNext()) {
                    List batch = Util.take(nodes, BATCHSIZE);
                    futures.add(Util.inTxFuture(pool, db, txInThread -> {
                        try {
                            for (Node node : batch) {
                                final Node boundNode = Util.rebind(txInThread, node);
                                NodeKey key = keyFor(boundNode, labelName, keys);
                                grouped.compute(key, (k, v) -> {
                                    if (v == null) v = new HashSet<>();
                                    v.add(boundNode);
                                    return v;
                                });
                                virtualNodes.compute(key, (k, v) -> {
                                    if (v == null) {
                                        v = new VirtualNode(singleLabel, propertiesFor(boundNode, keys));
                                    }
                                    VirtualNode vn = v;
                                    if (!nodeAggNames.isEmpty()) {
                                        aggregate(
                                                vn,
                                                nodeAggNames,
                                                nodeAggKeys.length > 0
                                                        ? boundNode.getProperties(nodeAggKeys)
                                                        : Collections.emptyMap());
                                    }
                                    return vn;
                                });
                            }
                        } catch (Exception e) {
                            log.error("Error grouping nodes", e);
                        }
                        return null;
                    }));
                    Util.removeFinished(futures);
                }
            }
        }
        Util.waitForFutures(futures);
        futures.clear();
        Iterator>> entries = grouped.entrySet().iterator();
        int size = 0;
        List>> batch = new ArrayList<>();
        while (entries.hasNext()) {
            Map.Entry> outerEntry = entries.next();
            batch.add(outerEntry);
            size += outerEntry.getValue().size();
            if (size > BATCHSIZE || !entries.hasNext()) {
                ArrayList>> submitted = new ArrayList<>(batch);
                batch.clear();
                size = 0;
                futures.add(Util.inTxFuture(pool, db, txInThread -> {
                    try {
                        for (Map.Entry> entry : submitted) {
                            for (Node node : entry.getValue()) {
                                node = Util.rebind(txInThread, node);
                                NodeKey startKey = entry.getKey();
                                VirtualNode v1 = virtualNodes.get(startKey);
                                for (Relationship rel : node.getRelationships(Direction.OUTGOING)) {
                                    if (includeRels != null
                                            && !includeRels.contains(
                                                    rel.getType().name())) continue;
                                    Node endNode = rel.getEndNode();
                                    for (NodeKey endKey : keysFor(endNode, labels, keys)) {
                                        VirtualNode v2 = virtualNodes.get(endKey);
                                        if (v2 == null) continue;
                                        if (!selfRels && startKey.equals(endKey)) continue;
                                        virtualRels.compute(new RelKey(startKey, endKey, rel), (rk, vRel) -> {
                                            if (vRel == null) vRel = v1.createRelationshipTo(v2, rel.getType());
                                            if (!relAggNames.isEmpty()) {
                                                aggregate(
                                                        vRel,
                                                        relAggNames,
                                                        relAggKeys.length > 0
                                                                ? rel.getProperties(relAggKeys)
                                                                : Collections.emptyMap());
                                            }
                                            return vRel;
                                        });
                                    }
                                }
                            }
                        }
                    } catch (Exception e) {
                        log.error("Error grouping relationships", e);
                    }
                    return null;
                }));
                Util.removeFinished(futures);
            }
        }
        Util.waitForFutures(futures);
        Stream stream = fixAggregates(virtualNodes.values()).stream();
        // apply filter
        if (filter != null) stream = stream.filter(n -> filter(n.getLabels(), n.getAllProperties(), filter));
        if (limitNodes > -1) stream = stream.limit(limitNodes);

        Stream groupResultStream =
                stream.map(n -> new GroupResult(n, getRelationships(n, filter, (int) relsPerNode)));
        if (!orphans)
            groupResultStream = groupResultStream.filter(
                    g -> g.relationships != null && !g.relationships.isEmpty() && g.node.getDegree() > 0);
        groupResultStream = groupResultStream.flatMap(GroupResult::spread);

        if (limitRels > -1) groupResultStream = groupResultStream.limit(limitRels);
        return groupResultStream;
    }

    private Map configuredFilter(Map config) {
        Map filter = (Map) config.get("filter");
        if (filter == null || filter.isEmpty()) return null;
        return filter;
    }

    private boolean filter(String type, Map props, Map filter) {
        if (filter == null || props.isEmpty()) return true;
        return filterProps(type, props, filter);
    }

    private boolean filter(Iterable