graphql.execution.instrumentation.dataloader.FieldLevelTrackingApproach Maven / Gradle / Ivy
package graphql.execution.instrumentation.dataloader;
import graphql.Assert;
import graphql.ExecutionResult;
import graphql.Internal;
import graphql.execution.FieldValueInfo;
import graphql.execution.ResultPath;
import graphql.execution.instrumentation.ExecutionStrategyInstrumentationContext;
import graphql.execution.instrumentation.InstrumentationContext;
import graphql.execution.instrumentation.InstrumentationState;
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters;
import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters;
import org.dataloader.DataLoaderRegistry;
import org.slf4j.Logger;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
/**
* This approach uses field level tracking to achieve its aims of making the data loader more efficient
*/
@Internal
public class FieldLevelTrackingApproach {
private final Supplier dataLoaderRegistrySupplier;
private final Logger log;
private static class CallStack implements InstrumentationState {
private final Map expectedFetchCountPerLevel = new LinkedHashMap<>();
private final Map fetchCountPerLevel = new LinkedHashMap<>();
private final Map expectedStrategyCallsPerLevel = new LinkedHashMap<>();
private final Map happenedStrategyCallsPerLevel = new LinkedHashMap<>();
private final Map happenedOnFieldValueCallsPerLevel = new LinkedHashMap<>();
private final Set dispatchedLevels = new LinkedHashSet<>();
CallStack() {
expectedStrategyCallsPerLevel.put(1, 1);
}
int increaseExpectedFetchCount(int level, int count) {
expectedFetchCountPerLevel.put(level, expectedFetchCountPerLevel.getOrDefault(level, 0) + count);
return expectedFetchCountPerLevel.get(level);
}
void increaseFetchCount(int level) {
fetchCountPerLevel.put(level, fetchCountPerLevel.getOrDefault(level, 0) + 1);
}
void increaseExpectedStrategyCalls(int level, int count) {
expectedStrategyCallsPerLevel.put(level, expectedStrategyCallsPerLevel.getOrDefault(level, 0) + count);
}
void increaseHappenedStrategyCalls(int level) {
happenedStrategyCallsPerLevel.put(level, happenedStrategyCallsPerLevel.getOrDefault(level, 0) + 1);
}
void increaseHappenedOnFieldValueCalls(int level) {
happenedOnFieldValueCallsPerLevel.put(level, happenedOnFieldValueCallsPerLevel.getOrDefault(level, 0) + 1);
}
boolean allStrategyCallsHappened(int level) {
return Objects.equals(happenedStrategyCallsPerLevel.get(level), expectedStrategyCallsPerLevel.get(level));
}
boolean allOnFieldCallsHappened(int level) {
return Objects.equals(happenedOnFieldValueCallsPerLevel.get(level), expectedStrategyCallsPerLevel.get(level));
}
boolean allFetchesHappened(int level) {
return Objects.equals(fetchCountPerLevel.get(level), expectedFetchCountPerLevel.get(level));
}
@Override
public String toString() {
return "CallStack{" +
"expectedFetchCountPerLevel=" + expectedFetchCountPerLevel +
", fetchCountPerLevel=" + fetchCountPerLevel +
", expectedStrategyCallsPerLevel=" + expectedStrategyCallsPerLevel +
", happenedStrategyCallsPerLevel=" + happenedStrategyCallsPerLevel +
", happenedOnFieldValueCallsPerLevel=" + happenedOnFieldValueCallsPerLevel +
", dispatchedLevels" + dispatchedLevels +
'}';
}
public boolean dispatchIfNotDispatchedBefore(int level) {
if (dispatchedLevels.contains(level)) {
Assert.assertShouldNeverHappen("level " + level + " already dispatched");
return false;
}
dispatchedLevels.add(level);
return true;
}
public void clearAndMarkCurrentLevelAsReady(int level) {
expectedFetchCountPerLevel.clear();
fetchCountPerLevel.clear();
expectedStrategyCallsPerLevel.clear();
happenedStrategyCallsPerLevel.clear();
happenedOnFieldValueCallsPerLevel.clear();
dispatchedLevels.clear();
// make sure the level is ready
expectedFetchCountPerLevel.put(level, 1);
expectedStrategyCallsPerLevel.put(level, 1);
happenedStrategyCallsPerLevel.put(level, 1);
}
}
public FieldLevelTrackingApproach(Logger log, Supplier dataLoaderRegistrySupplier) {
this.dataLoaderRegistrySupplier = dataLoaderRegistrySupplier;
this.log = log;
}
public InstrumentationState createState() {
return new CallStack();
}
ExecutionStrategyInstrumentationContext beginExecutionStrategy(InstrumentationExecutionStrategyParameters parameters) {
CallStack callStack = parameters.getInstrumentationState();
ResultPath path = parameters.getExecutionStrategyParameters().getPath();
int parentLevel = path.getLevel();
int curLevel = parentLevel + 1;
int fieldCount = parameters.getExecutionStrategyParameters().getFields().size();
synchronized (callStack) {
callStack.increaseExpectedFetchCount(curLevel, fieldCount);
callStack.increaseHappenedStrategyCalls(curLevel);
}
return new ExecutionStrategyInstrumentationContext() {
@Override
public void onDispatched(CompletableFuture result) {
}
@Override
public void onCompleted(ExecutionResult result, Throwable t) {
}
@Override
public void onFieldValuesInfo(List fieldValueInfoList) {
boolean dispatchNeeded;
synchronized (callStack) {
dispatchNeeded = handleOnFieldValuesInfo(fieldValueInfoList, callStack, curLevel);
}
if (dispatchNeeded) {
dispatch();
}
}
};
}
//
// thread safety : called with synchronised(callStack)
//
private boolean handleOnFieldValuesInfo(List fieldValueInfoList, CallStack callStack, int curLevel) {
callStack.increaseHappenedOnFieldValueCalls(curLevel);
int expectedStrategyCalls = 0;
for (FieldValueInfo fieldValueInfo : fieldValueInfoList) {
if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.OBJECT) {
expectedStrategyCalls++;
} else if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.LIST) {
expectedStrategyCalls += getCountForList(fieldValueInfo);
}
}
callStack.increaseExpectedStrategyCalls(curLevel + 1, expectedStrategyCalls);
return dispatchIfNeeded(callStack, curLevel + 1);
}
private int getCountForList(FieldValueInfo fieldValueInfo) {
int result = 0;
for (FieldValueInfo cvi : fieldValueInfo.getFieldValueInfos()) {
if (cvi.getCompleteValueType() == FieldValueInfo.CompleteValueType.OBJECT) {
result++;
} else if (cvi.getCompleteValueType() == FieldValueInfo.CompleteValueType.LIST) {
result += getCountForList(cvi);
}
}
return result;
}
public InstrumentationContext
© 2015 - 2025 Weber Informatics LLC | Privacy Policy