All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.springframework.web.client.DefaultRestClient Maven / Gradle / Ivy

There is a newer version: 6.1.13
Show newest version
/*
 * Copyright 2002-2023 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.web.client;

import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.net.URI;
import java.nio.charset.Charset;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ResolvableType;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRequest;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.StreamingHttpOutputMessage;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInitializer;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.InterceptingClientHttpRequestFactory;
import org.springframework.http.client.observation.ClientHttpObservationDocumentation;
import org.springframework.http.client.observation.ClientRequestObservationContext;
import org.springframework.http.client.observation.ClientRequestObservationConvention;
import org.springframework.http.client.observation.DefaultClientRequestObservationConvention;
import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.HttpMessageNotReadableException;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.util.UriBuilder;
import org.springframework.web.util.UriBuilderFactory;

/**
 * The default implementation of {@link RestClient},
 * as created by the static factory methods.
 *
 * @author Arjen Poutsma
 * @author Sebastien Deleuze
 * @since 6.1
 * @see RestClient#create()
 * @see RestClient#create(String)
 * @see RestClient#create(RestTemplate)
 */
final class DefaultRestClient implements RestClient {

	private static final Log logger = LogFactory.getLog(DefaultRestClient.class);

	private static final ClientRequestObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultClientRequestObservationConvention();


	private final ClientHttpRequestFactory clientRequestFactory;

	@Nullable
	private volatile ClientHttpRequestFactory interceptingRequestFactory;

	@Nullable
	private final List initializers;

	@Nullable
	private final List interceptors;

	private final UriBuilderFactory uriBuilderFactory;

	@Nullable
	private final HttpHeaders defaultHeaders;

	private final List defaultStatusHandlers;

	private final DefaultRestClientBuilder builder;

	private final List> messageConverters;

	private final ObservationRegistry observationRegistry;

	@Nullable
	private final ClientRequestObservationConvention observationConvention;


	DefaultRestClient(ClientHttpRequestFactory clientRequestFactory,
			@Nullable List interceptors,
			@Nullable List initializers,
			UriBuilderFactory uriBuilderFactory,
			@Nullable HttpHeaders defaultHeaders,
			@Nullable List statusHandlers,
			List> messageConverters,
			ObservationRegistry observationRegistry,
			@Nullable ClientRequestObservationConvention observationConvention,
			DefaultRestClientBuilder builder) {

		this.clientRequestFactory = clientRequestFactory;
		this.initializers = initializers;
		this.interceptors = interceptors;
		this.uriBuilderFactory = uriBuilderFactory;
		this.defaultHeaders = defaultHeaders;
		this.defaultStatusHandlers = (statusHandlers != null ? new ArrayList<>(statusHandlers) : new ArrayList<>());
		this.messageConverters = messageConverters;
		this.observationRegistry = observationRegistry;
		this.observationConvention = observationConvention;
		this.builder = builder;
	}

	@Override
	public RequestHeadersUriSpec get() {
		return methodInternal(HttpMethod.GET);
	}

	@Override
	public RequestHeadersUriSpec head() {
		return methodInternal(HttpMethod.HEAD);
	}

	@Override
	public RequestBodyUriSpec post() {
		return methodInternal(HttpMethod.POST);
	}

	@Override
	public RequestBodyUriSpec put() {
		return methodInternal(HttpMethod.PUT);
	}

	@Override
	public RequestBodyUriSpec patch() {
		return methodInternal(HttpMethod.PATCH);
	}

	@Override
	public RequestHeadersUriSpec delete() {
		return methodInternal(HttpMethod.DELETE);
	}

	@Override
	public RequestHeadersUriSpec options() {
		return methodInternal(HttpMethod.OPTIONS);
	}

	@Override
	public RequestBodyUriSpec method(HttpMethod method) {
		Assert.notNull(method, "HttpMethod must not be null");
		return methodInternal(method);
	}

	private RequestBodyUriSpec methodInternal(HttpMethod httpMethod) {
		return new DefaultRequestBodyUriSpec(httpMethod);
	}

	@Override
	public Builder mutate() {
		return new DefaultRestClientBuilder(this.builder);
	}

	@SuppressWarnings({"rawtypes", "unchecked"})
	private  T readWithMessageConverters(ClientHttpResponse clientResponse, Runnable callback, Type bodyType, Class bodyClass) {
		MediaType contentType = getContentType(clientResponse);

		try (clientResponse) {
			callback.run();

			for (HttpMessageConverter messageConverter : this.messageConverters) {
				if (messageConverter instanceof GenericHttpMessageConverter genericHttpMessageConverter) {
					if (genericHttpMessageConverter.canRead(bodyType, null, contentType)) {
						if (logger.isDebugEnabled()) {
							logger.debug("Reading to [" + ResolvableType.forType(bodyType) + "]");
						}
						return (T) genericHttpMessageConverter.read(bodyType, null, clientResponse);
					}
				}
				if (messageConverter.canRead(bodyClass, contentType)) {
					if (logger.isDebugEnabled()) {
						logger.debug("Reading to [" + bodyClass.getName() + "] as \"" + contentType + "\"");
					}
					return (T) messageConverter.read((Class)bodyClass, clientResponse);
				}
			}
			throw new UnknownContentTypeException(bodyType, contentType,
					clientResponse.getStatusCode(), clientResponse.getStatusText(),
					clientResponse.getHeaders(), RestClientUtils.getBody(clientResponse));
		}
		catch (UncheckedIOException | IOException | HttpMessageNotReadableException ex) {
			throw new RestClientException("Error while extracting response for type [" +
					ResolvableType.forType(bodyType) + "] and content type [" + contentType + "]", ex);
		}
	}

	private static MediaType getContentType(ClientHttpResponse clientResponse) {
		MediaType contentType = clientResponse.getHeaders().getContentType();
		if (contentType == null) {
			contentType = MediaType.APPLICATION_OCTET_STREAM;
		}
		return contentType;
	}

	@SuppressWarnings("unchecked")
	private static  Class bodyClass(Type type) {
		if (type instanceof Class clazz) {
			return (Class) clazz;
		}
		if (type instanceof ParameterizedType parameterizedType &&
				parameterizedType.getRawType() instanceof Class rawType) {
			return (Class) rawType;
		}
		return (Class) Object.class;
	}




	private class DefaultRequestBodyUriSpec implements RequestBodyUriSpec {

		private final HttpMethod httpMethod;

		@Nullable
		private URI uri;

		@Nullable
		private HttpHeaders headers;

		@Nullable
		private InternalBody body;

		@Nullable
		private String uriTemplate;

		@Nullable
		private Consumer httpRequestConsumer;

		public DefaultRequestBodyUriSpec(HttpMethod httpMethod) {
			this.httpMethod = httpMethod;
		}

		@Override
		public RequestBodySpec uri(String uriTemplate, Object... uriVariables) {
			this.uriTemplate = uriTemplate;
			return uri(DefaultRestClient.this.uriBuilderFactory.expand(uriTemplate, uriVariables));
		}

		@Override
		public RequestBodySpec uri(String uriTemplate, Map uriVariables) {
			this.uriTemplate = uriTemplate;
			return uri(DefaultRestClient.this.uriBuilderFactory.expand(uriTemplate, uriVariables));
		}

		@Override
		public RequestBodySpec uri(String uriTemplate, Function uriFunction) {
			this.uriTemplate = uriTemplate;
			return uri(uriFunction.apply(DefaultRestClient.this.uriBuilderFactory.uriString(uriTemplate)));
		}

		@Override
		public RequestBodySpec uri(Function uriFunction) {
			return uri(uriFunction.apply(DefaultRestClient.this.uriBuilderFactory.builder()));
		}

		@Override
		public RequestBodySpec uri(URI uri) {
			this.uri = uri;
			return this;
		}

		private HttpHeaders getHeaders() {
			if (this.headers == null) {
				this.headers = new HttpHeaders();
			}
			return this.headers;
		}

		@Override
		public DefaultRequestBodyUriSpec header(String headerName, String... headerValues) {
			for (String headerValue : headerValues) {
				getHeaders().add(headerName, headerValue);
			}
			return this;
		}

		@Override
		public DefaultRequestBodyUriSpec headers(Consumer headersConsumer) {
			headersConsumer.accept(getHeaders());
			return this;
		}

		@Override
		public DefaultRequestBodyUriSpec accept(MediaType... acceptableMediaTypes) {
			getHeaders().setAccept(Arrays.asList(acceptableMediaTypes));
			return this;
		}

		@Override
		public DefaultRequestBodyUriSpec acceptCharset(Charset... acceptableCharsets) {
			getHeaders().setAcceptCharset(Arrays.asList(acceptableCharsets));
			return this;
		}

		@Override
		public DefaultRequestBodyUriSpec contentType(MediaType contentType) {
			getHeaders().setContentType(contentType);
			return this;
		}

		@Override
		public DefaultRequestBodyUriSpec contentLength(long contentLength) {
			getHeaders().setContentLength(contentLength);
			return this;
		}

		@Override
		public DefaultRequestBodyUriSpec ifModifiedSince(ZonedDateTime ifModifiedSince) {
			getHeaders().setIfModifiedSince(ifModifiedSince);
			return this;
		}

		@Override
		public DefaultRequestBodyUriSpec ifNoneMatch(String... ifNoneMatches) {
			getHeaders().setIfNoneMatch(Arrays.asList(ifNoneMatches));
			return this;
		}

		@Override
		public RequestBodySpec httpRequest(Consumer requestConsumer) {
			this.httpRequestConsumer = (this.httpRequestConsumer != null ?
					this.httpRequestConsumer.andThen(requestConsumer) : requestConsumer);
			return this;
		}

		@Override
		public RequestBodySpec body(Object body) {
			this.body = clientHttpRequest -> writeWithMessageConverters(body, body.getClass(), clientHttpRequest);
			return this;
		}

		@Override
		public  RequestBodySpec body(T body, ParameterizedTypeReference bodyType) {
			this.body = clientHttpRequest -> writeWithMessageConverters(body, bodyType.getType(), clientHttpRequest);
			return this;
		}

		@Override
		public RequestBodySpec body(StreamingHttpOutputMessage.Body body) {
			this.body = request -> body.writeTo(request.getBody());
			return this;
		}

		@SuppressWarnings({"rawtypes", "unchecked"})
		private void writeWithMessageConverters(Object body, Type bodyType, ClientHttpRequest clientRequest)
				throws IOException {

			MediaType contentType = clientRequest.getHeaders().getContentType();
			Class bodyClass = body.getClass();

			for (HttpMessageConverter messageConverter : DefaultRestClient.this.messageConverters) {
				if (messageConverter instanceof GenericHttpMessageConverter genericMessageConverter) {
					if (genericMessageConverter.canWrite(bodyType, bodyClass, contentType)) {
						logBody(body, contentType, genericMessageConverter);
						genericMessageConverter.write(body, bodyType, contentType, clientRequest);
						return;
					}
				}
				if (messageConverter.canWrite(bodyClass, contentType)) {
					logBody(body, contentType, messageConverter);
					messageConverter.write(body, contentType, clientRequest);
					return;
				}
			}
			String message = "No HttpMessageConverter for " + bodyClass.getName();
			if (contentType != null) {
				message += " and content type \"" + contentType + "\"";
			}
			throw new RestClientException(message);
		}

		private void logBody(Object body, @Nullable MediaType mediaType, HttpMessageConverter converter) {
			if (logger.isDebugEnabled()) {
				StringBuilder msg = new StringBuilder("Writing [");
				msg.append(body);
				msg.append("] ");
				if (mediaType != null) {
					msg.append("as \"");
					msg.append(mediaType);
					msg.append("\" ");
				}
				msg.append("with ");
				msg.append(converter.getClass().getName());
				logger.debug(msg.toString());
			}
		}


		@Override
		public ResponseSpec retrieve() {
			return exchangeInternal(DefaultResponseSpec::new, false);
		}

		@Override
		public  T exchange(ExchangeFunction exchangeFunction, boolean close) {
			return exchangeInternal(exchangeFunction, close);
		}

		private  T exchangeInternal(ExchangeFunction exchangeFunction, boolean close) {
			Assert.notNull(exchangeFunction, "ExchangeFunction must not be null");

			ClientHttpResponse clientResponse = null;
			Observation observation = null;
			URI uri = null;
			try {
				uri = initUri();
				HttpHeaders headers = initHeaders();
				ClientHttpRequest clientRequest = createRequest(uri);
				clientRequest.getHeaders().addAll(headers);
				ClientRequestObservationContext observationContext = new ClientRequestObservationContext(clientRequest);
				observationContext.setUriTemplate(this.uriTemplate);
				observation = ClientHttpObservationDocumentation.HTTP_CLIENT_EXCHANGES.observation(observationConvention,
						DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, observationRegistry).start();
				if (this.body != null) {
					this.body.writeTo(clientRequest);
				}
				if (this.httpRequestConsumer != null) {
					this.httpRequestConsumer.accept(clientRequest);
				}
				clientResponse = clientRequest.execute();
				observationContext.setResponse(clientResponse);
				ConvertibleClientHttpResponse convertibleWrapper = new DefaultConvertibleClientHttpResponse(clientResponse);
				return exchangeFunction.exchange(clientRequest, convertibleWrapper);
			}
			catch (IOException ex) {
				ResourceAccessException resourceAccessException = createResourceAccessException(uri, this.httpMethod, ex);
				if (observation != null) {
					observation.error(resourceAccessException);
				}
				throw resourceAccessException;
			}
			catch (Throwable error) {
				if (observation != null) {
					observation.error(error);
				}
				throw error;
			}
			finally {
				if (close && clientResponse != null) {
					clientResponse.close();
				}
				if (observation != null) {
					observation.stop();
				}
			}
		}

		private URI initUri() {
			return (this.uri != null ? this.uri : DefaultRestClient.this.uriBuilderFactory.expand(""));
		}

		private HttpHeaders initHeaders() {
			HttpHeaders defaultHeaders = DefaultRestClient.this.defaultHeaders;
			if (CollectionUtils.isEmpty(this.headers)) {
				return (defaultHeaders != null ? defaultHeaders : new HttpHeaders());
			}
			else if (CollectionUtils.isEmpty(defaultHeaders)) {
				return this.headers;
			}
			else {
				HttpHeaders result = new HttpHeaders();
				result.putAll(defaultHeaders);
				result.putAll(this.headers);
				return result;
			}
		}

		private ClientHttpRequest createRequest(URI uri) throws IOException {
			ClientHttpRequestFactory factory;
			if (DefaultRestClient.this.interceptors != null) {
				factory = DefaultRestClient.this.interceptingRequestFactory;
				if (factory == null) {
					factory = new InterceptingClientHttpRequestFactory(DefaultRestClient.this.clientRequestFactory, DefaultRestClient.this.interceptors);
					DefaultRestClient.this.interceptingRequestFactory = factory;
				}
			}
			else {
				factory = DefaultRestClient.this.clientRequestFactory;
			}
			ClientHttpRequest request = factory.createRequest(uri, this.httpMethod);
			if (DefaultRestClient.this.initializers != null) {
				DefaultRestClient.this.initializers.forEach(initializer -> initializer.initialize(request));
			}
			return request;
		}

		private static ResourceAccessException createResourceAccessException(URI url, HttpMethod method, IOException ex) {
			StringBuilder msg = new StringBuilder("I/O error on ");
			msg.append(method.name());
			msg.append(" request for \"");
			String urlString = url.toString();
			int idx = urlString.indexOf('?');
			if (idx != -1) {
				msg.append(urlString, 0, idx);
			}
			else {
				msg.append(urlString);
			}
			msg.append("\": ");
			msg.append(ex.getMessage());
			return new ResourceAccessException(msg.toString(), ex);
		}


		@FunctionalInterface
		private interface InternalBody {

			void writeTo(ClientHttpRequest request) throws IOException;
		}
	}


	private class DefaultResponseSpec implements ResponseSpec {

		private final HttpRequest clientRequest;

		private final ClientHttpResponse clientResponse;

		private final List statusHandlers = new ArrayList<>(1);

		private final int defaultStatusHandlerCount;

		DefaultResponseSpec(HttpRequest clientRequest, ClientHttpResponse clientResponse) {
			this.clientRequest = clientRequest;
			this.clientResponse = clientResponse;
			this.statusHandlers.addAll(DefaultRestClient.this.defaultStatusHandlers);
			this.statusHandlers.add(StatusHandler.defaultHandler(DefaultRestClient.this.messageConverters));
			this.defaultStatusHandlerCount = this.statusHandlers.size();
		}

		@Override
		public ResponseSpec onStatus(Predicate statusPredicate, ErrorHandler errorHandler) {
			Assert.notNull(statusPredicate, "StatusPredicate must not be null");
			Assert.notNull(errorHandler, "ErrorHandler must not be null");

			return onStatusInternal(StatusHandler.of(statusPredicate, errorHandler));
		}

		@Override
		public ResponseSpec onStatus(ResponseErrorHandler errorHandler) {
			Assert.notNull(errorHandler, "ResponseErrorHandler must not be null");

			return onStatusInternal(StatusHandler.fromErrorHandler(errorHandler));
		}

		private ResponseSpec onStatusInternal(StatusHandler statusHandler) {
			Assert.notNull(statusHandler, "StatusHandler must not be null");

			int index = this.statusHandlers.size() - this.defaultStatusHandlerCount;  // Default handlers always last
			this.statusHandlers.add(index, statusHandler);
			return this;
		}

		@Override
		public  T body(Class bodyType) {
			return readBody(bodyType, bodyType);
		}

		@Override
		public  T body(ParameterizedTypeReference bodyType) {
			Type type = bodyType.getType();
			Class bodyClass = bodyClass(type);
			return readBody(type, bodyClass);
		}

		@Override
		public  ResponseEntity toEntity(Class bodyType) {
			return toEntityInternal(bodyType, bodyType);
		}

		@Override
		public  ResponseEntity toEntity(ParameterizedTypeReference bodyType) {
			Type type = bodyType.getType();
			Class bodyClass = bodyClass(type);
			return toEntityInternal(type, bodyClass);
		}

		private  ResponseEntity toEntityInternal(Type bodyType, Class bodyClass) {
			T body = readBody(bodyType, bodyClass);
			try {
				return ResponseEntity.status(this.clientResponse.getStatusCode())
						.headers(this.clientResponse.getHeaders())
						.body(body);
			}
			catch (IOException ex) {
				throw new ResourceAccessException("Could not retrieve response status code: " + ex.getMessage(), ex);
			}
		}

		@Override
		public ResponseEntity toBodilessEntity() {
			try (this.clientResponse) {
				applyStatusHandlers();
				return ResponseEntity.status(this.clientResponse.getStatusCode())
						.headers(this.clientResponse.getHeaders())
						.build();
			}
			catch (UncheckedIOException ex) {
				throw new ResourceAccessException("Could not retrieve response status code: " + ex.getMessage(), ex.getCause());
			}
			catch (IOException ex) {
				throw new ResourceAccessException("Could not retrieve response status code: " + ex.getMessage(), ex);
			}
		}


		private  T readBody(Type bodyType, Class bodyClass) {
			return DefaultRestClient.this.readWithMessageConverters(this.clientResponse, this::applyStatusHandlers,
					bodyType, bodyClass);

		}

		private void applyStatusHandlers() {
			try {
				ClientHttpResponse response = this.clientResponse;
				if (response instanceof DefaultConvertibleClientHttpResponse convertibleResponse) {
					response = convertibleResponse.delegate;
				}
				for (StatusHandler handler : this.statusHandlers) {
					if (handler.test(response)) {
						handler.handle(this.clientRequest, response);
						return;
					}
				}
			}
			catch (IOException ex) {
				throw new UncheckedIOException(ex);
			}
		}
	}


	private class DefaultConvertibleClientHttpResponse implements RequestHeadersSpec.ConvertibleClientHttpResponse {

		private final ClientHttpResponse delegate;


		public DefaultConvertibleClientHttpResponse(ClientHttpResponse delegate) {
			this.delegate = delegate;
		}


		@Nullable
		@Override
		public  T bodyTo(Class bodyType) {
			return readWithMessageConverters(this.delegate, () -> {} , bodyType, bodyType);
		}

		@Nullable
		@Override
		public  T bodyTo(ParameterizedTypeReference bodyType) {
			Type type = bodyType.getType();
			Class bodyClass = bodyClass(type);
			return readWithMessageConverters(this.delegate, () -> {} , type, bodyClass);
		}

		@Override
		public InputStream getBody() throws IOException {
			return this.delegate.getBody();
		}

		@Override
		public HttpHeaders getHeaders() {
			return this.delegate.getHeaders();
		}

		@Override
		public HttpStatusCode getStatusCode() throws IOException {
			return this.delegate.getStatusCode();
		}

		@Override
		public String getStatusText() throws IOException {
			return this.delegate.getStatusText();
		}

		@Override
		public void close() {
			this.delegate.close();
		}

	}


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy