All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.jauntsdn.rsocket.RSocketRequester Maven / Gradle / Ivy

There is a newer version: 0.9.8
Show newest version
/*
 * Copyright 2015-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.jauntsdn.rsocket;

import static com.jauntsdn.rsocket.StreamErrorMappers.*;
import static com.jauntsdn.rsocket.keepalive.KeepAlive.ClientKeepAlive;

import com.jauntsdn.rsocket.exceptions.ConnectionErrorException;
import com.jauntsdn.rsocket.exceptions.Exceptions;
import com.jauntsdn.rsocket.frame.*;
import com.jauntsdn.rsocket.frame.decoder.PayloadDecoder;
import com.jauntsdn.rsocket.internal.SynchronizedIntObjectHashMap;
import com.jauntsdn.rsocket.internal.UnboundedProcessor;
import com.jauntsdn.rsocket.keepalive.KeepAlive;
import com.jauntsdn.rsocket.keepalive.KeepAlive.ServerKeepAlive;
import com.jauntsdn.rsocket.keepalive.KeepAliveHandler;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.collection.IntObjectMap;
import java.nio.channels.ClosedChannelException;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Consumer;
import java.util.function.LongConsumer;
import javax.annotation.Nonnull;
import org.reactivestreams.Processor;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.publisher.*;
import reactor.core.scheduler.Scheduler;

/**
 * Requester Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketResponder} of peer
 */
class RSocketRequester implements RSocket {
  private static final AtomicReferenceFieldUpdater TERMINATION_ERROR =
      AtomicReferenceFieldUpdater.newUpdater(
          RSocketRequester.class, Throwable.class, "terminationError");
  private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();

  static {
    CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]);
  }

  private final DuplexConnection connection;
  private final PayloadDecoder payloadDecoder;
  private final Consumer errorConsumer;
  private final ErrorFrameMapper errorFrameMapper;
  private final StreamIdSupplier streamIdSupplier;
  private final IntObjectMap senders;
  private final IntObjectMap> receivers;
  private final UnboundedProcessor sendProcessor;
  private final ByteBufAllocator allocator;
  private final Consumer keepAliveFramesAcceptor;
  private final Scheduler transportScheduler;
  private volatile Throwable terminationError;
  private final KeepAlive keepAlive;

  RSocketRequester(
      ByteBufAllocator allocator,
      DuplexConnection connection,
      PayloadDecoder payloadDecoder,
      Consumer errorConsumer,
      ErrorFrameMapper errorFrameMapper,
      StreamIdSupplier streamIdSupplier,
      int keepAliveTickPeriod,
      int keepAliveAckTimeout,
      KeepAliveHandler keepAliveHandler) {
    this.allocator = allocator;
    this.connection = connection;
    this.payloadDecoder = payloadDecoder;
    this.errorConsumer = errorConsumer;
    this.errorFrameMapper = errorFrameMapper;
    this.streamIdSupplier = streamIdSupplier;
    this.senders = new SynchronizedIntObjectHashMap<>();
    this.receivers = new SynchronizedIntObjectHashMap<>();
    this.transportScheduler = connection.scheduler();

    // DO NOT Change the order here. The Send processor must be subscribed to before receiving
    this.sendProcessor = new UnboundedProcessor<>();

    connection
        .onClose()
        .doFinally(signalType -> tryTerminate(CLOSED_CHANNEL_EXCEPTION))
        .subscribe(null, errorConsumer);
    connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError);

    connection.receive().subscribe(this::handleIncomingFrames, errorConsumer);

    KeepAlive keepAlive =
        keepAliveTickPeriod != 0
            ? new ClientKeepAlive(
                transportScheduler,
                allocator,
                keepAliveTickPeriod,
                keepAliveAckTimeout,
                keepAliveFrame -> sendProcessor.onNext(keepAliveFrame))
            : new ServerKeepAlive(
                transportScheduler,
                allocator,
                keepAliveAckTimeout,
                keepAliveFrame -> sendProcessor.onNext(keepAliveFrame));
    this.keepAlive = keepAlive;
    this.keepAliveFramesAcceptor =
        keepAliveHandler.start(keepAlive, () -> tryTerminate(keepAliveAckTimeout));
  }

  @Override
  public Mono fireAndForget(Payload payload) {
    return handleFireAndForget(payload);
  }

  @Override
  public Mono requestResponse(Payload payload) {
    return handleRequestResponse(payload);
  }

  @Override
  public Flux requestStream(Payload payload) {
    return handleRequestStream(payload);
  }

  @Override
  public Flux requestChannel(Publisher payloads) {
    return handleChannel(Flux.from(payloads));
  }

  @Override
  public Mono metadataPush(Payload payload) {
    return handleMetadataPush(payload);
  }

  @Override
  public double availability() {
    return connection.availability();
  }

  @Override
  public Optional scheduler() {
    return Optional.of(transportScheduler);
  }

  @Override
  public void dispose() {
    connection.dispose();
  }

  @Override
  public boolean isDisposed() {
    return connection.isDisposed();
  }

  @Override
  public Mono onClose() {
    return connection.onClose();
  }

  Throwable checkAllowed() {
    return terminationError;
  }

  void handleLeaseFrame(ByteBuf frame) {}

  KeepAlive keepAlive() {
    return keepAlive;
  }

  private Mono handleFireAndForget(Payload payload) {
    final Throwable err = checkAllowed();
    if (err != null) {
      payload.release();
      return Mono.error(err);
    }

    return Mono.create(
            new Consumer>() {

              boolean isRequestSent;

              @Override
              public void accept(MonoSink sink) {
                if (isRequestSent) {
                  sink.error(new IllegalStateException("only a single Subscriber is allowed"));
                  return;
                }
                isRequestSent = true;

                Throwable e = terminationError;
                if (e != null) {
                  sink.error(e);
                  payload.release();
                  return;
                }

                final int streamId = streamIdSupplier.nextStreamId(receivers);
                final ByteBuf requestFrame =
                    RequestFireAndForgetFrameFlyweight.encode(
                        allocator,
                        streamId,
                        payload.hasMetadata() ? payload.sliceMetadata().retain() : null,
                        payload.sliceData().retain());
                payload.release();

                sendProcessor.onNext(requestFrame);
                sink.success();
              }
            })
        .subscribeOn(transportScheduler);
  }

  private Mono handleRequestResponse(final Payload payload) {
    final Throwable err = checkAllowed();
    if (err != null) {
      payload.release();
      return Mono.error(err);
    }
    final UnboundedProcessor sendProcessor = this.sendProcessor;

    final Stream stream = new Stream();
    final MonoProcessor receiver = MonoProcessor.create();
    return receiver
        .doOnSubscribe(
            new Consumer() {

              boolean isRequestSent;

              @Override
              public void accept(@Nonnull Subscription subscription) {
                if (isRequestSent) {
                  receiver.onError(
                      new IllegalStateException("only a single Subscriber is allowed"));
                  return;
                }
                isRequestSent = true;

                if (receiver.isTerminated()) {
                  payload.release();
                  return;
                }

                Throwable err = terminationError;
                if (err != null) {
                  receiver.onError(err);
                  payload.release();
                  return;
                }

                int streamId = stream.setId(streamIdSupplier.nextStreamId(receivers));
                receivers.put(streamId, receiver);

                final ByteBuf requestFrame =
                    RequestResponseFrameFlyweight.encode(
                        allocator,
                        streamId,
                        payload.hasMetadata() ? payload.sliceMetadata().retain() : null,
                        payload.sliceData().retain());
                payload.release();

                sendProcessor.onNext(requestFrame);
              }
            })
        .doFinally(
            signalType -> {
              int streamId = stream.getId();
              if (signalType == SignalType.CANCEL && !receiver.isTerminated()) {
                sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
              }
              removeReceiver(streamId);
            })
        .subscribeOn(transportScheduler);
  }

  private Flux handleRequestStream(final Payload payload) {
    final Throwable err = checkAllowed();
    if (err != null) {
      payload.release();
      return Flux.error(err);
    }

    final UnboundedProcessor sendProcessor = this.sendProcessor;
    final UnicastProcessor receiver = UnicastProcessor.create();
    final Stream stream = new Stream();

    return receiver
        .doOnRequest(
            new LongConsumer() {

              boolean isRequestSent;

              @Override
              public void accept(long requestN) {

                if (receiver.isDisposed()) {
                  if (!isRequestSent) {
                    payload.release();
                  }
                  return;
                }

                Throwable err = terminationError;
                if (err != null) {
                  receiver.onError(err);
                  if (!isRequestSent) {
                    payload.release();
                  }
                  return;
                }

                if (!isRequestSent) {
                  isRequestSent = true;

                  final int streamId = stream.setId(streamIdSupplier.nextStreamId(receivers));
                  receivers.put(streamId, receiver);

                  sendProcessor.onNext(
                      RequestStreamFrameFlyweight.encode(
                          allocator,
                          streamId,
                          requestN,
                          payload.hasMetadata() ? payload.sliceMetadata().retain() : null,
                          payload.sliceData().retain()));
                  payload.release();
                } else {
                  final int streamId = stream.getId();
                  sendProcessor.onNext(
                      RequestNFrameFlyweight.encode(allocator, streamId, requestN));
                }
              }
            })
        .doFinally(
            signalType -> {
              final int streamId = stream.getId();
              if (streamId < 0) {
                return;
              }
              if (signalType == SignalType.CANCEL && !receiver.isTerminated()) {
                sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
              }
              removeReceiver(streamId);
            })
        .subscribeOn(transportScheduler);
  }

  private Flux handleChannel(Flux request) {
    Throwable err = checkAllowed();
    if (err != null) {
      return Flux.error(err);
    }

    final UnboundedProcessor sendProcessor = this.sendProcessor;
    final UnicastProcessor receiver = UnicastProcessor.create();
    final Stream stream = new Stream();

    return receiver
        .doOnRequest(
            new LongConsumer() {

              boolean isRequestSent;

              @Override
              public void accept(long requestN) {
                if (receiver.isDisposed()) {
                  return;
                }
                if (!isRequestSent) {
                  isRequestSent = true;
                  request
                      .publishOn(transportScheduler)
                      .subscribe(
                          new BaseSubscriber() {

                            boolean isRequestSent;
                            Subscription sender;

                            @Override
                            protected void hookOnSubscribe(Subscription subscription) {
                              this.sender = subscription;
                              stream.setSender(subscription);
                              subscription.request(1);
                            }

                            @Override
                            protected void hookOnNext(Payload payload) {
                              final ByteBuf frame;

                              if (receiver.isDisposed()) {
                                payload.release();
                                return;
                              }

                              Throwable err = terminationError;
                              if (err != null) {
                                receiver.onError(err);
                                payload.release();
                                return;
                              }

                              if (!isRequestSent) {
                                isRequestSent = true;

                                int streamId =
                                    stream.setId(streamIdSupplier.nextStreamId(receivers));
                                senders.put(streamId, sender);
                                receivers.put(streamId, receiver);

                                frame =
                                    RequestChannelFrameFlyweight.encode(
                                        allocator,
                                        streamId,
                                        requestN,
                                        payload.hasMetadata()
                                            ? payload.sliceMetadata().retain()
                                            : null,
                                        payload.sliceData().retain());
                              } else {
                                final int streamId = stream.getId();
                                frame =
                                    PayloadFrameFlyweight.encodeNext(allocator, streamId, payload);
                              }

                              sendProcessor.onNext(frame);
                              payload.release();
                            }

                            @Override
                            protected void hookOnComplete() {
                              int streamId = stream.getId();
                              if (streamId < 0) {
                                receiver.onComplete();
                              } else if (!receiver.isDisposed()) {
                                sendProcessor.onNext(
                                    PayloadFrameFlyweight.encodeComplete(allocator, streamId));
                              }
                            }

                            @Override
                            protected void hookOnError(Throwable t) {
                              int streamId = stream.getId();
                              if (streamId > 0) {
                                sendProcessor.onNext(
                                    errorFrameMapper.streamErrorToFrame(
                                        streamId, StreamType.REQUEST, t));
                              }
                              // todo signal error wrapping request error
                              receiver.dispose();
                            }
                          });
                } else {
                  final int streamId = stream.getId();
                  sendProcessor.onNext(
                      RequestNFrameFlyweight.encode(allocator, streamId, requestN));
                }
              }
            })
        .doFinally(
            s -> {
              int streamId = stream.getId();
              if (streamId < 0) {
                Subscription sender = stream.getSender();
                if (sender != null) {
                  sender.cancel();
                }
                return;
              }
              if (s == SignalType.CANCEL && !receiver.isTerminated()) {
                sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
              }
              removeReceiverAndSender(streamId);
            });
  }

  private Mono handleMetadataPush(Payload payload) {
    Throwable err = this.terminationError;
    if (err != null) {
      payload.release();
      return Mono.error(err);
    }

    return Mono.create(
        new Consumer>() {

          boolean isRequestSent;

          @Override
          public void accept(MonoSink sink) {
            if (isRequestSent) {
              sink.error(new IllegalStateException("only a single Subscriber is allowed"));
              return;
            }
            isRequestSent = true;

            Throwable err = terminationError;
            if (err != null) {
              payload.release();
              sink.error(err);
              return;
            }

            ByteBuf metadataPushFrame =
                MetadataPushFrameFlyweight.encode(allocator, payload.sliceMetadata().retain());
            payload.release();

            sendProcessor.onNext(metadataPushFrame);
            sink.success();
          }
        });
  }

  private void handleIncomingFrames(ByteBuf frame) {
    try {
      int streamId = FrameHeaderFlyweight.streamId(frame);
      FrameType type = FrameHeaderFlyweight.strictFrameType(frame);
      if (streamId == 0) {
        handleZeroFrame(type, frame);
      } else {
        handleStreamFrame(streamId, type, frame);
      }
      frame.release();
    } catch (Throwable t) {
      ReferenceCountUtil.safeRelease(frame);
      throw reactor.core.Exceptions.propagate(t);
    }
  }

  private void handleZeroFrame(FrameType type, ByteBuf frame) {
    switch (type) {
      case ERROR:
        tryTerminate(frame);
        break;
      case LEASE:
        handleLeaseFrame(frame);
        break;
      case KEEPALIVE:
        keepAliveFramesAcceptor.accept(frame);
        break;
      default:
        // Ignore unknown frames. Throwing an error will close the socket.
        errorConsumer.accept(
            new IllegalStateException(
                "Client received supported frame on stream 0: " + frame.toString()));
    }
  }

  private void handleStreamFrame(int streamId, FrameType type, ByteBuf frame) {
    Subscriber receiver = receivers.get(streamId);
    if (receiver != null) {
      switch (type) {
        case ERROR:
          receiver.onError(errorFrameMapper.streamFrameToError(frame, StreamType.RESPONSE));
          receivers.remove(streamId);
          break;
        case NEXT_COMPLETE:
          receiver.onNext(payloadDecoder.apply(frame, type));
          receiver.onComplete();
          break;
        case CANCEL:
          {
            Subscription sender = senders.remove(streamId);
            if (sender != null) {
              sender.cancel();
            }
            break;
          }
        case NEXT:
          receiver.onNext(payloadDecoder.apply(frame, type));
          break;
        case REQUEST_N:
          {
            Subscription sender = senders.get(streamId);
            if (sender != null) {
              int n = RequestNFrameFlyweight.requestN(frame);
              sender.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n);
            }
            break;
          }
        case COMPLETE:
          receiver.onComplete();
          receivers.remove(streamId);
          break;
        default:
          throw new IllegalStateException(
              "Client received supported frame on stream " + streamId + ": " + frame.toString());
      }
    }
  }

  /*error is (int (KeepAlive timeout) | ClosedChannelException (dispose) | Error frame (0 error)) */
  /*no race because executed on transport scheduler */
  private void tryTerminate(Object error) {
    if (terminationError == null) {
      Exception e;
      if (error instanceof ClosedChannelException) {
        e = (Exception) error;
      } else if (error instanceof Integer) {
        Integer keepAliveTimeout = (Integer) error;
        e =
            new ConnectionErrorException(
                String.format("No keep-alive acks for %d ms", keepAliveTimeout));
      } else if (error instanceof ByteBuf) {
        ByteBuf errorFrame = (ByteBuf) error;
        e = Exceptions.from(errorFrame);
      } else {
        e = new IllegalStateException("Unknown termination token: " + error);
      }
      this.terminationError = e;
      terminate(e);
    }
  }

  private void terminate(Exception e) {
    connection.dispose();

    synchronized (receivers) {
      receivers
          .values()
          .forEach(
              receiver -> {
                try {
                  receiver.onError(e);
                } catch (Throwable t) {
                  errorConsumer.accept(t);
                }
              });
    }
    synchronized (senders) {
      senders
          .values()
          .forEach(
              sender -> {
                try {
                  sender.cancel();
                } catch (Throwable t) {
                  errorConsumer.accept(t);
                }
              });
    }
    senders.clear();
    receivers.clear();
    sendProcessor.dispose();
    errorConsumer.accept(e);
  }

  private void removeReceiver(int streamId) {
    /*on termination receivers are explicitly cleared to avoid removing from map while iterating over one
    of its views*/
    if (terminationError == null) {
      receivers.remove(streamId);
    }
  }

  private void removeReceiverAndSender(int streamId) {
    /*on termination senders & receivers are explicitly cleared to avoid removing from map while iterating over one
    of its views*/
    if (terminationError == null) {
      receivers.remove(streamId);
      Subscription sender = senders.remove(streamId);
      if (sender != null) {
        sender.cancel();
      }
    }
  }

  private void handleSendProcessorError(Throwable t) {
    connection.dispose();
  }

  private static class Stream {
    private volatile int id = -1;
    private volatile Subscription sender;

    public int setId(int streamId) {
      this.id = streamId;
      return streamId;
    }

    public int getId() {
      return id;
    }

    public Subscription getSender() {
      return sender;
    }

    public void setSender(Subscription sender) {
      this.sender = sender;
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy