graphql.schema.transform.FieldVisibilitySchemaTransformation Maven / Gradle / Ivy
package graphql.schema.transform;
import com.google.common.collect.ImmutableList;
import graphql.PublicApi;
import graphql.schema.GraphQLEnumType;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLImplementingType;
import graphql.schema.GraphQLInputObjectField;
import graphql.schema.GraphQLInputObjectType;
import graphql.schema.GraphQLInterfaceType;
import graphql.schema.GraphQLNamedSchemaElement;
import graphql.schema.GraphQLNamedType;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLSchema;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLType;
import graphql.schema.GraphQLTypeVisitorStub;
import graphql.schema.GraphQLUnionType;
import graphql.schema.SchemaTraverser;
import graphql.schema.impl.SchemaUtil;
import graphql.schema.transform.VisibleFieldPredicateEnvironment.VisibleFieldPredicateEnvironmentImpl;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static graphql.schema.SchemaTransformer.transformSchema;
/**
* Transforms a schema by applying a visibility predicate to every field.
*/
@PublicApi
public class FieldVisibilitySchemaTransformation {
private final VisibleFieldPredicate visibleFieldPredicate;
private final Runnable beforeTransformationHook;
private final Runnable afterTransformationHook;
public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPredicate) {
this(visibleFieldPredicate, () -> {}, () -> {});
}
public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPredicate,
Runnable beforeTransformationHook,
Runnable afterTransformationHook) {
this.visibleFieldPredicate = visibleFieldPredicate;
this.beforeTransformationHook = beforeTransformationHook;
this.afterTransformationHook = afterTransformationHook;
}
public final GraphQLSchema apply(GraphQLSchema schema) {
Set observedBeforeTransform = new HashSet<>();
Set observedAfterTransform = new HashSet<>();
Set markedForRemovalTypes = new HashSet<>();
// query, mutation, and subscription types should not be removed
final Set protectedTypeNames = getOperationTypes(schema).stream()
.map(GraphQLObjectType::getName)
.collect(Collectors.toSet());
beforeTransformationHook.run();
new SchemaTraverser(getChildrenFn(schema)).depthFirst(new TypeObservingVisitor(observedBeforeTransform), getRootTypes(schema));
// remove fields
GraphQLSchema interimSchema = transformSchema(schema,
new FieldRemovalVisitor(visibleFieldPredicate, markedForRemovalTypes));
new SchemaTraverser(getChildrenFn(interimSchema)).depthFirst(new TypeObservingVisitor(observedAfterTransform), getRootTypes(interimSchema));
// remove types that are not used after removing fields - (connected schema only)
GraphQLSchema connectedSchema = transformSchema(interimSchema,
new TypeVisibilityVisitor(protectedTypeNames, observedBeforeTransform, observedAfterTransform));
// ensure markedForRemovalTypes are not referenced by other schema elements, and delete from the schema
// the ones that aren't.
GraphQLSchema finalSchema = removeUnreferencedTypes(markedForRemovalTypes, connectedSchema);
afterTransformationHook.run();
return finalSchema;
}
// Creates a getChildrenFn that includes interface
private Function> getChildrenFn(GraphQLSchema schema) {
Map> interfaceImplementations = new SchemaUtil().groupImplementationsForInterfacesAndObjects(schema);
return graphQLSchemaElement -> {
if (!(graphQLSchemaElement instanceof GraphQLInterfaceType)) {
return graphQLSchemaElement.getChildren();
}
ArrayList children = new ArrayList<>(graphQLSchemaElement.getChildren());
List implementations = interfaceImplementations.get(((GraphQLInterfaceType) graphQLSchemaElement).getName());
if (implementations != null) {
children.addAll(implementations);
}
return children;
};
}
private GraphQLSchema removeUnreferencedTypes(Set markedForRemovalTypes, GraphQLSchema connectedSchema) {
GraphQLSchema withoutAdditionalTypes = connectedSchema.transform(builder -> {
Set additionalTypes = new HashSet<>(connectedSchema.getAdditionalTypes());
additionalTypes.removeAll(markedForRemovalTypes);
builder.clearAdditionalTypes();
builder.additionalTypes(additionalTypes);
});
// remove from markedForRemovalTypes any type that might still be referenced by other schema elements
transformSchema(withoutAdditionalTypes, new AdditionalTypeVisibilityVisitor(markedForRemovalTypes));
// finally remove the types on the schema we are certain aren't referenced by any other node.
return transformSchema(connectedSchema, new GraphQLTypeVisitorStub() {
@Override
protected TraversalControl visitGraphQLType(GraphQLSchemaElement node, TraverserContext context) {
if (node instanceof GraphQLType && markedForRemovalTypes.contains(node)) {
return deleteNode(context);
}
return super.visitGraphQLType(node, context);
}
});
}
private static class TypeObservingVisitor extends GraphQLTypeVisitorStub {
private final Set observedTypes;
private TypeObservingVisitor(Set observedTypes) {
this.observedTypes = observedTypes;
}
@Override
protected TraversalControl visitGraphQLType(GraphQLSchemaElement node,
TraverserContext context) {
if (node instanceof GraphQLType) {
observedTypes.add((GraphQLType) node);
}
return TraversalControl.CONTINUE;
}
}
private static class FieldRemovalVisitor extends GraphQLTypeVisitorStub {
private final VisibleFieldPredicate visibilityPredicate;
private final Set removedTypes;
private FieldRemovalVisitor(VisibleFieldPredicate visibilityPredicate,
Set removedTypes) {
this.visibilityPredicate = visibilityPredicate;
this.removedTypes = removedTypes;
}
@Override
public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition definition,
TraverserContext context) {
return visitField(definition, context);
}
@Override
public TraversalControl visitGraphQLInputObjectField(GraphQLInputObjectField definition,
TraverserContext context) {
return visitField(definition, context);
}
private TraversalControl visitField(GraphQLNamedSchemaElement element,
TraverserContext context) {
VisibleFieldPredicateEnvironment environment = new VisibleFieldPredicateEnvironmentImpl(
element, context.getParentNode());
if (!visibilityPredicate.isVisible(environment)) {
deleteNode(context);
if (element instanceof GraphQLFieldDefinition) {
removedTypes.add(((GraphQLFieldDefinition) element).getType());
} else if (element instanceof GraphQLInputObjectField) {
removedTypes.add(((GraphQLInputObjectField) element).getType());
}
}
return TraversalControl.CONTINUE;
}
}
private static class TypeVisibilityVisitor extends GraphQLTypeVisitorStub {
private final Set protectedTypeNames;
private final Set observedBeforeTransform;
private final Set observedAfterTransform;
private TypeVisibilityVisitor(Set protectedTypeNames,
Set observedTypes,
Set observedAfterTransform) {
this.protectedTypeNames = protectedTypeNames;
this.observedBeforeTransform = observedTypes;
this.observedAfterTransform = observedAfterTransform;
}
@Override
public TraversalControl visitGraphQLInterfaceType(GraphQLInterfaceType node,
TraverserContext context) {
return super.visitGraphQLInterfaceType(node, context);
}
@Override
public TraversalControl visitGraphQLType(GraphQLSchemaElement node,
TraverserContext context) {
if (observedBeforeTransform.contains(node) &&
!observedAfterTransform.contains(node) &&
(node instanceof GraphQLObjectType ||
node instanceof GraphQLEnumType ||
node instanceof GraphQLInputObjectType ||
node instanceof GraphQLInterfaceType ||
node instanceof GraphQLUnionType)) {
return deleteNode(context);
}
return TraversalControl.CONTINUE;
}
}
private static class AdditionalTypeVisibilityVisitor extends GraphQLTypeVisitorStub {
private final Set markedForRemovalTypes;
private AdditionalTypeVisibilityVisitor(Set markedForRemovalTypes) {
this.markedForRemovalTypes = markedForRemovalTypes;
}
@Override
public TraversalControl visitGraphQLType(GraphQLSchemaElement node,
TraverserContext context) {
if (node instanceof GraphQLNamedType) {
GraphQLNamedType namedType = (GraphQLNamedType) node;
// we encountered a node referencing one of the marked types, so it should not be removed.
if (markedForRemovalTypes.contains(node)) {
markedForRemovalTypes.remove(namedType);
}
}
return TraversalControl.CONTINUE;
}
}
private List getRootTypes(GraphQLSchema schema) {
return ImmutableList.builder()
.addAll(getOperationTypes(schema))
// Include directive definitions as roots, since they won't be removed in the filtering process.
// Some types (enums, input types, etc.) might be reachable only by directive definitions (and
// not by other types or fields).
.addAll(schema.getDirectives())
.build();
}
private List getOperationTypes(GraphQLSchema schema) {
return Stream.of(
schema.getQueryType(),
schema.getSubscriptionType(),
schema.getMutationType()
).filter(Objects::nonNull).collect(Collectors.toList());
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy