graphql.servlet.GraphQLServlet Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of graphql-java-servlet Show documentation
Show all versions of graphql-java-servlet Show documentation
relay.js-compatible GraphQL servlet
package graphql.servlet;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.InjectableValues;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectReader;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import graphql.ExecutionInput;
import graphql.ExecutionResult;
import graphql.GraphQL;
import graphql.GraphQLError;
import graphql.execution.instrumentation.Instrumentation;
import graphql.introspection.IntrospectionQuery;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLSchema;
import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileItemFactory;
import org.apache.commons.fileupload.disk.DiskFileItemFactory;
import org.apache.commons.fileupload.servlet.ServletFileUpload;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.security.auth.Subject;
import javax.servlet.Servlet;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Writer;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* @author Andrew Potter
*/
public abstract class GraphQLServlet extends HttpServlet implements Servlet, GraphQLMBean {
public static final Logger log = LoggerFactory.getLogger(GraphQLServlet.class);
public static final String APPLICATION_JSON_UTF8 = "application/json;charset=UTF-8";
public static final int STATUS_OK = 200;
public static final int STATUS_BAD_REQUEST = 400;
protected abstract GraphQLSchemaProvider getSchemaProvider();
protected abstract GraphQLContext createContext(Optional request, Optional response);
protected abstract Object createRootObject(Optional request, Optional response);
protected abstract ExecutionStrategyProvider getExecutionStrategyProvider();
protected abstract Instrumentation getInstrumentation();
protected abstract Map transformVariables(GraphQLSchema schema, String query, Map variables);
protected abstract GraphQLErrorHandler getGraphQLErrorHandler();
private final LazyObjectMapperBuilder lazyObjectMapperBuilder;
private final List listeners;
private final ServletFileUpload fileUpload;
private final HttpRequestHandler getHandler;
private final HttpRequestHandler postHandler;
public GraphQLServlet() {
this(null, null, null);
}
public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List listeners, FileItemFactory fileItemFactory) {
this.lazyObjectMapperBuilder = new LazyObjectMapperBuilder(objectMapperConfigurer != null ? objectMapperConfigurer : new DefaultObjectMapperConfigurer());
this.listeners = listeners != null ? new ArrayList<>(listeners) : new ArrayList<>();
this.fileUpload = new ServletFileUpload(fileItemFactory != null ? fileItemFactory : new DiskFileItemFactory());
this.getHandler = (request, response) -> {
final GraphQLContext context = createContext(Optional.of(request), Optional.of(response));
final Object rootObject = createRootObject(Optional.of(request), Optional.of(response));
String path = request.getPathInfo();
if (path == null) {
path = request.getServletPath();
}
if (path.contentEquals("/schema.json")) {
doQuery(IntrospectionQuery.INTROSPECTION_QUERY, null, new HashMap<>(), getSchemaProvider().getSchema(request), context, rootObject, request, response);
} else {
String query = request.getParameter("query");
if (query != null) {
if (isBatchedQuery(query)) {
doBatchedQuery(getGraphQLRequestMapper().readValues(query), getSchemaProvider().getReadOnlySchema(request), context, rootObject, request, response);
} else {
final Map variables = new HashMap<>();
if (request.getParameter("variables") != null) {
variables.putAll(deserializeVariables(request.getParameter("variables")));
}
String operationName = null;
if (request.getParameter("operationName") != null) {
operationName = request.getParameter("operationName");
}
doQuery(query, operationName, variables, getSchemaProvider().getReadOnlySchema(request), context, rootObject, request, response);
}
} else {
response.setStatus(STATUS_BAD_REQUEST);
log.info("Bad GET request: path was not \"/schema.json\" or no query variable named \"query\" given");
}
}
};
this.postHandler = (request, response) -> {
final GraphQLContext context = createContext(Optional.of(request), Optional.of(response));
final Object rootObject = createRootObject(Optional.of(request), Optional.of(response));
try {
if (ServletFileUpload.isMultipartContent(request)) {
final Map> fileItems = fileUpload.parseParameterMap(request);
context.setFiles(Optional.of(fileItems));
if (fileItems.containsKey("graphql")) {
final Optional graphqlItem = getFileItem(fileItems, "graphql");
if (graphqlItem.isPresent()) {
InputStream inputStream = graphqlItem.get().getInputStream();
if (!inputStream.markSupported()) {
inputStream = new BufferedInputStream(inputStream);
}
if (isBatchedQuery(inputStream)) {
doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response);
return;
} else {
doQuery(getGraphQLRequestMapper().readValue(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response);
return;
}
}
} else if (fileItems.containsKey("query")) {
final Optional queryItem = getFileItem(fileItems, "query");
if (queryItem.isPresent()) {
InputStream inputStream = queryItem.get().getInputStream();
if (!inputStream.markSupported()) {
inputStream = new BufferedInputStream(inputStream);
}
if (isBatchedQuery(inputStream)) {
doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response);
return;
} else {
String query = new String(queryItem.get().get());
Map variables = null;
final Optional variablesItem = getFileItem(fileItems, "variables");
if (variablesItem.isPresent()) {
variables = deserializeVariables(new String(variablesItem.get().get()));
}
String operationName = null;
final Optional operationNameItem = getFileItem(fileItems, "operationName");
if (operationNameItem.isPresent()) {
operationName = new String(operationNameItem.get().get()).trim();
}
doQuery(query, operationName, variables, getSchemaProvider().getSchema(request), context, rootObject, request, response);
return;
}
}
}
response.setStatus(STATUS_BAD_REQUEST);
log.info("Bad POST multipart request: no part named \"graphql\" or \"query\"");
} else {
// this is not a multipart request
InputStream inputStream = request.getInputStream();
if (!inputStream.markSupported()) {
inputStream = new BufferedInputStream(inputStream);
}
if (isBatchedQuery(inputStream)) {
doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response);
} else {
doQuery(getGraphQLRequestMapper().readValue(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response);
}
}
} catch (Exception e) {
log.info("Bad POST request: parsing failed", e);
response.setStatus(STATUS_BAD_REQUEST);
}
};
}
protected ObjectMapper getMapper() {
return lazyObjectMapperBuilder.getMapper();
}
/**
* Creates an {@link ObjectReader} for deserializing {@link GraphQLRequest}
*/
private ObjectReader getGraphQLRequestMapper() {
// Add object mapper to injection so VariablesDeserializer can access it...
InjectableValues.Std injectableValues = new InjectableValues.Std();
injectableValues.addValue(ObjectMapper.class, getMapper());
return getMapper().reader(injectableValues).forType(GraphQLRequest.class);
}
public void addListener(GraphQLServletListener servletListener) {
listeners.add(servletListener);
}
public void removeListener(GraphQLServletListener servletListener) {
listeners.remove(servletListener);
}
@Override
public String[] getQueries() {
return getSchemaProvider().getSchema().getQueryType().getFieldDefinitions().stream().map(GraphQLFieldDefinition::getName).toArray(String[]::new);
}
@Override
public String[] getMutations() {
return getSchemaProvider().getSchema().getMutationType().getFieldDefinitions().stream().map(GraphQLFieldDefinition::getName).toArray(String[]::new);
}
@Override
public String executeQuery(String query) {
try {
final ExecutionResult result = newGraphQL(getSchemaProvider().getSchema()).execute(new ExecutionInput(query, null, createContext(Optional.empty(), Optional.empty()), createRootObject(Optional.empty(), Optional.empty()), new HashMap<>()));
return getMapper().writeValueAsString(createResultFromDataAndErrors(result.getData(), result.getErrors()));
} catch (Exception e) {
return e.getMessage();
}
}
private void doRequest(HttpServletRequest request, HttpServletResponse response, HttpRequestHandler handler) {
List requestCallbacks = runListeners(l -> l.onRequest(request, response));
try {
handler.handle(request, response);
runCallbacks(requestCallbacks, c -> c.onSuccess(request, response));
} catch (Throwable t) {
response.setStatus(500);
log.error("Error executing GraphQL request!", t);
runCallbacks(requestCallbacks, c -> c.onError(request, response, t));
} finally {
runCallbacks(requestCallbacks, c -> c.onFinally(request, response));
}
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doRequest(req, resp, getHandler);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doRequest(req, resp, postHandler);
}
private Optional getFileItem(Map> fileItems, String name) {
List items = fileItems.get(name);
if(items == null || items.isEmpty()) {
return Optional.empty();
}
return items.stream().findFirst();
}
private GraphQL newGraphQL(GraphQLSchema schema) {
ExecutionStrategyProvider executionStrategyProvider = getExecutionStrategyProvider();
return GraphQL.newGraphQL(schema)
.queryExecutionStrategy(executionStrategyProvider.getQueryExecutionStrategy())
.mutationExecutionStrategy(executionStrategyProvider.getMutationExecutionStrategy())
.subscriptionExecutionStrategy(executionStrategyProvider.getSubscriptionExecutionStrategy())
.instrumentation(getInstrumentation())
.build();
}
private void doQuery(GraphQLRequest graphQLRequest, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest httpReq, HttpServletResponse httpRes) throws Exception {
doQuery(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, httpReq, httpRes);
}
private void doQuery(String query, String operationName, Map variables, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest req, HttpServletResponse resp) throws Exception {
query(query, operationName, variables, schema, context, rootObject, (r) -> {
resp.setContentType(APPLICATION_JSON_UTF8);
resp.setStatus(r.getStatus());
resp.getWriter().write(r.getResponse());
});
}
private void doBatchedQuery(Iterator graphQLRequests, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest req, HttpServletResponse resp) throws Exception {
resp.setContentType(APPLICATION_JSON_UTF8);
resp.setStatus(STATUS_OK);
Writer respWriter = resp.getWriter();
respWriter.write('[');
while (graphQLRequests.hasNext()) {
GraphQLRequest graphQLRequest = graphQLRequests.next();
query(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, (r) -> respWriter.write(r.getResponse()));
if (graphQLRequests.hasNext()) {
respWriter.write(',');
}
}
respWriter.write(']');
}
private void query(String query, String operationName, Map variables, GraphQLSchema schema, GraphQLContext context, Object rootObject, GraphQLResponseHandler responseHandler) throws Exception {
if (operationName != null && operationName.isEmpty()) {
query(query, null, variables, schema, context, rootObject, responseHandler);
} else if (Subject.getSubject(AccessController.getContext()) == null && context.getSubject().isPresent()) {
Subject.doAs(context.getSubject().get(), (PrivilegedAction) () -> {
try {
query(query, operationName, variables, schema, context, rootObject, responseHandler);
} catch (Exception e) {
throw new RuntimeException(e);
}
return null;
});
} else {
List operationCallbacks = runListeners(l -> l.onOperation(context, operationName, query, variables));
final ExecutionResult executionResult = newGraphQL(schema).execute(new ExecutionInput(query, operationName, context, rootObject, transformVariables(schema, query, variables)));
final List errors = executionResult.getErrors();
final Object data = executionResult.getData();
final String response = getMapper().writeValueAsString(createResultFromDataAndErrors(data, errors));
GraphQLResponse graphQLResponse = new GraphQLResponse();
graphQLResponse.setStatus(STATUS_OK);
graphQLResponse.setResponse(response);
responseHandler.handle(graphQLResponse);
if(getGraphQLErrorHandler().errorsPresent(errors)) {
runCallbacks(operationCallbacks, c -> c.onError(context, operationName, query, variables, data, errors));
} else {
runCallbacks(operationCallbacks, c -> c.onSuccess(context, operationName, query, variables, data));
}
runCallbacks(operationCallbacks, c -> c.onFinally(context, operationName, query, variables, data));
}
}
private Map createResultFromDataAndErrors(Object data, List errors) {
final Map result = new HashMap<>();
result.put("data", data);
if (getGraphQLErrorHandler().errorsPresent(errors)) {
result.put("errors", getGraphQLErrorHandler().processErrors(errors));
}
return result;
}
private List runListeners(Function super GraphQLServletListener, R> action) {
if (listeners == null) {
return Collections.emptyList();
}
return listeners.stream()
.map(listener -> {
try {
return action.apply(listener);
} catch (Throwable t) {
log.error("Error running listener: {}", listener, t);
return null;
}
})
.filter(Objects::nonNull)
.collect(Collectors.toList());
}
private void runCallbacks(List callbacks, Consumer action) {
callbacks.forEach(callback -> {
try {
action.accept(callback);
} catch (Throwable t) {
log.error("Error running callback: {}", callback, t);
}
});
}
protected static class VariablesDeserializer extends JsonDeserializer