All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.restassured.module.webtestclient.internal.WebTestClientRequestSenderImpl Maven / Gradle / Ivy

/*
 * Copyright 2019 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.restassured.module.webtestclient.internal;

import io.restassured.RestAssured;
import io.restassured.authentication.NoAuthScheme;
import io.restassured.builder.MultiPartSpecBuilder;
import io.restassured.filter.Filter;
import io.restassured.filter.log.RequestLoggingFilter;
import io.restassured.http.Cookies;
import io.restassured.http.Header;
import io.restassured.http.Headers;
import io.restassured.http.Method;
import io.restassured.internal.RequestSpecificationImpl;
import io.restassured.internal.ResponseParserRegistrar;
import io.restassured.internal.ResponseSpecificationImpl;
import io.restassured.internal.filter.FilterContextImpl;
import io.restassured.internal.log.LogRepository;
import io.restassured.internal.multipart.MultiPartInternal;
import io.restassured.internal.support.PathSupport;
import io.restassured.internal.util.SafeExceptionRethrower;
import io.restassured.module.spring.commons.HeaderHelper;
import io.restassured.module.spring.commons.ParamApplier;
import io.restassured.module.spring.commons.config.ConfigConverter;
import io.restassured.module.webtestclient.config.RestAssuredWebTestClientConfig;
import io.restassured.module.webtestclient.response.WebTestClientResponse;
import io.restassured.module.webtestclient.specification.WebTestClientRequestSender;
import io.restassured.specification.RequestSpecification;
import io.restassured.specification.ResponseSpecification;
import org.apache.commons.codec.Charsets;
import org.apache.commons.lang3.ArrayUtils;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.MultipartBodyBuilder;
import org.springframework.test.web.reactive.server.EntityExchangeResult;
import org.springframework.test.web.reactive.server.FluxExchangeResult;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.util.MultiValueMap;
import org.springframework.web.util.UriBuilder;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UriUtils;

import java.io.File;
import java.net.URI;
import java.net.URL;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;

import static io.restassured.internal.common.assertion.AssertParameter.notNull;
import static io.restassured.internal.support.PathSupport.mergeAndRemoveDoubleSlash;
import static io.restassured.module.spring.commons.BodyHelper.toByteArray;
import static io.restassured.module.spring.commons.HeaderHelper.mapToArray;
import static io.restassured.module.spring.commons.RequestLogger.logParamsAndHeaders;
import static io.restassured.module.spring.commons.RequestLogger.logRequestBody;
import static java.util.Optional.ofNullable;
import static org.apache.commons.lang3.StringUtils.*;
import static org.springframework.http.HttpMethod.*;
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
import static org.springframework.http.MediaType.parseMediaType;

public class WebTestClientRequestSenderImpl implements WebTestClientRequestSender {

	private static final String CONTENT_TYPE = "Content-Type";
	private static final Pattern PATH_PARAM_PATTERN = Pattern.compile("\\{([^/]+?)\\}");

	private final WebTestClient webTestClient;
	private final Map params;
	private final Map namedPathParams;
	private final Map queryParams;
	private final Map formParams;
	private final Map attributes;
	private final RestAssuredWebTestClientConfig config;
	private final Object requestBody;
	private final Cookies cookies;
	private final List multiParts;
	private final String basePath;
	private final ResponseSpecification responseSpecification;
	private final LogRepository logRepository;
	private Headers headers;
	private final RequestLoggingFilter requestLoggingFilter;
	private Consumer> consumer;

	@Override
	public WebTestClientResponse get(Function uriFunction) {
		return sendRequest(GET, uriFunction);
	}

	@Override
	public WebTestClientResponse post(Function uriFunction) {
		return sendRequest(POST, uriFunction);
	}

	@Override
	public WebTestClientResponse put(Function uriFunction) {
		return sendRequest(PUT, uriFunction);
	}

	@Override
	public WebTestClientResponse delete(Function uriFunction) {
		return sendRequest(DELETE, uriFunction);
	}

	@Override
	public WebTestClientResponse patch(Function uriFunction) {
		return sendRequest(PATCH, uriFunction);
	}

	@Override
	public WebTestClientResponse head(Function uriFunction) {
		return sendRequest(HEAD, uriFunction);
	}

	@Override
	public WebTestClientResponse options(Function uriFunction) {
		return sendRequest(OPTIONS, uriFunction);
	}

	@Override
	public WebTestClientResponse request(Method method, Function uriFunction) {
		return request(notNull(method, Method.class).name(), uriFunction);
	}

	@Override
	public WebTestClientResponse request(String method, Function uriFunction) {
		return sendRequest(toValidHttpMethod(method), uriFunction);
	}

	private HttpMethod toValidHttpMethod(String method) {
		String httpMethodAsString = notNull(trimToNull(method), "HTTP Method");
		try {
			return HttpMethod.valueOf(httpMethodAsString.toUpperCase());
		} catch (IllegalArgumentException e) {
			throw new IllegalArgumentException(String.format("HTTP method '%s' is not supported by WebTestClient", method));
		}
	}

	@Override
	public WebTestClientResponse get(String path, Object... pathParams) {
		return sendRequest(GET, path, pathParams);
	}

	public WebTestClientRequestSenderImpl(
			WebTestClient webTestClient,
			Map params,
			Map namedPathParams,
			Map queryParams,
			Map formParams,
			Map attributes,
			RestAssuredWebTestClientConfig config,
			Object requestBody,
			Headers headers,
			Cookies cookies,
			List multiParts,
			RequestLoggingFilter requestLoggingFilter,
			String basePath,
			ResponseSpecification responseSpecification,
			LogRepository logRepository
	) {
		this.webTestClient = webTestClient;
		this.params = params;
		this.namedPathParams = namedPathParams;
		this.queryParams = queryParams;
		this.formParams = formParams;
		this.attributes = attributes;
		this.config = config;
		this.requestBody = requestBody;
		this.headers = headers;
		this.cookies = cookies;
		this.multiParts = multiParts;
		this.basePath = basePath;
		this.responseSpecification = responseSpecification;
		this.logRepository = logRepository;
		this.requestLoggingFilter = requestLoggingFilter;
	}

	@Override
	public WebTestClientRequestSender consumeWith(Consumer> consumer) {
		this.consumer = consumer;
		return this;
	}

	@Override
	public WebTestClientResponse get(String path, Map pathParams) {
		return get(path, mapToArray(pathParams));
	}

	@Override
	public WebTestClientResponse post(String path, Object... pathParams) {
		return sendRequest(POST, path, pathParams);
	}

	@Override
	public WebTestClientResponse post(String path, Map pathParams) {
		return post(path, mapToArray(pathParams));
	}

	@Override
	public WebTestClientResponse put(String path, Object... pathParams) {
		return sendRequest(PUT, path, pathParams);
	}

	@Override
	public WebTestClientResponse put(String path, Map pathParams) {
		return put(path, mapToArray(pathParams));
	}

	@Override
	public WebTestClientResponse delete(String path, Object... pathParams) {
		return sendRequest(DELETE, path, pathParams);
	}

	@Override
	public WebTestClientResponse delete(String path, Map pathParams) {
		return delete(path, mapToArray(pathParams));
	}

	@Override
	public WebTestClientResponse head(String path, Object... pathParams) {
		return sendRequest(HEAD, path, pathParams);
	}

	@Override
	public WebTestClientResponse head(String path, Map pathParams) {
		return head(path, mapToArray(pathParams));
	}

	@Override
	public WebTestClientResponse patch(String path, Object... pathParams) {
		return sendRequest(PATCH, path, pathParams);
	}

	@Override
	public WebTestClientResponse patch(String path, Map pathParams) {
		return patch(path, mapToArray(pathParams));
	}

	@Override
	public WebTestClientResponse options(String path, Object... pathParams) {
		return sendRequest(OPTIONS, path, pathParams);
	}

	@Override
	public WebTestClientResponse options(String path, Map pathParams) {
		return options(path, mapToArray(pathParams));
	}

	@Override
	public WebTestClientResponse get(URI uri) {
		return get(uri.toString());
	}

	@Override
	public WebTestClientResponse post(URI uri) {
		return post(uri.toString());
	}

	@Override
	public WebTestClientResponse put(URI uri) {
		return put(uri.toString());
	}

	@Override
	public WebTestClientResponse delete(URI uri) {
		return delete(uri.toString());
	}

	@Override
	public WebTestClientResponse head(URI uri) {
		return head(uri.toString());
	}

	@Override
	public WebTestClientResponse patch(URI uri) {
		return patch(uri.toString());
	}

	@Override
	public WebTestClientResponse options(URI uri) {
		return options(uri.toString());
	}

	@Override
	public WebTestClientResponse get(URL url) {
		return get(url.toString());
	}

	@Override
	public WebTestClientResponse post(URL url) {
		return post(url.toString());
	}

	@Override
	public WebTestClientResponse put(URL url) {
		return put(url.toString());
	}

	@Override
	public WebTestClientResponse delete(URL url) {
		return delete(url.toString());
	}

	@Override
	public WebTestClientResponse head(URL url) {
		return head(url.toString());
	}

	@Override
	public WebTestClientResponse patch(URL url) {
		return patch(url.toString());
	}

	@Override
	public WebTestClientResponse options(URL url) {
		return options(url.toString());
	}

	@Override
	public WebTestClientResponse get() {
		return get("");
	}

	@Override
	public WebTestClientResponse post() {
		return post("");
	}

	@Override
	public WebTestClientResponse put() {
		return put("");
	}

	@Override
	public WebTestClientResponse delete() {
		return delete("");
	}

	@Override
	public WebTestClientResponse head() {
		return head("");
	}

	@Override
	public WebTestClientResponse patch() {
		return patch("");
	}

	@Override
	public WebTestClientResponse options() {
		return options("");
	}

	@Override
	public WebTestClientResponse request(Method method) {
		return request(method, "");
	}

	@Override
	public WebTestClientResponse request(String method) {
		return request(method, "");
	}

	@Override
	public WebTestClientResponse request(Method method, String path, Object... pathParams) {
		return request(notNull(method, Method.class).name(), path, pathParams);
	}

	@Override
	public WebTestClientResponse request(String method, String path, Object... pathParams) {
		return sendRequest(toValidHttpMethod(method), path, pathParams);
	}

	@Override
	public WebTestClientResponse request(Method method, URI uri) {
		return request(method, notNull(uri, URI.class).toString());
	}

	@Override
	public WebTestClientResponse request(Method method, URL url) {
		return request(method, notNull(url, URL.class).toString());
	}

	@Override
	public WebTestClientResponse request(String method, URI uri) {
		return request(method, notNull(uri, URI.class).toString());
	}

	@Override
	public WebTestClientResponse request(String method, URL url) {
		return request(method, notNull(url, URL.class).toString());
	}

	private WebTestClientResponse sendRequest(HttpMethod method, String path, Object[] unnamedPathParams) {
		String requestContentType = HeaderHelper.findContentType(headers, (List) (List) multiParts, config);
		WebTestClient.RequestBodySpec requestBodySpec = buildFromPath(method, requestContentType, path, unnamedPathParams);
		addRequestElements(method, requestContentType, requestBodySpec);
		logRequestIfApplicable(method, getBaseUri(path), path, unnamedPathParams);
		return performRequest(requestBodySpec);
	}

	private WebTestClient.RequestBodySpec buildFromPath(
			HttpMethod method,
			String requestContentType,
			String path,
			Object[] unnamedPathParams
	) {
		notNull(path, "Path");
		String baseUri = getBaseUri(path);
		final UriContainer uriContainer = buildUri(method, requestContentType, baseUri, unnamedPathParams);
		return webTestClient.method(method).uri(uriContainer.getUri(), uriContainer.getUriVariables());
	}

	private void addRequestElements(HttpMethod method, String requestContentType, WebTestClient.RequestBodySpec requestBodySpec) {
		verifyNoBodyAndMultipartTogether();
		if (isNotBlank(requestContentType)) {
			requestBodySpec.contentType(parseMediaType(requestContentType));
		}
		applyRequestBody(requestBodySpec);
		applyMultipartBody(method, requestBodySpec);
		applyAttributes(requestBodySpec);
		headers.forEach(header -> requestBodySpec.header(header.getName(), header.getValue()));
		cookies.asList().forEach(cookie -> requestBodySpec.cookie(cookie.getName(), cookie.getValue()));
	}

	private void logRequestIfApplicable(HttpMethod method, String uri, String originalPath, Object[] unnamedPathParams) {
		if (requestLoggingFilter == null) {
			return;
		}
		final RequestSpecificationImpl reqSpec = new RequestSpecificationImpl("http://localhost",
				RestAssured.UNDEFINED_PORT, "", new NoAuthScheme(), Collections.emptyList(),
				null, true, ConfigConverter.convertToRestAssuredConfig(config), logRepository, null, true, true);
		logParamsAndHeaders(reqSpec, method.toString(), uri, unnamedPathParams, params, namedPathParams, queryParams, formParams, headers, cookies);
		logRequestBody(reqSpec, requestBody, headers, (List) (List) multiParts, config);
		ofNullable(multiParts).map(List::stream).orElseGet(Stream::empty)
				.forEach(multiPart -> addMultipartToReqSpec(reqSpec, multiPart));
		String originalUriPath = PathSupport.getPath(originalPath);
		String uriPath = PathSupport.getPath(uri);
		requestLoggingFilter.filter(reqSpec, null, new FilterContextImpl(uri, originalUriPath,
				uriPath, uri, uri, new Object[0], method.toString(), null,
				Collections.emptyList().iterator(), new HashMap<>()));
	}

	private String getBaseUri(String path) {
		String baseUri;
		if (isNotBlank(basePath)) {
			baseUri = mergeAndRemoveDoubleSlash(basePath, path);
		} else {
			baseUri = path;
		}
		return baseUri;
	}

	private WebTestClientResponse performRequest(WebTestClient.RequestBodySpec requestBuilder) {
		FluxExchangeResult result;
		WebTestClientRestAssuredResponseImpl restAssuredResponse;
		try {
			final long start = System.currentTimeMillis();
			WebTestClient.ResponseSpec responseSpec = requestBuilder.exchange();
			final long responseTime = System.currentTimeMillis() - start;
			result = responseSpec.returnResult(byte[].class);
			restAssuredResponse = new ExchangeResultConverter().toRestAssuredResponse(result, responseSpec, responseTime,
					logRepository, config, consumer, getRpr());
			if (responseSpecification != null) {
				responseSpecification.validate(ResponseConverter.toStandardResponse(restAssuredResponse));
			}
		} catch (Exception e) {
			return SafeExceptionRethrower.safeRethrow(e);
		}
		return restAssuredResponse;
	}

	private UriContainer buildUri(HttpMethod method, String requestContentType, String baseUri, Object[] unnamedPathParams) {
		final UriComponentsBuilder uriComponentsBuilder = UriComponentsBuilder.fromUriString(baseUri);
		final UriContainer.Builder uriContainerBuilder = UriContainer.newBuilder(baseUri);

		applyQueryParams(uriComponentsBuilder);
		applyPathParams(uriContainerBuilder, baseUri, unnamedPathParams);
		applyParams(method, uriComponentsBuilder, requestContentType);
		applyFormParams(method, uriComponentsBuilder, requestContentType);

		final String uriWithoutPathParams = uriComponentsBuilder.cloneBuilder()
				.uriVariables(Collections.emptyMap())
				.build(false)
				.toUriString();

		return uriContainerBuilder.uri(uriWithoutPathParams).build();
	}

	private void verifyNoBodyAndMultipartTogether() {
		if (requestBody != null && !multiParts.isEmpty()) {
			throw new IllegalStateException("You cannot specify a request body and a multi-part body in the same request." +
					" Perhaps you want to change the body to a multi part?");
		}
	}

	private void applyRequestBody(WebTestClient.RequestBodySpec requestBodySpec) {
		if (requestBody != null) {
			if (requestBody instanceof byte[]) {
				requestBodySpec.syncBody(requestBody);
			} else if (requestBody instanceof File) {
				byte[] bytes = toByteArray((File) requestBody);
				requestBodySpec.syncBody(bytes);
			} else {
				requestBodySpec.syncBody(requestBody.toString());
			}
		}
	}

	private void applyMultipartBody(HttpMethod method, WebTestClient.RequestBodySpec requestBodySpec) {
		if (!multiParts.isEmpty()) {
			if (method != POST && method != PUT && method != PATCH) {
				throw new IllegalArgumentException("Currently multi-part file data uploading only works for POST, PUT and PATCH method.");
			}
			requestBodySpec.syncBody(getMultipartBody());
		}
	}

	private void applyAttributes(WebTestClient.RequestBodySpec requestBodySpec) {
		if (!attributes.isEmpty()) {
			new ParamApplier(attributes) {
				@Override
				protected void applyParam(String paramName, String[] paramValues) {
					requestBodySpec.attribute(paramName, paramValues[0]);
				}
			}.applyParams();
		}
	}

	private void addMultipartToReqSpec(RequestSpecification requestSpecification, MultiPartInternal multiPart) {
		requestSpecification.multiPart(new MultiPartSpecBuilder(multiPart.getContent())
				.controlName(multiPart.getControlName()).
						fileName(multiPart.getFileName()).
						mimeType(multiPart.getMimeType()).
						build());
	}

	private void applyQueryParams(UriComponentsBuilder uriComponentsBuilder) {
		if (!queryParams.isEmpty()) {
			new ParamApplier(queryParams) {
				@Override
				protected void applyParam(String paramName, String[] paramValues) {
					uriComponentsBuilder.queryParam(paramName, paramValues);
				}
			}.applyParams();
		}
	}

	private void applyPathParams(
			final UriContainer.Builder uriContainerBuilder,
			final String baseUri,
			final Object[] unnamedPathParams
	) {
		final Matcher pathParamMatcher = PATH_PARAM_PATTERN.matcher(baseUri);
		if (!pathParamMatcher.find()) {
			return;
		}

		if (namedPathParams.isEmpty() && ArrayUtils.isEmpty(unnamedPathParams)) {
			throw new IllegalArgumentException("No values were found for the request's pathParams.");
		}

		final AtomicInteger nextUnnamedPathParamIndex = new AtomicInteger(0);
		final Function> getPathParamValueFunction = param -> {
			if (namedPathParams.containsKey(param)) {
				return Optional.of(namedPathParams.get(param));
			}

			if (unnamedPathParams.length > 0) {
				return Optional.of(unnamedPathParams[nextUnnamedPathParamIndex.getAndIncrement()]);
			}

			return Optional.empty();
		};

		final Map uriVariables = new HashMap<>();
		do {
			final String paramName = pathParamMatcher.group(1);
			getPathParamValueFunction.apply(paramName).ifPresent(paramValue ->
					uriVariables.put(paramName, UriUtils.encode((String) paramValue, Charsets.UTF_8))
			);
		} while (pathParamMatcher.find());

		uriContainerBuilder.uriVariables(uriVariables);
	}

	private void applyParams(HttpMethod method, UriComponentsBuilder uriComponentsBuilder, String requestContentType) {
		if (!params.isEmpty()) {
			new ParamApplier(params) {
				@Override
				protected void applyParam(String paramName, String[] paramValues) {
					uriComponentsBuilder.queryParam(paramName, paramValues);
				}
			}.applyParams();

			if (isBlank(requestContentType) && method == POST && !isMultipartRequest()) {
				setContentTypeToApplicationFormUrlEncoded();
			}
		}
	}

	private void applyFormParams(HttpMethod method, UriComponentsBuilder uriComponentsBuilder, String requestContentType) {
		if (!formParams.isEmpty()) {
			if (method == GET) {
				throw new IllegalArgumentException("Cannot use form parameters in a GET request");
			}
			new ParamApplier(formParams) {
				@Override
				protected void applyParam(String paramName, String[] paramValues) {
					uriComponentsBuilder.queryParam(paramName, paramValues);
				}
			}.applyParams();
			if (isBlank(requestContentType) && !isMultipartRequest()) {
				setContentTypeToApplicationFormUrlEncoded();
			}
		}
	}

	private MultiValueMap> getMultipartBody() {
		MultipartBodyBuilder multipartBodyBuilder = new MultipartBodyBuilder();
		multiParts.stream().forEach(multipart -> multipartBodyBuilder.part(multipart.getFileName(), multipart.getContentBody()));
		return multipartBodyBuilder.build();
	}

	private boolean isMultipartRequest() {
		return !multiParts.isEmpty();
	}

	private void setContentTypeToApplicationFormUrlEncoded() {
		String requestContentType = parseMediaType(HeaderHelper.buildApplicationFormEncodedContentType(config,
				APPLICATION_FORM_URLENCODED_VALUE)).toString();
		List
newHeaders = new ArrayList<>(headers.asList()); newHeaders.add(new Header(CONTENT_TYPE, requestContentType)); headers = new Headers(newHeaders); } private WebTestClientResponse sendRequest(HttpMethod method, Function uriFunction) { String requestContentType = HeaderHelper.findContentType(headers, (List) (List) multiParts, config); WebTestClient.RequestBodySpec requestBodySpec = buildFromUriFunction(method, uriFunction); addRequestElements(method, requestContentType, requestBodySpec); logRequestIfApplicable(method, uriFunction); return performRequest(requestBodySpec); } private WebTestClient.RequestBodySpec buildFromUriFunction(HttpMethod method, Function uriFunction) { return webTestClient.method(method).uri(uriFunction); } private void logRequestIfApplicable(HttpMethod method, Function uriFunction) { if (requestLoggingFilter == null) { return; } final RequestSpecificationImpl reqSpec = new RequestSpecificationImpl("http://localhost", RestAssured.UNDEFINED_PORT, "", new NoAuthScheme(), Collections.emptyList(), null, true, ConfigConverter.convertToRestAssuredConfig(config), logRepository, null, true, true); logParamsAndHeaders(reqSpec, method.toString(), "Request from uri function" + uriFunction.toString(), null, params, namedPathParams, queryParams, formParams, headers, cookies); logRequestBody(reqSpec, requestBody, headers, (List) (List) multiParts, config); ofNullable(multiParts).map(List::stream).orElseGet(Stream::empty) .forEach(multiPart -> addMultipartToReqSpec(reqSpec, multiPart)); requestLoggingFilter.filter(reqSpec, null, new FilterContextImpl("Request from uri function" + uriFunction.toString(), null, null, null, null, new Object[0], method.toString(), null, Collections.emptyList().iterator(), new HashMap<>())); } private ResponseParserRegistrar getRpr() { if (responseSpecification instanceof ResponseSpecificationImpl) { return ((ResponseSpecificationImpl) responseSpecification).getRpr(); } return new ResponseParserRegistrar(); } }