
io.vertx.ext.web.handler.graphql.impl.GraphQLHandlerImpl Maven / Gradle / Ivy
/*
* Copyright 2021 Red Hat, Inc.
*
* Red Hat licenses this file to you under the Apache License, version 2.0
* (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.vertx.ext.web.handler.graphql.impl;
import graphql.ExecutionInput;
import graphql.GraphQL;
import graphql.execution.preparsed.persisted.PersistedQuerySupport;
import io.vertx.core.AsyncResult;
import io.vertx.core.CompositeFuture;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.impl.NoStackTraceThrowable;
import io.vertx.core.json.Json;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.FileUpload;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.handler.graphql.ExecutionInputBuilderWithContext;
import io.vertx.ext.web.handler.graphql.GraphQLHandler;
import io.vertx.ext.web.handler.graphql.GraphQLHandlerOptions;
import org.dataloader.DataLoaderRegistry;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.regex.Pattern;
import static io.vertx.core.http.HttpMethod.*;
/**
* @author Thomas Segismont
*/
public class GraphQLHandlerImpl implements GraphQLHandler {
private static final Pattern IS_NUMBER = Pattern.compile("\\d+");
private static final Function DEFAULT_QUERY_CONTEXT_FACTORY = rc -> rc;
private static final Function DEFAULT_DATA_LOADER_REGISTRY_FACTORY = rc -> null;
private static final Function DEFAULT_LOCALE_FACTORY = rc -> null;
private final GraphQL graphQL;
private final GraphQLHandlerOptions options;
private Function queryContextFactory = DEFAULT_QUERY_CONTEXT_FACTORY;
private Function dataLoaderRegistryFactory = DEFAULT_DATA_LOADER_REGISTRY_FACTORY;
private Function localeFactory = DEFAULT_LOCALE_FACTORY;
private Handler> beforeExecute;
public GraphQLHandlerImpl(GraphQL graphQL, GraphQLHandlerOptions options) {
Objects.requireNonNull(graphQL, "graphQL");
Objects.requireNonNull(options, "options");
this.graphQL = graphQL;
this.options = options;
}
@Override
public synchronized GraphQLHandler queryContext(Function factory) {
queryContextFactory = factory != null ? factory : DEFAULT_QUERY_CONTEXT_FACTORY;
return this;
}
@Override
public synchronized GraphQLHandler dataLoaderRegistry(Function factory) {
dataLoaderRegistryFactory = factory != null ? factory : DEFAULT_DATA_LOADER_REGISTRY_FACTORY;
return this;
}
@Override
public synchronized GraphQLHandler locale(Function factory) {
localeFactory = factory != null ? factory : DEFAULT_LOCALE_FACTORY;
return this;
}
@Override
public synchronized GraphQLHandler beforeExecute(Handler> beforeExecute) {
this.beforeExecute = beforeExecute;
return this;
}
@Override
public void handle(RoutingContext rc) {
HttpMethod method = rc.request().method();
if (method == GET) {
handleGet(rc);
} else if (method == POST) {
if (!rc.body().available()) {
// the body handler was not set, so we cannot securely process POST bodies
// we could just add an ad-hoc body handler but this can lead to DDoS attacks
// and it doesn't really cover all the uploads, such as multipart, etc...
// as well as resource cleanup
rc.fail(500, new NoStackTraceThrowable("BodyHandler is required to process POST requests"));
} else {
handlePost(rc, rc.body().buffer());
}
} else {
rc.fail(405);
}
}
private void handleGet(RoutingContext rc) {
Map variables;
try {
variables = getVariablesFromQueryParam(rc);
} catch (Exception e) {
rc.fail(400, e);
return;
}
Object initialValue;
try {
initialValue = getInitialValueFromQueryParam(rc);
} catch (Exception e) {
rc.fail(400, e);
return;
}
Map extensions;
try {
extensions = getExtensionsFromQueryParam(rc);
} catch (Exception e) {
rc.fail(400, e);
return;
}
String query = rc.queryParams().get("query");
if (query == null) {
if (extensions != null && extensions.containsKey("persistedQuery")) {
query = PersistedQuerySupport.PERSISTED_QUERY_MARKER;
} else {
failQueryMissing(rc);
return;
}
}
GraphQLQuery graphQLQuery = new GraphQLQuery()
.setQuery(query)
.setOperationName(rc.queryParams().get("operationName"))
.setVariables(variables)
.setInitialValue(initialValue)
.setExtensions(extensions);
executeOne(rc, graphQLQuery);
}
private void handlePost(RoutingContext rc, Buffer body) {
Map variables;
try {
variables = getVariablesFromQueryParam(rc);
} catch (Exception e) {
rc.fail(400, e);
return;
}
Object initialValue;
try {
initialValue = getInitialValueFromQueryParam(rc);
} catch (Exception e) {
rc.fail(400, e);
return;
}
Map extensions;
try {
extensions = getExtensionsFromQueryParam(rc);
} catch (Exception e) {
rc.fail(400, e);
return;
}
String query = rc.queryParams().get("query");
if (query != null) {
GraphQLQuery graphQLQuery = new GraphQLQuery()
.setQuery(query)
.setOperationName(rc.queryParams().get("operationName"))
.setVariables(variables)
.setInitialValue(initialValue)
.setExtensions(extensions);
executeOne(rc, graphQLQuery);
return;
}
switch (getContentType(rc)) {
case "application/json":
if (body == null) {
// plain failure as the json body is missing
rc.fail(400, new NoStackTraceThrowable("No body"));
return;
}
handlePostJson(rc, body, rc.queryParams().get("operationName"), variables, initialValue, extensions);
break;
case "multipart/form-data":
handlePostMultipart(rc, rc.queryParams().get("operationName"), variables, initialValue, extensions);
break;
case "application/graphql":
if (body == null) {
// plain failure as the query is missing
rc.fail(400, new NoStackTraceThrowable("No body"));
return;
}
GraphQLQuery graphQLQuery = new GraphQLQuery()
.setQuery(body.toString())
.setOperationName(rc.queryParams().get("operationName"))
.setVariables(variables)
.setInitialValue(initialValue)
.setExtensions(extensions);
executeOne(rc, graphQLQuery);
break;
default:
rc.fail(415);
}
}
private void handlePostJson(RoutingContext rc, Buffer body, String operationName, Map variables, Object initialValue, Map extensions) {
GraphQLInput graphQLInput;
try {
graphQLInput = GraphQLInput.decode(body);
} catch (Exception e) {
rc.fail(400, e);
return;
}
if (graphQLInput instanceof GraphQLBatch) {
handlePostBatch(rc, (GraphQLBatch) graphQLInput, operationName, variables, initialValue, extensions);
} else if (graphQLInput instanceof GraphQLQuery) {
handlePostQuery(rc, (GraphQLQuery) graphQLInput, operationName, variables, initialValue, extensions);
} else {
rc.fail(500);
}
}
private void handlePostBatch(RoutingContext rc, GraphQLBatch batch, String operationName, Map variables, Object initialValue, Map extensions) {
if (!options.isRequestBatchingEnabled()) {
rc.fail(400);
return;
}
for (GraphQLQuery query : batch) {
if (operationName != null) {
query.setOperationName(operationName);
}
if (variables != null) {
query.setVariables(variables);
}
if (initialValue != null) {
query.setInitialValue(initialValue);
}
if (extensions != null) {
query.setExtensions(extensions);
}
if (query.getQuery() == null) {
Map exts = query.getExtensions();
if (exts != null && exts.containsKey("persistedQuery")) {
query.setQuery(PersistedQuerySupport.PERSISTED_QUERY_MARKER);
} else {
failQueryMissing(rc);
return;
}
}
}
executeBatch(rc, batch);
}
@SuppressWarnings("rawtypes")
private void executeBatch(RoutingContext rc, GraphQLBatch batch) {
List futures = new ArrayList<>(batch.size());
for (GraphQLQuery graphQLQuery : batch) {
futures.add(execute(rc, graphQLQuery));
}
CompositeFuture.all(futures)
.map(cf -> new JsonArray(cf.list()).toBuffer())
.onComplete(ar -> sendResponse(rc, ar));
}
private void handlePostQuery(RoutingContext rc, GraphQLQuery query, String operationName, Map variables, Object initialValue, Map extensions) {
if (operationName != null) {
query.setOperationName(operationName);
}
if (variables != null) {
query.setVariables(variables);
}
if (initialValue != null) {
query.setInitialValue(initialValue);
}
if (extensions != null) {
query.setExtensions(extensions);
}
if (query.getQuery() == null) {
Map exts = query.getExtensions();
if (exts != null && exts.containsKey("persistedQuery")) {
query.setQuery(PersistedQuerySupport.PERSISTED_QUERY_MARKER);
} else {
failQueryMissing(rc);
return;
}
}
executeOne(rc, query);
}
/**
* An "operations object" is an Apollo GraphQL POST request (or array of requests if batching).
* An "operations path" is an object-path string to locate a file within an operations object.
*
* So operations can be resolved while the files are still uploading, the fields are ordered:
*
* 1. operations: A JSON encoded operations object with files replaced with null.
* 2. map: A JSON encoded map of where files occurred in the operations. For each file, the key is
* the file multipart form field name and the value is an array of operations paths.
* 3. File fields: Each file extracted from the operations object with a unique, arbitrary field name.
*
* @see GraphQL multipart request specification
**/
private void handlePostMultipart(RoutingContext rc, String operationName, Map variables, Object initialValue, Map extensions) {
GraphQLInput graphQLInput;
if (!options.isRequestMultipartEnabled()) {
rc.fail(415);
return;
}
try {
graphQLInput = parseMultipartAttributes(rc);
} catch (Exception e) {
rc.fail(400, e);
return;
}
if (graphQLInput instanceof GraphQLBatch) {
handlePostBatch(rc, (GraphQLBatch) graphQLInput, operationName, variables, initialValue, extensions);
} else if (graphQLInput instanceof GraphQLQuery) {
handlePostQuery(rc, (GraphQLQuery) graphQLInput, operationName, variables, initialValue, extensions);
} else {
rc.fail(500);
}
}
private GraphQLInput parseMultipartAttributes(RoutingContext rc) {
MultiMap attrs = rc.request().formAttributes();
@SuppressWarnings("unchecked")
Map filesMap = (Map) Json.decodeValue(attrs.get("map"), Map.class);
GraphQLInput graphQLInput = GraphQLInput.decode(Json.decodeValue(attrs.get("operations")));
Map> variablesMap = new HashMap<>();
Iterable batch = (graphQLInput instanceof GraphQLBatch)
? (GraphQLBatch) graphQLInput
: Collections.singletonList((GraphQLQuery) graphQLInput);
int i = 0;
Iterator iterator = batch.iterator();
for (; iterator.hasNext(); i++) {
GraphQLQuery query = iterator.next();
Map variables = new HashMap<>();
variables.put("variables", query.getVariables());
variablesMap.put(String.valueOf(i), variables);
}
for (Map.Entry entry : filesMap.entrySet()) {
for (Object fullPath : (List) entry.getValue()) {
String[] path = ((String) fullPath).split("\\.");
int end = path.length;
int idx = -1;
if (IS_NUMBER.matcher(path[end - 1]).matches()) {
idx = Integer.parseInt(path[end - 1]);
--end;
}
Map, ?> variables;
int start = 0;
if (IS_NUMBER.matcher(path[0]).matches()) {
variables = variablesMap.get(path[0]);
++start;
} else {
variables = variablesMap.get("0");
}
String attr = path[--end];
Map obj = variables;
for (; start < end; ++start) {
String token = path[start];
obj = (Map) obj.get(token);
}
FileUpload file = rc.fileUploads().stream()
.filter(f -> f.name().equals(entry.getKey())).findFirst().orElse(null);
if (file != null) {
if (idx == -1) {
obj.put(attr, file);
} else {
((List) obj.get(attr)).set(idx, file);
}
}
}
}
return graphQLInput;
}
private void executeOne(RoutingContext rc, GraphQLQuery query) {
execute(rc, query)
.map(JsonObject::toBuffer)
.onComplete(ar -> sendResponse(rc, ar));
}
private Future execute(RoutingContext rc, GraphQLQuery query) {
ExecutionInput.Builder builder = ExecutionInput.newExecutionInput();
builder.query(query.getQuery());
String operationName = query.getOperationName();
if (operationName != null) {
builder.operationName(operationName);
}
Map variables = query.getVariables();
if (variables != null) {
builder.variables(variables);
}
Object initialValue = query.getInitialValue();
if (initialValue != null) {
builder.root(initialValue);
}
Map extensions = query.getExtensions();
if (extensions != null) {
builder.extensions(extensions);
}
Function qc;
Function dlr;
Function l;
Handler> be;
synchronized (this) {
qc = queryContextFactory;
dlr = dataLoaderRegistryFactory;
l = localeFactory;
be = beforeExecute;
}
builder.context(qc.apply(rc));
DataLoaderRegistry registry = dlr.apply(rc);
if (registry != null) {
builder.dataLoaderRegistry(registry);
}
Locale locale = l.apply(rc);
if (locale != null) {
builder.locale(locale);
}
builder.graphQLContext(Collections.singletonMap(RoutingContext.class, rc));
if (be != null) {
be.handle(new ExecutionInputBuilderWithContext() {
@Override
public RoutingContext context() {
return rc;
}
@Override
public ExecutionInput.Builder builder() {
return builder;
}
});
}
return Future.fromCompletionStage(graphQL.executeAsync(builder.build()), rc.vertx().getOrCreateContext())
.map(executionResult -> new JsonObject(executionResult.toSpecification()));
}
private String getContentType(RoutingContext rc) {
String contentType = rc.parsedHeaders().contentType().value();
return contentType.isEmpty() ? "application/json" : contentType.toLowerCase();
}
private Map getVariablesFromQueryParam(RoutingContext rc) throws Exception {
String variablesParam = rc.queryParams().get("variables");
if (variablesParam == null) {
return null;
} else {
return new JsonObject(variablesParam).getMap();
}
}
private Object getInitialValueFromQueryParam(RoutingContext rc) throws Exception {
String initialParam = rc.queryParams().get("initialValue");
if (initialParam == null || initialParam.isEmpty()) {
return null;
} else {
return Json.decodeValue(initialParam);
}
}
private Map getExtensionsFromQueryParam(RoutingContext rc) throws Exception {
String extensionsParam = rc.queryParams().get("extensions");
if (extensionsParam == null) {
return null;
} else {
return new JsonObject(extensionsParam).getMap();
}
}
private void sendResponse(RoutingContext rc, AsyncResult ar) {
if (ar.succeeded()) {
rc.response().putHeader(HttpHeaders.CONTENT_TYPE, "application/json").end(ar.result());
} else {
rc.fail(ar.cause());
}
}
private void failQueryMissing(RoutingContext rc) {
rc.fail(400, new NoStackTraceThrowable("Query is missing"));
}
}