io.rsocket.RSocketFactory Maven / Gradle / Ivy
/*
* Copyright 2015-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.rsocket;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.rsocket.exceptions.InvalidSetupException;
import io.rsocket.exceptions.RejectedSetupException;
import io.rsocket.frame.FrameHeaderFlyweight;
import io.rsocket.frame.ResumeFrameFlyweight;
import io.rsocket.frame.SetupFrameFlyweight;
import io.rsocket.frame.decoder.PayloadDecoder;
import io.rsocket.internal.ClientServerInputMultiplexer;
import io.rsocket.internal.ClientSetup;
import io.rsocket.internal.KeepAliveData;
import io.rsocket.internal.ServerSetup;
import io.rsocket.keepalive.KeepAliveConnection;
import io.rsocket.plugins.DuplexConnectionInterceptor;
import io.rsocket.plugins.PluginRegistry;
import io.rsocket.plugins.Plugins;
import io.rsocket.plugins.RSocketInterceptor;
import io.rsocket.resume.*;
import io.rsocket.transport.ClientTransport;
import io.rsocket.transport.ServerTransport;
import io.rsocket.util.ConnectionUtils;
import io.rsocket.util.EmptyPayload;
import java.time.Duration;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import reactor.core.publisher.Mono;
/** Factory for creating RSocket clients and servers. */
public class RSocketFactory {
/**
* Creates a factory that establishes client connections to other RSockets.
*
* @return a client factory
*/
public static ClientRSocketFactory connect() {
return new ClientRSocketFactory();
}
/**
* Creates a factory that receives server connections from client RSockets.
*
* @return a server factory.
*/
public static ServerRSocketFactory receive() {
return new ServerRSocketFactory();
}
public interface Start {
Mono start();
}
public interface ClientTransportAcceptor {
Start transport(Supplier transport);
default Start transport(ClientTransport transport) {
return transport(() -> transport);
}
}
public interface ServerTransportAcceptor {
ServerTransport.ConnectionAcceptor toConnectionAcceptor();
Start transport(Supplier> transport);
default Start transport(ServerTransport transport) {
return transport(() -> transport);
}
}
public static class ClientRSocketFactory implements ClientTransportAcceptor {
private Supplier> acceptor =
() -> rSocket -> new AbstractRSocket() {};
private Consumer errorConsumer = Throwable::printStackTrace;
private int mtu = 0;
private PluginRegistry plugins = new PluginRegistry(Plugins.defaultPlugins());
private Payload setupPayload = EmptyPayload.INSTANCE;
private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT;
private Duration tickPeriod = Duration.ofSeconds(20);
private Duration ackTimeout = Duration.ofSeconds(30);
private int missedAcks = 3;
private String metadataMimeType = "application/binary";
private String dataMimeType = "application/binary";
private boolean resumeEnabled;
private boolean resumeCleanupStoreOnKeepAlive;
private Supplier resumeTokenSupplier = ResumeFrameFlyweight::generateResumeToken;
private Function super ByteBuf, ? extends ResumableFramesStore> resumeStoreFactory =
token -> new InMemoryResumableFramesStore("client", 100_000);
private Duration resumeSessionDuration = Duration.ofMinutes(2);
private Duration resumeStreamTimeout = Duration.ofSeconds(10);
private Supplier resumeStrategySupplier =
() ->
new ExponentialBackoffResumeStrategy(Duration.ofSeconds(1), Duration.ofSeconds(16), 2);
private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;
public ClientRSocketFactory byteBufAllocator(ByteBufAllocator allocator) {
Objects.requireNonNull(allocator);
this.allocator = allocator;
return this;
}
public ClientRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) {
plugins.addConnectionPlugin(interceptor);
return this;
}
public ClientRSocketFactory addClientPlugin(RSocketInterceptor interceptor) {
plugins.addClientPlugin(interceptor);
return this;
}
public ClientRSocketFactory addServerPlugin(RSocketInterceptor interceptor) {
plugins.addServerPlugin(interceptor);
return this;
}
/**
* Deprecated as Keep-Alive is not optional according to spec
*
* @return this ClientRSocketFactory
*/
@Deprecated
public ClientRSocketFactory keepAlive() {
return this;
}
public ClientRSocketFactory keepAlive(
Duration tickPeriod, Duration ackTimeout, int missedAcks) {
this.tickPeriod = tickPeriod;
this.ackTimeout = ackTimeout;
this.missedAcks = missedAcks;
return this;
}
public ClientRSocketFactory keepAliveTickPeriod(Duration tickPeriod) {
this.tickPeriod = tickPeriod;
return this;
}
public ClientRSocketFactory keepAliveAckTimeout(Duration ackTimeout) {
this.ackTimeout = ackTimeout;
return this;
}
public ClientRSocketFactory keepAliveMissedAcks(int missedAcks) {
this.missedAcks = missedAcks;
return this;
}
public ClientRSocketFactory mimeType(String metadataMimeType, String dataMimeType) {
this.dataMimeType = dataMimeType;
this.metadataMimeType = metadataMimeType;
return this;
}
public ClientRSocketFactory dataMimeType(String dataMimeType) {
this.dataMimeType = dataMimeType;
return this;
}
public ClientRSocketFactory metadataMimeType(String metadataMimeType) {
this.metadataMimeType = metadataMimeType;
return this;
}
public ClientRSocketFactory resume() {
this.resumeEnabled = true;
return this;
}
public ClientRSocketFactory resumeToken(Supplier resumeTokenSupplier) {
this.resumeTokenSupplier = Objects.requireNonNull(resumeTokenSupplier);
return this;
}
public ClientRSocketFactory resumeStore(
Function super ByteBuf, ? extends ResumableFramesStore> resumeStoreFactory) {
this.resumeStoreFactory = resumeStoreFactory;
return this;
}
public ClientRSocketFactory resumeSessionDuration(Duration sessionDuration) {
this.resumeSessionDuration = Objects.requireNonNull(sessionDuration);
return this;
}
public ClientRSocketFactory resumeStreamTimeout(Duration resumeStreamTimeout) {
this.resumeStreamTimeout = Objects.requireNonNull(resumeStreamTimeout);
return this;
}
public ClientRSocketFactory resumeStrategy(Supplier resumeStrategy) {
this.resumeStrategySupplier = Objects.requireNonNull(resumeStrategy);
return this;
}
public ClientRSocketFactory resumeCleanupOnKeepAlive() {
resumeCleanupStoreOnKeepAlive = true;
return this;
}
@Override
public Start transport(Supplier transportClient) {
return new StartClient(transportClient);
}
public ClientTransportAcceptor acceptor(Function acceptor) {
this.acceptor = () -> acceptor;
return StartClient::new;
}
public ClientTransportAcceptor acceptor(Supplier> acceptor) {
this.acceptor = acceptor;
return StartClient::new;
}
public ClientRSocketFactory fragment(int mtu) {
this.mtu = mtu;
return this;
}
public ClientRSocketFactory errorConsumer(Consumer errorConsumer) {
this.errorConsumer = errorConsumer;
return this;
}
public ClientRSocketFactory setupPayload(Payload payload) {
this.setupPayload = payload;
return this;
}
public ClientRSocketFactory frameDecoder(PayloadDecoder payloadDecoder) {
this.payloadDecoder = payloadDecoder;
return this;
}
private class StartClient implements Start {
private final Supplier transportClient;
StartClient(Supplier transportClient) {
this.transportClient = transportClient;
}
@Override
public Mono start() {
return newConnection()
.flatMap(
connection -> {
ClientSetup clientSetup = clientSetup();
DuplexConnection wrappedConnection = clientSetup.wrappedConnection(connection);
ByteBuf resumeToken = clientSetup.resumeToken();
ClientServerInputMultiplexer multiplexer =
new ClientServerInputMultiplexer(wrappedConnection, plugins);
RSocketClient rSocketClient =
new RSocketClient(
allocator,
multiplexer.asClientConnection(),
payloadDecoder,
errorConsumer,
StreamIdSupplier.clientSupplier());
RSocket wrappedRSocketClient = plugins.applyClient(rSocketClient);
RSocket unwrappedServerSocket = acceptor.get().apply(wrappedRSocketClient);
RSocket wrappedRSocketServer = plugins.applyServer(unwrappedServerSocket);
RSocketServer rSocketServer =
new RSocketServer(
allocator,
multiplexer.asServerConnection(),
wrappedRSocketServer,
payloadDecoder,
errorConsumer);
ByteBuf setupFrame =
SetupFrameFlyweight.encode(
allocator,
false,
(int) keepAliveTickPeriod(),
(int) keepAliveTimeout(),
resumeToken,
metadataMimeType,
dataMimeType,
setupPayload.sliceMetadata(),
setupPayload.sliceData());
return wrappedConnection.sendOne(setupFrame).thenReturn(wrappedRSocketClient);
});
}
private long keepAliveTickPeriod() {
return tickPeriod.toMillis();
}
private long keepAliveTimeout() {
return ackTimeout.toMillis() + tickPeriod.toMillis() * missedAcks;
}
private ClientSetup clientSetup() {
if (resumeEnabled) {
ByteBuf resumeToken = resumeTokenSupplier.get();
return new ClientSetup.ResumableClientSetup(
allocator,
newConnection(),
resumeToken,
resumeStoreFactory.apply(resumeToken),
resumeSessionDuration,
resumeStreamTimeout,
resumeStrategySupplier,
resumeCleanupStoreOnKeepAlive);
} else {
return new ClientSetup.DefaultClientSetup();
}
}
private Mono newConnection() {
return transportClient
.get()
.connect(mtu)
.map(
connection ->
KeepAliveConnection.ofClient(
allocator,
connection,
notUsed -> new KeepAliveData(keepAliveTickPeriod(), keepAliveTimeout()),
errorConsumer));
}
}
}
public static class ServerRSocketFactory {
private SocketAcceptor acceptor;
private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT;
private Consumer errorConsumer = Throwable::printStackTrace;
private int mtu = 0;
private PluginRegistry plugins = new PluginRegistry(Plugins.defaultPlugins());
private boolean resumeSupported;
private Duration resumeSessionDuration = Duration.ofSeconds(120);
private Duration resumeStreamTimeout = Duration.ofSeconds(10);
private Function super ByteBuf, ? extends ResumableFramesStore> resumeStoreFactory =
token -> new InMemoryResumableFramesStore("server", 100_000);
private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;
private boolean resumeCleanupStoreOnKeepAlive;
private ServerRSocketFactory() {}
public ServerRSocketFactory byteBufAllocator(ByteBufAllocator allocator) {
Objects.requireNonNull(allocator);
this.allocator = allocator;
return this;
}
public ServerRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) {
plugins.addConnectionPlugin(interceptor);
return this;
}
public ServerRSocketFactory addClientPlugin(RSocketInterceptor interceptor) {
plugins.addClientPlugin(interceptor);
return this;
}
public ServerRSocketFactory addServerPlugin(RSocketInterceptor interceptor) {
plugins.addServerPlugin(interceptor);
return this;
}
public ServerTransportAcceptor acceptor(SocketAcceptor acceptor) {
this.acceptor = acceptor;
return new ServerStart<>();
}
public ServerRSocketFactory frameDecoder(PayloadDecoder payloadDecoder) {
this.payloadDecoder = payloadDecoder;
return this;
}
public ServerRSocketFactory fragment(int mtu) {
this.mtu = mtu;
return this;
}
public ServerRSocketFactory errorConsumer(Consumer errorConsumer) {
this.errorConsumer = errorConsumer;
return this;
}
public ServerRSocketFactory resume() {
this.resumeSupported = true;
return this;
}
public ServerRSocketFactory resumeStore(
Function super ByteBuf, ? extends ResumableFramesStore> resumeStoreFactory) {
this.resumeStoreFactory = resumeStoreFactory;
return this;
}
public ServerRSocketFactory resumeSessionDuration(Duration sessionDuration) {
this.resumeSessionDuration = Objects.requireNonNull(sessionDuration);
return this;
}
public ServerRSocketFactory resumeStreamTimeout(Duration resumeStreamTimeout) {
this.resumeStreamTimeout = Objects.requireNonNull(resumeStreamTimeout);
return this;
}
public ServerRSocketFactory resumeCleanupOnKeepAlive() {
resumeCleanupStoreOnKeepAlive = true;
return this;
}
private class ServerStart implements Start, ServerTransportAcceptor {
private Supplier> transportServer;
@Override
public ServerTransport.ConnectionAcceptor toConnectionAcceptor() {
return new ServerTransport.ConnectionAcceptor() {
private final ServerSetup serverSetup = serverSetup();
@Override
public Mono apply(DuplexConnection connection) {
return acceptor(serverSetup, connection);
}
};
}
@Override
@SuppressWarnings("unchecked")
public Start transport(Supplier> transport) {
this.transportServer = (Supplier) transport;
return (Start) this::start;
}
private Mono acceptor(ServerSetup serverSetup, DuplexConnection connection) {
connection =
KeepAliveConnection.ofServer(
allocator, connection, serverSetup::keepAliveData, errorConsumer);
ClientServerInputMultiplexer multiplexer =
new ClientServerInputMultiplexer(connection, plugins);
return multiplexer
.asSetupConnection()
.receive()
.next()
.flatMap(startFrame -> accept(serverSetup, startFrame, multiplexer));
}
private Mono acceptResume(
ServerSetup serverSetup, ByteBuf resumeFrame, ClientServerInputMultiplexer multiplexer) {
return serverSetup.acceptRSocketResume(resumeFrame, multiplexer);
}
private Mono accept(
ServerSetup serverSetup, ByteBuf startFrame, ClientServerInputMultiplexer multiplexer) {
switch (FrameHeaderFlyweight.frameType(startFrame)) {
case SETUP:
return acceptSetup(serverSetup, startFrame, multiplexer);
case RESUME:
return acceptResume(serverSetup, startFrame, multiplexer);
default:
return acceptUnknown(startFrame, multiplexer);
}
}
private Mono acceptSetup(
ServerSetup serverSetup, ByteBuf setupFrame, ClientServerInputMultiplexer multiplexer) {
if (!SetupFrameFlyweight.isSupportedVersion(setupFrame)) {
return sendError(
multiplexer,
new InvalidSetupException(
"Unsupported version: "
+ SetupFrameFlyweight.humanReadableVersion(setupFrame)))
.doFinally(
signalType -> {
setupFrame.release();
multiplexer.dispose();
});
}
return serverSetup.acceptRSocketSetup(
setupFrame,
multiplexer,
wrappedMultiplexer -> {
ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(setupFrame);
RSocketClient rSocketClient =
new RSocketClient(
allocator,
wrappedMultiplexer.asServerConnection(),
payloadDecoder,
errorConsumer,
StreamIdSupplier.serverSupplier());
RSocket wrappedRSocketClient = plugins.applyClient(rSocketClient);
return acceptor
.accept(setupPayload, wrappedRSocketClient)
.onErrorResume(
err -> sendError(multiplexer, rejectedSetupError(err)).then(Mono.error(err)))
.doOnNext(
unwrappedServerSocket -> {
RSocket wrappedRSocketServer = plugins.applyServer(unwrappedServerSocket);
RSocketServer rSocketServer =
new RSocketServer(
allocator,
wrappedMultiplexer.asClientConnection(),
wrappedRSocketServer,
payloadDecoder,
errorConsumer);
})
.doFinally(signalType -> setupPayload.release())
.then();
});
}
@Override
public Mono start() {
return Mono.defer(
new Supplier>() {
ServerSetup serverSetup = serverSetup();
@Override
public Mono get() {
return transportServer
.get()
.start(duplexConnection -> acceptor(serverSetup, duplexConnection), mtu)
.doOnNext(c -> c.onClose().doFinally(v -> serverSetup.dispose()).subscribe());
}
});
}
private ServerSetup serverSetup() {
return resumeSupported
? new ServerSetup.ResumableServerSetup(
allocator,
new SessionManager(),
resumeSessionDuration,
resumeStreamTimeout,
resumeStoreFactory,
resumeCleanupStoreOnKeepAlive)
: new ServerSetup.DefaultServerSetup(allocator);
}
private Mono acceptUnknown(ByteBuf frame, ClientServerInputMultiplexer multiplexer) {
return sendError(
multiplexer,
new InvalidSetupException(
"invalid setup frame: " + FrameHeaderFlyweight.frameType(frame)))
.doFinally(
signalType -> {
frame.release();
multiplexer.dispose();
});
}
private Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) {
return ConnectionUtils.sendError(allocator, multiplexer, exception);
}
private Exception rejectedSetupError(Throwable err) {
String msg = err.getMessage();
return new RejectedSetupException(msg == null ? "rejected by server acceptor" : msg);
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy