graphql.util.TreeParallelTraverser Maven / Gradle / Ivy
package graphql.util;
import graphql.Internal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountedCompleter;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Function;
import static graphql.Assert.assertNotNull;
import static graphql.Assert.assertTrue;
import static graphql.util.TraversalControl.ABORT;
import static graphql.util.TraversalControl.CONTINUE;
import static graphql.util.TraversalControl.QUIT;
@Internal
public class TreeParallelTraverser {
private final Function super T, Map>> getChildren;
private final Map, Object> rootVars = new ConcurrentHashMap<>();
private final ForkJoinPool forkJoinPool;
private Object sharedContextData;
private TreeParallelTraverser(Function super T, Map>> getChildren,
Object sharedContextData,
ForkJoinPool forkJoinPool) {
this.getChildren = assertNotNull(getChildren);
this.sharedContextData = sharedContextData;
this.forkJoinPool = forkJoinPool;
}
public static TreeParallelTraverser parallelTraverser(Function super T, ? extends List> getChildren) {
return parallelTraverser(getChildren, null, ForkJoinPool.commonPool());
}
public static TreeParallelTraverser parallelTraverser(Function super T, ? extends List> getChildren,
Object sharedContextData) {
return new TreeParallelTraverser<>(wrapListFunction(getChildren), sharedContextData, ForkJoinPool.commonPool());
}
public static TreeParallelTraverser parallelTraverser(Function super T, ? extends List> getChildren,
Object sharedContextData,
ForkJoinPool forkJoinPool) {
return new TreeParallelTraverser<>(wrapListFunction(getChildren), sharedContextData, forkJoinPool);
}
public static TreeParallelTraverser parallelTraverserWithNamedChildren(Function super T, Map>> getNamedChildren,
Object sharedContextData) {
return new TreeParallelTraverser<>(getNamedChildren, sharedContextData, ForkJoinPool.commonPool());
}
public static TreeParallelTraverser parallelTraverserWithNamedChildren(Function super T, Map>> getNamedChildren,
Object sharedContextData,
ForkJoinPool forkJoinPool) {
return new TreeParallelTraverser<>(getNamedChildren, sharedContextData, forkJoinPool);
}
private static Function super T, Map>> wrapListFunction(Function super T, ? extends List> listFn) {
return node -> {
List childs = listFn.apply(node);
return Collections.singletonMap(null, childs);
};
}
public TreeParallelTraverser rootVars(Map, Object> rootVars) {
this.rootVars.putAll(assertNotNull(rootVars));
return this;
}
public TreeParallelTraverser rootVar(Class> key, Object value) {
rootVars.put(key, value);
return this;
}
public void traverse(T root, TraverserVisitor super T> visitor) {
traverse(Collections.singleton(root), visitor);
}
public void traverse(Collection extends T> roots, TraverserVisitor super T> visitor) {
traverseImpl(roots, visitor);
}
public DefaultTraverserContext newRootContext(Map, Object> vars) {
return newContextImpl(null, null, vars, null, true);
}
public void traverseImpl(Collection extends T> roots, TraverserVisitor super T> visitor) {
assertNotNull(roots);
assertNotNull(visitor);
DefaultTraverserContext rootContext = newRootContext(rootVars);
forkJoinPool.invoke(new CountedCompleter() {
@Override
public void compute() {
setPendingCount(roots.size());
for (T root : roots) {
DefaultTraverserContext context = newContext(root, rootContext, null);
EnterAction enterAction = new EnterAction(this, context, visitor);
enterAction.fork();
}
tryComplete();
}
});
}
private class EnterAction extends CountedCompleter {
private DefaultTraverserContext currentContext;
private TraverserVisitor super T> visitor;
private EnterAction(CountedCompleter parent, DefaultTraverserContext currentContext, TraverserVisitor super T> visitor) {
super(parent);
this.currentContext = currentContext;
this.visitor = visitor;
}
@Override
public void compute() {
currentContext.setPhase(TraverserContext.Phase.ENTER);
TraversalControl traversalControl = visitor.enter(currentContext);
assertNotNull(traversalControl, () -> "result of enter must not be null");
assertTrue(QUIT != traversalControl, () -> "can't return QUIT for parallel traversing");
if (traversalControl == ABORT) {
tryComplete();
return;
}
assertTrue(traversalControl == CONTINUE);
List children = pushAll(currentContext);
if (children.size() == 0) {
tryComplete();
return;
}
setPendingCount(children.size() - 1);
for (int i = 1; i < children.size(); i++) {
new EnterAction(this, children.get(i), visitor).fork();
}
new EnterAction(this, children.get(0), visitor).compute();
}
}
private List pushAll(TraverserContext traverserContext) {
Map>> childrenContextMap = new LinkedHashMap<>();
LinkedList contexts = new LinkedList<>();
if (!traverserContext.isDeleted()) {
Map> childrenMap = getChildren.apply(traverserContext.thisNode());
childrenMap.keySet().forEach(key -> {
List children = childrenMap.get(key);
for (int i = children.size() - 1; i >= 0; i--) {
T child = assertNotNull(children.get(i), "null child for key %s", key);
NodeLocation nodeLocation = new NodeLocation(key, i);
DefaultTraverserContext context = newContext(child, traverserContext, nodeLocation);
contexts.push(context);
childrenContextMap.computeIfAbsent(key, notUsed -> new ArrayList<>());
childrenContextMap.get(key).add(0, context);
}
});
}
return contexts;
}
private DefaultTraverserContext newContext(T o, TraverserContext parent, NodeLocation position) {
return newContextImpl(o, parent, new LinkedHashMap<>(), position, false);
}
private DefaultTraverserContext newContextImpl(T curNode,
TraverserContext parent,
Map, Object> vars,
NodeLocation nodeLocation,
boolean isRootContext) {
assertNotNull(vars);
return new DefaultTraverserContext<>(curNode, parent, null, vars, sharedContextData, nodeLocation, isRootContext, true);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy