All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.springframework.web.servlet.function.DefaultServerRequest Maven / Gradle / Ivy

There is a newer version: 6.1.6
Show newest version
/*
 * Copyright 2002-2021 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.web.servlet.function;

import java.io.IOException;
import java.io.PrintWriter;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.charset.Charset;
import java.security.Principal;
import java.time.Instant;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.stream.Collectors;

import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.servlet.http.Part;

import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpRange;
import org.springframework.http.MediaType;
import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.RequestPath;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ObjectUtils;
import org.springframework.web.HttpMediaTypeNotSupportedException;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.context.request.WebRequest;
import org.springframework.web.servlet.support.ServletUriComponentsBuilder;
import org.springframework.web.util.ServletRequestPathUtils;
import org.springframework.web.util.UriBuilder;

/**
 * {@code ServerRequest} implementation based on a {@link HttpServletRequest}.
 *
 * @author Arjen Poutsma
 * @since 5.2
 */
class DefaultServerRequest implements ServerRequest {

	private final ServletServerHttpRequest serverHttpRequest;

	private final RequestPath requestPath;

	private final Headers headers;

	private final List> messageConverters;

	private final MultiValueMap params;

	private final Map attributes;

	@Nullable
	private MultiValueMap parts;


	public DefaultServerRequest(HttpServletRequest servletRequest, List> messageConverters) {
		this.serverHttpRequest = new ServletServerHttpRequest(servletRequest);
		this.messageConverters = Collections.unmodifiableList(new ArrayList<>(messageConverters));

		this.headers = new DefaultRequestHeaders(this.serverHttpRequest.getHeaders());
		this.params = CollectionUtils.toMultiValueMap(new ServletParametersMap(servletRequest));
		this.attributes = new ServletAttributesMap(servletRequest);

		// DispatcherServlet parses the path but for other scenarios (e.g. tests) we might need to

		this.requestPath = (ServletRequestPathUtils.hasParsedRequestPath(servletRequest) ?
				ServletRequestPathUtils.getParsedRequestPath(servletRequest) :
				ServletRequestPathUtils.parseAndCache(servletRequest));
	}


	@Override
	public String methodName() {
		return servletRequest().getMethod();
	}

	@Override
	public URI uri() {
		return this.serverHttpRequest.getURI();
	}

	@Override
	public UriBuilder uriBuilder() {
		return ServletUriComponentsBuilder.fromRequest(servletRequest());
	}

	@Override
	public RequestPath requestPath() {
		return this.requestPath;
	}

	@Override
	public Headers headers() {
		return this.headers;
	}

	@Override
	public MultiValueMap cookies() {
		Cookie[] cookies = servletRequest().getCookies();
		if (cookies == null) {
			cookies = new Cookie[0];
		}
		MultiValueMap result = new LinkedMultiValueMap<>(cookies.length);
		for (Cookie cookie : cookies) {
			result.add(cookie.getName(), cookie);
		}
		return result;
	}

	@Override
	public HttpServletRequest servletRequest() {
		return this.serverHttpRequest.getServletRequest();
	}

	@Override
	public Optional remoteAddress() {
		return Optional.of(this.serverHttpRequest.getRemoteAddress());
	}

	@Override
	public List> messageConverters() {
		return this.messageConverters;
	}

	@Override
	public  T body(Class bodyType) throws IOException, ServletException {
		return bodyInternal(bodyType, bodyType);
	}

	@Override
	public  T body(ParameterizedTypeReference bodyType) throws IOException, ServletException {
		Type type = bodyType.getType();
		return bodyInternal(type, bodyClass(type));
	}

	static Class bodyClass(Type type) {
		if (type instanceof Class) {
			return (Class) type;
		}
		if (type instanceof ParameterizedType) {
			ParameterizedType parameterizedType = (ParameterizedType) type;
			if (parameterizedType.getRawType() instanceof Class) {
				return (Class) parameterizedType.getRawType();
			}
		}
		return Object.class;
	}

	@SuppressWarnings("unchecked")
	private  T bodyInternal(Type bodyType, Class bodyClass) throws ServletException, IOException {
		MediaType contentType = this.headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);

		for (HttpMessageConverter messageConverter : this.messageConverters) {
			if (messageConverter instanceof GenericHttpMessageConverter) {
				GenericHttpMessageConverter genericMessageConverter =
						(GenericHttpMessageConverter) messageConverter;
				if (genericMessageConverter.canRead(bodyType, bodyClass, contentType)) {
					return genericMessageConverter.read(bodyType, bodyClass, this.serverHttpRequest);
				}
			}
			if (messageConverter.canRead(bodyClass, contentType)) {
				HttpMessageConverter theConverter =
						(HttpMessageConverter) messageConverter;
				Class clazz = (Class) bodyClass;
				return theConverter.read(clazz, this.serverHttpRequest);
			}
		}
		throw new HttpMediaTypeNotSupportedException(contentType, getSupportedMediaTypes(bodyClass));
	}

	private List getSupportedMediaTypes(Class bodyClass) {
		return this.messageConverters.stream()
				.flatMap(converter -> converter.getSupportedMediaTypes(bodyClass).stream())
				.sorted(MediaType.SPECIFICITY_COMPARATOR)
				.collect(Collectors.toList());
	}

	@Override
	public Optional attribute(String name) {
		return Optional.ofNullable(servletRequest().getAttribute(name));
	}

	@Override
	public Map attributes() {
		return this.attributes;
	}

	@Override
	public Optional param(String name) {
		return Optional.ofNullable(servletRequest().getParameter(name));
	}

	@Override
	public MultiValueMap params() {
		return this.params;
	}

	@Override
	public MultiValueMap multipartData() throws IOException, ServletException {
		MultiValueMap result = this.parts;
		if (result == null) {
			result = servletRequest().getParts().stream()
					.collect(Collectors.groupingBy(Part::getName,
							LinkedMultiValueMap::new,
							Collectors.toList()));
			this.parts = result;
		}
		return result;
	}

	@Override
	@SuppressWarnings("unchecked")
	public Map pathVariables() {
		Map pathVariables = (Map)
				servletRequest().getAttribute(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE);
		if (pathVariables != null) {
			return pathVariables;
		}
		else {
			return Collections.emptyMap();
		}
	}

	@Override
	public HttpSession session() {
		return servletRequest().getSession(true);
	}

	@Override
	public Optional principal() {
		return Optional.ofNullable(this.serverHttpRequest.getPrincipal());
	}

	@Override
	public String toString() {
		return String.format("HTTP %s %s", method(), path());
	}

	static Optional checkNotModified(
			HttpServletRequest servletRequest, @Nullable Instant lastModified, @Nullable String etag) {

		long lastModifiedTimestamp = -1;
		if (lastModified != null && lastModified.isAfter(Instant.EPOCH)) {
			lastModifiedTimestamp = lastModified.toEpochMilli();
		}

		CheckNotModifiedResponse response = new CheckNotModifiedResponse();
		WebRequest webRequest = new ServletWebRequest(servletRequest, response);
		if (webRequest.checkNotModified(etag, lastModifiedTimestamp)) {
			return Optional.of(ServerResponse.status(response.status).
					headers(headers -> headers.addAll(response.headers))
					.build());
		}
		else {
			return Optional.empty();
		}
	}


	/**
	 * Default implementation of {@link Headers}.
	 */
	static class DefaultRequestHeaders implements Headers {

		private final HttpHeaders httpHeaders;

		public DefaultRequestHeaders(HttpHeaders httpHeaders) {
			this.httpHeaders = HttpHeaders.readOnlyHttpHeaders(httpHeaders);
		}

		@Override
		public List accept() {
			return this.httpHeaders.getAccept();
		}

		@Override
		public List acceptCharset() {
			return this.httpHeaders.getAcceptCharset();
		}

		@Override
		public List acceptLanguage() {
			return this.httpHeaders.getAcceptLanguage();
		}

		@Override
		public OptionalLong contentLength() {
			long value = this.httpHeaders.getContentLength();
			return (value != -1 ? OptionalLong.of(value) : OptionalLong.empty());
		}

		@Override
		public Optional contentType() {
			return Optional.ofNullable(this.httpHeaders.getContentType());
		}

		@Override
		public InetSocketAddress host() {
			return this.httpHeaders.getHost();
		}

		@Override
		public List range() {
			return this.httpHeaders.getRange();
		}

		@Override
		public List header(String headerName) {
			List headerValues = this.httpHeaders.get(headerName);
			return (headerValues != null ? headerValues : Collections.emptyList());
		}

		@Override
		public HttpHeaders asHttpHeaders() {
			return this.httpHeaders;
		}

		@Override
		public String toString() {
			return this.httpHeaders.toString();
		}
	}


	private static final class ServletParametersMap extends AbstractMap> {

		private final HttpServletRequest servletRequest;

		private ServletParametersMap(HttpServletRequest servletRequest) {
			this.servletRequest = servletRequest;
		}

		@Override
		public Set>> entrySet() {
			return this.servletRequest.getParameterMap().entrySet().stream()
					.map(entry -> {
						List value = Arrays.asList(entry.getValue());
						return new SimpleImmutableEntry<>(entry.getKey(), value);
					})
					.collect(Collectors.toSet());
		}

		@Override
		public int size() {
			return this.servletRequest.getParameterMap().size();
		}

		@Override
		public List get(Object key) {
			String name = (String) key;
			String[] parameterValues = this.servletRequest.getParameterValues(name);
			if (!ObjectUtils.isEmpty(parameterValues)) {
				return Arrays.asList(parameterValues);
			}
			else {
				return Collections.emptyList();
			}
		}

		@Override
		public List put(String key, List value) {
			throw new UnsupportedOperationException();
		}

		@Override
		public List remove(Object key) {
			throw new UnsupportedOperationException();
		}

		@Override
		public void clear() {
			throw new UnsupportedOperationException();
		}
	}


	private static final class ServletAttributesMap extends AbstractMap {

		private final HttpServletRequest servletRequest;

		private ServletAttributesMap(HttpServletRequest servletRequest) {
			this.servletRequest = servletRequest;
		}

		@Override
		public boolean containsKey(Object key) {
			String name = (String) key;
			return this.servletRequest.getAttribute(name) != null;
		}

		@Override
		public void clear() {
			List attributeNames = Collections.list(this.servletRequest.getAttributeNames());
			attributeNames.forEach(this.servletRequest::removeAttribute);
		}

		@Override
		public Set> entrySet() {
			return Collections.list(this.servletRequest.getAttributeNames()).stream()
					.map(name -> {
						Object value = this.servletRequest.getAttribute(name);
						return new SimpleImmutableEntry<>(name, value);
					})
					.collect(Collectors.toSet());
		}

		@Override
		public Object get(Object key) {
			String name = (String) key;
			return this.servletRequest.getAttribute(name);
		}

		@Override
		public Object put(String key, Object value) {
			Object oldValue = this.servletRequest.getAttribute(key);
			this.servletRequest.setAttribute(key, value);
			return oldValue;
		}

		@Override
		public Object remove(Object key) {
			String name = (String) key;
			Object value = this.servletRequest.getAttribute(name);
			this.servletRequest.removeAttribute(name);
			return value;
		}
	}


	/**
	 * Simple implementation of {@link HttpServletResponse} used by
	 * {@link #checkNotModified(HttpServletRequest, Instant, String)} to record status and headers set by
	 * {@link ServletWebRequest#checkNotModified(String, long)}. Throws an {@code UnsupportedOperationException}
	 * for other methods.
	 */
	private static final class CheckNotModifiedResponse implements HttpServletResponse {

		private final HttpHeaders headers = new HttpHeaders();

		private int status = 200;

		@Override
		public boolean containsHeader(String name) {
			return this.headers.containsKey(name);
		}

		@Override
		public void setDateHeader(String name, long date) {
			this.headers.setDate(name, date);
		}

		@Override
		public void setHeader(String name, String value) {
			this.headers.set(name, value);
		}

		@Override
		public void addHeader(String name, String value) {
			this.headers.add(name, value);
		}

		@Override
		public void setStatus(int sc) {
			this.status = sc;
		}

		@Override
		@Deprecated
		public void setStatus(int sc, String sm) {
			this.status = sc;
		}

		@Override
		public int getStatus() {
			return this.status;
		}

		@Override
		@Nullable
		public String getHeader(String name) {
			return this.headers.getFirst(name);
		}

		@Override
		public Collection getHeaders(String name) {
			List result = this.headers.get(name);
			return (result != null ? result : Collections.emptyList());
		}

		@Override
		public Collection getHeaderNames() {
			return this.headers.keySet();
		}


		// Unsupported

		@Override
		public void addCookie(Cookie cookie) {
			throw new UnsupportedOperationException();
		}

		@Override
		public String encodeURL(String url) {
			throw new UnsupportedOperationException();
		}

		@Override
		public String encodeRedirectURL(String url) {
			throw new UnsupportedOperationException();
		}

		@Override
		@Deprecated
		public String encodeUrl(String url) {
			throw new UnsupportedOperationException();
		}

		@Override
		@Deprecated
		public String encodeRedirectUrl(String url) {
			throw new UnsupportedOperationException();
		}

		@Override
		public void sendError(int sc, String msg) throws IOException {
			throw new UnsupportedOperationException();
		}

		@Override
		public void sendError(int sc) throws IOException {
			throw new UnsupportedOperationException();
		}

		@Override
		public void sendRedirect(String location) throws IOException {
			throw new UnsupportedOperationException();
		}

		@Override
		public void addDateHeader(String name, long date) {
			throw new UnsupportedOperationException();
		}

		@Override
		public void setIntHeader(String name, int value) {
			throw new UnsupportedOperationException();
		}

		@Override
		public void addIntHeader(String name, int value) {
			throw new UnsupportedOperationException();
		}


		@Override
		public String getCharacterEncoding() {
			throw new UnsupportedOperationException();
		}

		@Override
		public String getContentType() {
			throw new UnsupportedOperationException();
		}

		@Override
		public ServletOutputStream getOutputStream() throws IOException {
			throw new UnsupportedOperationException();
		}

		@Override
		public PrintWriter getWriter() throws IOException {
			throw new UnsupportedOperationException();
		}

		@Override
		public void setCharacterEncoding(String charset) {
			throw new UnsupportedOperationException();
		}

		@Override
		public void setContentLength(int len) {
			throw new UnsupportedOperationException();
		}

		@Override
		public void setContentLengthLong(long len) {
			throw new UnsupportedOperationException();
		}

		@Override
		public void setContentType(String type) {
			throw new UnsupportedOperationException();
		}

		@Override
		public void setBufferSize(int size) {
			throw new UnsupportedOperationException();
		}

		@Override
		public int getBufferSize() {
			throw new UnsupportedOperationException();
		}

		@Override
		public void flushBuffer() throws IOException {
			throw new UnsupportedOperationException();
		}

		@Override
		public void resetBuffer() {
			throw new UnsupportedOperationException();
		}

		@Override
		public boolean isCommitted() {
			throw new UnsupportedOperationException();
		}

		@Override
		public void reset() {
			throw new UnsupportedOperationException();
		}

		@Override
		public void setLocale(Locale loc) {
			throw new UnsupportedOperationException();
		}

		@Override
		public Locale getLocale() {
			throw new UnsupportedOperationException();
		}
	}

}