All Downloads are FREE. Search and download functionalities are using the official Maven repository.

graphql.util.TreeParallelTraverser Maven / Gradle / Ivy

package graphql.util;

import graphql.Internal;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountedCompleter;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Function;

import static graphql.Assert.assertNotNull;
import static graphql.Assert.assertTrue;
import static graphql.util.TraversalControl.ABORT;
import static graphql.util.TraversalControl.CONTINUE;
import static graphql.util.TraversalControl.QUIT;

@Internal
public class TreeParallelTraverser {

    private final Function>> getChildren;
    private final Map, Object> rootVars = new ConcurrentHashMap<>();

    private final ForkJoinPool forkJoinPool;

    private Object sharedContextData;


    private TreeParallelTraverser(Function>> getChildren,
                                  Object sharedContextData,
                                  ForkJoinPool forkJoinPool) {
        this.getChildren = assertNotNull(getChildren);
        this.sharedContextData = sharedContextData;
        this.forkJoinPool = forkJoinPool;
    }

    public static  TreeParallelTraverser parallelTraverser(Function> getChildren) {
        return parallelTraverser(getChildren, null, ForkJoinPool.commonPool());
    }

    public static  TreeParallelTraverser parallelTraverser(Function> getChildren,
                                                                 Object sharedContextData) {
        return new TreeParallelTraverser<>(wrapListFunction(getChildren), sharedContextData, ForkJoinPool.commonPool());
    }

    public static  TreeParallelTraverser parallelTraverser(Function> getChildren,
                                                                 Object sharedContextData,
                                                                 ForkJoinPool forkJoinPool) {
        return new TreeParallelTraverser<>(wrapListFunction(getChildren), sharedContextData, forkJoinPool);
    }


    public static  TreeParallelTraverser parallelTraverserWithNamedChildren(Function>> getNamedChildren,
                                                                                  Object sharedContextData) {
        return new TreeParallelTraverser<>(getNamedChildren, sharedContextData, ForkJoinPool.commonPool());
    }

    public static  TreeParallelTraverser parallelTraverserWithNamedChildren(Function>> getNamedChildren,
                                                                                  Object sharedContextData,
                                                                                  ForkJoinPool forkJoinPool) {
        return new TreeParallelTraverser<>(getNamedChildren, sharedContextData, forkJoinPool);
    }


    private static  Function>> wrapListFunction(Function> listFn) {
        return node -> {
            List childs = listFn.apply(node);
            return Collections.singletonMap(null, childs);
        };
    }

    public TreeParallelTraverser rootVars(Map, Object> rootVars) {
        this.rootVars.putAll(assertNotNull(rootVars));
        return this;
    }

    public TreeParallelTraverser rootVar(Class key, Object value) {
        rootVars.put(key, value);
        return this;
    }

    public void traverse(T root, TraverserVisitor visitor) {
        traverse(Collections.singleton(root), visitor);
    }

    public void traverse(Collection roots, TraverserVisitor visitor) {
        traverseImpl(roots, visitor);
    }

    public DefaultTraverserContext newRootContext(Map, Object> vars) {
        return newContextImpl(null, null, vars, null, true);
    }


    public void traverseImpl(Collection roots, TraverserVisitor visitor) {
        assertNotNull(roots);
        assertNotNull(visitor);

        DefaultTraverserContext rootContext = newRootContext(rootVars);
        forkJoinPool.invoke(new CountedCompleter() {
            @Override
            public void compute() {
                setPendingCount(roots.size());
                for (T root : roots) {
                    DefaultTraverserContext context = newContext(root, rootContext, null);
                    EnterAction enterAction = new EnterAction(this, context, visitor);
                    enterAction.fork();
                }
                tryComplete();
            }

        });
    }

    private class EnterAction extends CountedCompleter {
        private DefaultTraverserContext currentContext;
        private TraverserVisitor visitor;

        private EnterAction(CountedCompleter parent, DefaultTraverserContext currentContext, TraverserVisitor visitor) {
            super(parent);
            this.currentContext = currentContext;
            this.visitor = visitor;
        }

        @Override
        public void compute() {
            currentContext.setPhase(TraverserContext.Phase.ENTER);
            TraversalControl traversalControl = visitor.enter(currentContext);
            assertNotNull(traversalControl, () -> "result of enter must not be null");
            assertTrue(QUIT != traversalControl, () -> "can't return QUIT for parallel traversing");
            if (traversalControl == ABORT) {
                tryComplete();
                return;
            }
            assertTrue(traversalControl == CONTINUE);
            List children = pushAll(currentContext);
            if (children.size() == 0) {
                tryComplete();
                return;
            }
            setPendingCount(children.size() - 1);
            for (int i = 1; i < children.size(); i++) {
                new EnterAction(this, children.get(i), visitor).fork();
            }
            new EnterAction(this, children.get(0), visitor).compute();
        }
    }

    private List pushAll(TraverserContext traverserContext) {

        Map>> childrenContextMap = new LinkedHashMap<>();

        LinkedList contexts = new LinkedList<>();
        if (!traverserContext.isDeleted()) {

            Map> childrenMap = getChildren.apply(traverserContext.thisNode());
            childrenMap.keySet().forEach(key -> {
                List children = childrenMap.get(key);
                for (int i = children.size() - 1; i >= 0; i--) {
                    T child = assertNotNull(children.get(i), "null child for key %s", key);
                    NodeLocation nodeLocation = new NodeLocation(key, i);
                    DefaultTraverserContext context = newContext(child, traverserContext, nodeLocation);
                    contexts.push(context);
                    childrenContextMap.computeIfAbsent(key, notUsed -> new ArrayList<>());
                    childrenContextMap.get(key).add(0, context);
                }
            });
        }
        return contexts;
    }

    private DefaultTraverserContext newContext(T o, TraverserContext parent, NodeLocation position) {
        return newContextImpl(o, parent, new LinkedHashMap<>(), position, false);
    }

    private DefaultTraverserContext newContextImpl(T curNode,
                                                      TraverserContext parent,
                                                      Map, Object> vars,
                                                      NodeLocation nodeLocation,
                                                      boolean isRootContext) {
        assertNotNull(vars);
        return new DefaultTraverserContext<>(curNode, parent, null, vars, sharedContextData, nodeLocation, isRootContext, true);
    }
}






© 2015 - 2025 Weber Informatics LLC | Privacy Policy