All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.springframework.messaging.rsocket.DefaultRSocketRequester Maven / Gradle / Ivy

There is a newer version: 6.2.0
Show newest version
/*
 * Copyright 2002-2023 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.messaging.rsocket;

import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;

import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.core.RSocketClient;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;

/**
 * Default implementation of {@link RSocketRequester}.
 *
 * @author Rossen Stoyanchev
 * @since 5.2
 */
final class DefaultRSocketRequester implements RSocketRequester {

	private static final Map EMPTY_HINTS = Collections.emptyMap();

	private final RSocketClient rsocketClient;

	@Nullable
	private final RSocket rsocket;

	private final MimeType dataMimeType;

	private final MimeType metadataMimeType;

	private final RSocketStrategies strategies;

	private final Mono emptyBufferMono;


	DefaultRSocketRequester(
			@Nullable RSocketClient rsocketClient, @Nullable RSocket rsocket,
			MimeType dataMimeType, MimeType metadataMimeType, RSocketStrategies strategies) {

		Assert.isTrue(rsocketClient != null || rsocket != null, "RSocketClient or RSocket is required");
		Assert.notNull(dataMimeType, "'dataMimeType' is required");
		Assert.notNull(metadataMimeType, "'metadataMimeType' is required");
		Assert.notNull(strategies, "RSocketStrategies is required");

		this.rsocketClient = (rsocketClient != null ? rsocketClient : RSocketClient.from(rsocket));
		this.rsocket = rsocket;
		this.dataMimeType = dataMimeType;
		this.metadataMimeType = metadataMimeType;
		this.strategies = strategies;
		this.emptyBufferMono = Mono.just(this.strategies.dataBufferFactory().wrap(new byte[0]));
	}


	@Override
	public RSocketClient rsocketClient() {
		return this.rsocketClient;
	}

	@Nullable
	@Override
	public RSocket rsocket() {
		return this.rsocket;
	}

	@Override
	public MimeType dataMimeType() {
		return this.dataMimeType;
	}

	@Override
	public MimeType metadataMimeType() {
		return this.metadataMimeType;
	}

	@Override
	public RSocketStrategies strategies() {
		return this.strategies;
	}


	@Override
	public RequestSpec route(String route, Object... vars) {
		return new DefaultRequestSpec(route, vars);
	}

	@Override
	public RequestSpec metadata(Object metadata, @Nullable MimeType mimeType) {
		return new DefaultRequestSpec(metadata, mimeType);
	}

	private static boolean isVoid(ResolvableType elementType) {
		return (Void.class.equals(elementType.resolve()) || void.class.equals(elementType.resolve()));
	}

	private DataBufferFactory bufferFactory() {
		return this.strategies.dataBufferFactory();
	}


	private class DefaultRequestSpec implements RequestSpec {

		private final MetadataEncoder metadataEncoder = new MetadataEncoder(metadataMimeType(), strategies);

		@Nullable
		private Mono payloadMono;

		@Nullable
		private Flux payloadFlux;


		public DefaultRequestSpec(String route, Object... vars) {
			this.metadataEncoder.route(route, vars);
		}

		public DefaultRequestSpec(Object metadata, @Nullable MimeType mimeType) {
			this.metadataEncoder.metadata(metadata, mimeType);
		}


		@Override
		public RequestSpec metadata(Object metadata, MimeType mimeType) {
			this.metadataEncoder.metadata(metadata, mimeType);
			return this;
		}

		@Override
		public RequestSpec metadata(Consumer> configurer) {
			configurer.accept(this);
			return this;
		}

		@Override
		public RequestSpec data(Object data) {
			Assert.notNull(data, "'data' must not be null");
			createPayload(data, ResolvableType.NONE);
			return this;
		}

		@Override
		public RequestSpec data(Object producer, Class elementClass) {
			Assert.notNull(producer, "'producer' must not be null");
			Assert.notNull(elementClass, "'elementClass' must not be null");
			ReactiveAdapter adapter = getAdapter(producer.getClass());
			Assert.notNull(adapter, () -> "'producer' type is unknown to ReactiveAdapterRegistry: " +
					producer.getClass().getName());
			createPayload(adapter.toPublisher(producer), ResolvableType.forClass(elementClass));
			return this;
		}

		@Nullable
		private ReactiveAdapter getAdapter(Class aClass) {
			return strategies.reactiveAdapterRegistry().getAdapter(aClass);
		}

		@Override
		public RequestSpec data(Object producer, ParameterizedTypeReference elementTypeRef) {
			Assert.notNull(producer, "'producer' must not be null");
			Assert.notNull(elementTypeRef, "'elementTypeRef' must not be null");
			ReactiveAdapter adapter = getAdapter(producer.getClass());
			Assert.notNull(adapter, () -> "'producer' type is unknown to ReactiveAdapterRegistry: " +
					producer.getClass().getName());
			createPayload(adapter.toPublisher(producer), ResolvableType.forType(elementTypeRef));
			return this;
		}

		private void createPayload(Object input, ResolvableType elementType) {
			ReactiveAdapter adapter = getAdapter(input.getClass());
			Publisher publisher;
			if (input instanceof Publisher _publisher) {
				publisher = _publisher;
			}
			else if (adapter != null) {
				publisher = adapter.toPublisher(input);
			}
			else {
				ResolvableType type = ResolvableType.forInstance(input);
				this.payloadMono = firstPayload(Mono.fromCallable(() -> encodeData(input, type, null)));
				this.payloadFlux = null;
				return;
			}

			if (isVoid(elementType) || (adapter != null && adapter.isNoValue())) {
				this.payloadMono = Mono.when(publisher).then(firstPayload(emptyBufferMono));
				this.payloadFlux = null;
				return;
			}

			Encoder encoder = elementType != ResolvableType.NONE && !Object.class.equals(elementType.resolve()) ?
					strategies.encoder(elementType, dataMimeType) : null;

			if (adapter != null && !adapter.isMultiValue()) {
				Mono data = Mono.from(publisher)
						.map(value -> encodeData(value, elementType, encoder))
						.switchIfEmpty(emptyBufferMono);
				this.payloadMono = firstPayload(data);
				this.payloadFlux = null;
				return;
			}

			this.payloadMono = null;
			this.payloadFlux = Flux.from(publisher)
					.map(value -> encodeData(value, elementType, encoder))
					.switchIfEmpty(emptyBufferMono)
					.switchOnFirst((signal, inner) -> {
						DataBuffer data = signal.get();
						if (data != null) {
							return firstPayload(Mono.fromCallable(() -> data))
									.concatWith(inner.skip(1).map(PayloadUtils::createPayload));
						}
						else {
							return inner.map(PayloadUtils::createPayload);
						}
					})
					.doOnDiscard(Payload.class, Payload::release);
		}

		@SuppressWarnings("unchecked")
		private  DataBuffer encodeData(T value, ResolvableType elementType, @Nullable Encoder encoder) {
			if (encoder == null) {
				elementType = ResolvableType.forInstance(value);
				encoder = strategies.encoder(elementType, dataMimeType);
			}
			return ((Encoder) encoder).encodeValue(
					value, bufferFactory(), elementType, dataMimeType, EMPTY_HINTS);
		}

		/**
		 * Create the 1st request payload with encoded data and metadata.
		 * @param encodedData the encoded payload data; expected to not be empty!
		 */
		private Mono firstPayload(Mono encodedData) {
			return Mono.zip(encodedData, this.metadataEncoder.encode())
					.map(tuple -> PayloadUtils.createPayload(tuple.getT1(), tuple.getT2()))
					.doOnDiscard(DataBuffer.class, DataBufferUtils::release)
					.doOnDiscard(Payload.class, Payload::release);
		}

		@Override
		public Mono sendMetadata() {
			return rsocketClient.metadataPush(getPayloadMono());
		}

		@Override
		public Mono send() {
			return rsocketClient.fireAndForget(getPayloadMono());
		}

		@Override
		public  Mono retrieveMono(Class dataType) {
			return retrieveMono(ResolvableType.forClass(dataType));
		}

		@Override
		public  Mono retrieveMono(ParameterizedTypeReference dataTypeRef) {
			return retrieveMono(ResolvableType.forType(dataTypeRef));
		}

		@SuppressWarnings("unchecked")
		private  Mono retrieveMono(ResolvableType elementType) {
			Mono payloadMono = rsocketClient.requestResponse(getPayloadMono());

			if (isVoid(elementType)) {
				return (Mono) payloadMono.then();
			}

			Decoder decoder = strategies.decoder(elementType, dataMimeType);
			return (Mono) payloadMono.map(this::retainDataAndReleasePayload)
					.map(dataBuffer -> decoder.decode(dataBuffer, elementType, dataMimeType, EMPTY_HINTS));
		}

		@Override
		public  Flux retrieveFlux(Class dataType) {
			return retrieveFlux(ResolvableType.forClass(dataType));
		}

		@Override
		public  Flux retrieveFlux(ParameterizedTypeReference dataTypeRef) {
			return retrieveFlux(ResolvableType.forType(dataTypeRef));
		}

		@SuppressWarnings("unchecked")
		private  Flux retrieveFlux(ResolvableType elementType) {

			Flux payloadFlux = (this.payloadFlux != null ?
					rsocketClient.requestChannel(this.payloadFlux) :
					rsocketClient.requestStream(getPayloadMono()));

			if (isVoid(elementType)) {
				return payloadFlux.thenMany(Flux.empty());
			}

			Decoder decoder = strategies.decoder(elementType, dataMimeType);
			return payloadFlux.map(this::retainDataAndReleasePayload).map(dataBuffer ->
					(T) decoder.decode(dataBuffer, elementType, dataMimeType, EMPTY_HINTS));
		}

		private Mono getPayloadMono() {
			Assert.state(this.payloadFlux == null, "No RSocket interaction with Flux request and Mono response.");
			return this.payloadMono != null ? this.payloadMono : firstPayload(emptyBufferMono);
		}

		private DataBuffer retainDataAndReleasePayload(Payload payload) {
			return PayloadUtils.retainDataAndReleasePayload(payload, bufferFactory());
		}
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy