All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.xlrit.gears.server.graphql.GraphqlMultipartHandler Maven / Gradle / Ivy

There is a newer version: 1.17.6
Show newest version
package com.xlrit.gears.server.graphql;

import java.net.URI;
import java.util.*;
import java.util.regex.Pattern;
import jakarta.servlet.ServletException;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.graphql.ExecutionGraphQlRequest;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlRequest;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.util.AlternativeJdkIdGenerator;
import org.springframework.util.IdGenerator;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.support.AbstractMultipartHttpServletRequest;
import org.springframework.web.server.ServerWebInputException;
import org.springframework.web.servlet.function.ServerRequest;
import org.springframework.web.servlet.function.ServerResponse;
import reactor.core.publisher.Mono;

// from https://github.com/spring-projects/spring-graphql/issues/69
@RequiredArgsConstructor
public class GraphqlMultipartHandler {
	private static final Logger LOG = LoggerFactory.getLogger(GraphqlMultipartHandler.class);

	public static final List SUPPORTED_RESPONSE_MEDIA_TYPES =
		Arrays.asList(MediaType.APPLICATION_GRAPHQL, MediaType.APPLICATION_JSON);

	private final WebGraphQlHandler graphQlHandler;
	private final ObjectMapper objectMapper;
	private final IdGenerator idGenerator = new AlternativeJdkIdGenerator();

	public ServerResponse handleRequest(ServerRequest serverRequest) throws ServletException {
		Optional operation = serverRequest.param("operations");
		Optional mapParam = serverRequest.param("map");
		Map inputQuery = readJson(operation, new TypeReference<>() {});

		final Map queryVariables;
		if (inputQuery.containsKey("variables")) {
			//noinspection unchecked
			queryVariables = (Map) inputQuery.get("variables");
		} else {
			queryVariables = new HashMap<>();
		}

		Map extensions = new HashMap<>();
		if (inputQuery.containsKey("extensions")) {
			//noinspection unchecked
			extensions = (Map) inputQuery.get("extensions");
		}

		Map fileParams = readMultipartBody(serverRequest);
		Map> fileMapInput = readJson(mapParam, new TypeReference<>() {});

		fileMapInput.forEach((String fileKey, List objectPaths) -> {
			MultipartFile file = fileParams.get(fileKey);
			if (file != null) {
				objectPaths.forEach((String objectPath) -> {
					MultipartVariableMapper.mapVariable(
						objectPath,
						queryVariables,
						file
					);
				});
			}
		});

		String query = (String) inputQuery.get("query");
		String opName = (String) inputQuery.get("operationName");

		WebGraphQlRequest graphQlRequest = new MultipartGraphQlRequest(
			query,
			opName,
			queryVariables,
			extensions,
			serverRequest.uri(), serverRequest.headers().asHttpHeaders(),
			this.idGenerator.generateId().toString(), LocaleContextHolder.getLocale());

		if (LOG.isDebugEnabled()) {
			LOG.debug("Executing: " + graphQlRequest);
		}

		Mono responseMono = this.graphQlHandler.handleRequest(graphQlRequest)
			.map(response -> {
				if (LOG.isDebugEnabled()) {
					LOG.debug("Execution complete");
				}
				ServerResponse.BodyBuilder builder = ServerResponse.ok();
				builder.headers(headers -> headers.putAll(response.getResponseHeaders()));
				builder.contentType(selectResponseMediaType(serverRequest));
				return builder.body(response.toMap());
			});

		return ServerResponse.async(responseMono);
	}

	private  T readJson(Optional string, TypeReference t) {
		if (string.isEmpty()) {
			return (T) new HashMap();
		}
		try {
			return objectMapper.readValue(string.get(), t);
		}
		catch (JsonProcessingException e) {
			throw new RuntimeException(e);
		}
	}

	private static Map readMultipartBody(ServerRequest request) {
		try {
			AbstractMultipartHttpServletRequest abstractMultipartHttpServletRequest = (AbstractMultipartHttpServletRequest) request.servletRequest();
			return abstractMultipartHttpServletRequest.getFileMap();
		}
		catch (RuntimeException ex) {
			throw new ServerWebInputException("Error while reading request parts", null, ex);
		}
	}

	private static MediaType selectResponseMediaType(ServerRequest serverRequest) {
		for (MediaType accepted : serverRequest.headers().accept()) {
			if (SUPPORTED_RESPONSE_MEDIA_TYPES.contains(accepted)) {
				return accepted;
			}
		}
		return MediaType.APPLICATION_JSON;
	}
}

// As in DGS, this is borrowed from https://github.com/graphql-java-kickstart/graphql-java-servlet/blob/eb4dfdb5c0198adc1b4d4466c3b4ea4a77def5d1/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/core/internal/VariableMapper.java
class MultipartVariableMapper {

    private static final Pattern PERIOD = Pattern.compile("\\.");

    private static final Mapper> MAP_MAPPER = new Mapper<>() {
		@Override
		public Object set(Map location, String target, MultipartFile value) {
			return location.put(target, value);
		}

		@Override
		public Object recurse(Map location, String target) {
			return location.get(target);
		}
	};

    private static final Mapper> LIST_MAPPER = new Mapper<>() {
		@Override
		public Object set(List location, String target, MultipartFile value) {
			return location.set(Integer.parseInt(target), value);
		}

		@Override
		public Object recurse(List location, String target) {
			return location.get(Integer.parseInt(target));
		}
	};

    @SuppressWarnings({"unchecked", "rawtypes"})
    public static void mapVariable(String objectPath, Map variables, MultipartFile part) {
        String[] segments = PERIOD.split(objectPath);

        if (segments.length < 2) {
            throw new RuntimeException("object-path in map must have at least two segments");
        }
		else if (!"variables".equals(segments[0])) {
            throw new RuntimeException("can only map into variables");
        }

        Object currentLocation = variables;
        for (int i = 1; i < segments.length; i++) {
            String segmentName = segments[i];
            Mapper mapper = determineMapper(currentLocation, objectPath, segmentName);

            if (i == segments.length - 1) {
                if (null != mapper.set(currentLocation, segmentName, part)) {
                    throw new RuntimeException("expected null value when mapping " + objectPath);
                }
            }
			else {
                currentLocation = mapper.recurse(currentLocation, segmentName);
                if (null == currentLocation) {
                    throw new RuntimeException(
                        "found null intermediate value when trying to map " + objectPath);
                }
            }
        }
    }

    private static Mapper determineMapper(Object currentLocation, String objectPath, String segmentName) {
        if (currentLocation instanceof Map) {
            return MAP_MAPPER;
        }
		else if (currentLocation instanceof List) {
            return LIST_MAPPER;
        }
        throw new RuntimeException("expected a map or list at " + segmentName + " when trying to map " + objectPath);
    }

    interface Mapper {
        Object set(T location, String target, MultipartFile value);
        Object recurse(T location, String target);
    }
}

// It's possible to remove this class if there was a extra constructor in WebGraphQlRequest
class MultipartGraphQlRequest extends WebGraphQlRequest implements ExecutionGraphQlRequest {

    private final String document;
    private final String operationName;
    private final Map variables;
    private final Map extensions;

    public MultipartGraphQlRequest(
		String query,
		String operationName,
		Map variables,
		Map extensions,
		URI uri, HttpHeaders headers,
		String id, Locale locale) {

        super(uri, headers, fakeBody(query), id, locale);

        this.document = query;
        this.operationName = operationName;
        this.variables = variables;
        this.extensions = extensions;
    }

    private static Map fakeBody(String query) {
        Map fakeBody = new HashMap<>();
        fakeBody.put("query", query);
        return fakeBody;
    }

    @Override
    public String getDocument() {
        return document;
    }

    @Override
    public String getOperationName() {
        return operationName;
    }

    @Override
    public Map getVariables() {
        return variables;
    }

    @Override
    public Map getExtensions() {
        return extensions;
    }
}