All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.jauntsdn.rsocket.RSocketResponder Maven / Gradle / Ivy

There is a newer version: 0.9.8
Show newest version
/*
 * Copyright 2015-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.jauntsdn.rsocket;

import static com.jauntsdn.rsocket.StreamErrorMappers.*;

import com.jauntsdn.rsocket.frame.*;
import com.jauntsdn.rsocket.frame.decoder.PayloadDecoder;
import com.jauntsdn.rsocket.internal.SynchronizedIntObjectHashMap;
import com.jauntsdn.rsocket.internal.UnboundedProcessor;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.collection.IntObjectMap;
import java.nio.channels.ClosedChannelException;
import java.util.function.Consumer;
import org.reactivestreams.Processor;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Exceptions;
import reactor.core.publisher.*;

/** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */
class RSocketResponder implements ResponderRSocket {
  private static final Logger logger = LoggerFactory.getLogger(RSocketResponder.class);
  private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();

  static {
    CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]);
  }

  private final DuplexConnection connection;
  private final RSocket requestHandler;
  private final ResponderRSocket responderRSocket;
  private final PayloadDecoder payloadDecoder;
  private final Consumer errorConsumer;
  private final ErrorFrameMapper errorFrameMapper;

  private final IntObjectMap senders;
  private final IntObjectMap> receivers;

  private final UnboundedProcessor sendProcessor;
  private final ByteBufAllocator allocator;

  private volatile Throwable terminationError;

  RSocketResponder(
      ByteBufAllocator allocator,
      DuplexConnection connection,
      RSocket requestHandler,
      PayloadDecoder payloadDecoder,
      Consumer errorConsumer,
      ErrorFrameMapper errorFrameMapper) {
    this.allocator = allocator;
    this.connection = connection;

    this.requestHandler = requestHandler;
    this.responderRSocket =
        (requestHandler instanceof ResponderRSocket) ? (ResponderRSocket) requestHandler : null;

    this.payloadDecoder = payloadDecoder;
    this.errorConsumer = errorConsumer;
    this.errorFrameMapper = errorFrameMapper;
    this.senders = new SynchronizedIntObjectHashMap<>();
    this.receivers = new SynchronizedIntObjectHashMap<>();

    this.sendProcessor = new UnboundedProcessor<>();

    connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError);

    connection.receive().subscribe(this::handleFrame, errorConsumer);

    this.connection.onClose().doFinally(s -> terminate()).subscribe(null, errorConsumer);
  }

  @Override
  public Mono fireAndForget(Payload payload) {
    try {
      return requestHandler.fireAndForget(payload);
    } catch (Throwable t) {
      return Mono.error(t);
    }
  }

  @Override
  public Mono requestResponse(Payload payload) {
    try {
      return requestHandler.requestResponse(payload);
    } catch (Throwable t) {
      return Mono.error(t);
    }
  }

  @Override
  public Flux requestStream(Payload payload) {
    try {
      return requestHandler.requestStream(payload);
    } catch (Throwable t) {
      return Flux.error(t);
    }
  }

  @Override
  public Flux requestChannel(Publisher payloads) {
    return requestHandler.requestChannel(payloads);
  }

  @Override
  public Flux requestChannel(Payload payload, Publisher payloads) {
    try {
      return responderRSocket != null
          ? responderRSocket.requestChannel(payload, payloads)
          : requestChannel(payloads);
    } catch (Throwable t) {
      return Flux.error(t);
    }
  }

  @Override
  public Mono metadataPush(Payload payload) {
    try {
      return requestHandler.metadataPush(payload);
    } catch (Throwable t) {
      return Mono.error(t);
    }
  }

  @Override
  public void dispose() {
    connection.dispose();
  }

  @Override
  public boolean isDisposed() {
    return connection.isDisposed();
  }

  @Override
  public Mono onClose() {
    return connection.onClose();
  }

  void sendFrame(ByteBuf frame) {
    sendProcessor.onNext(frame);
  }

  void terminate() {
    Throwable e = this.terminationError;
    if (e == null) {
      this.terminationError = e = CLOSED_CHANNEL_EXCEPTION;
    }
    final Throwable err = e;

    synchronized (receivers) {
      receivers
          .values()
          .forEach(
              receiver -> {
                try {
                  receiver.onError(err);
                } catch (Throwable t) {
                  errorConsumer.accept(t);
                }
              });
    }
    synchronized (senders) {
      senders
          .values()
          .forEach(
              sender -> {
                try {
                  sender.cancel();
                } catch (Throwable t) {
                  errorConsumer.accept(t);
                }
              });
    }
    senders.clear();
    receivers.clear();

    requestHandler.dispose();
    sendProcessor.dispose();
  }

  private void handleSendProcessorError(Throwable t) {
    connection.dispose();
  }

  private void handleFrame(ByteBuf frame) {
    try {
      int streamId = FrameHeaderFlyweight.streamId(frame);
      Subscriber receiver;
      FrameType frameType = FrameHeaderFlyweight.strictFrameType(frame);
      switch (frameType) {
        case REQUEST_FNF:
          handleFireAndForget(streamId, fireAndForget(payloadDecoder.apply(frame, frameType)));
          break;
        case REQUEST_RESPONSE:
          handleRequestResponse(streamId, requestResponse(payloadDecoder.apply(frame, frameType)));
          break;
        case REQUEST_N:
          handleRequestN(streamId, frame);
          break;
        case REQUEST_STREAM:
          int streamInitialRequestN = RequestStreamFrameFlyweight.initialRequestN(frame);
          Payload streamPayload = payloadDecoder.apply(frame, frameType);
          handleStream(streamId, requestStream(streamPayload), streamInitialRequestN);
          break;
        case REQUEST_CHANNEL:
          int channelInitialRequestN = RequestChannelFrameFlyweight.initialRequestN(frame);
          Payload channelPayload = payloadDecoder.apply(frame, frameType);
          handleChannel(streamId, channelPayload, channelInitialRequestN);
          break;
        case METADATA_PUSH:
          handleMetadataPush(metadataPush(payloadDecoder.apply(frame, frameType)));
          break;
        case NEXT:
          receiver = receivers.get(streamId);
          if (receiver != null) {
            receiver.onNext(payloadDecoder.apply(frame, frameType));
          }
          break;
        case NEXT_COMPLETE:
          receiver = receivers.get(streamId);
          if (receiver != null) {
            receiver.onNext(payloadDecoder.apply(frame, frameType));
            receiver.onComplete();
          }
          break;
        case COMPLETE:
          receiver = receivers.get(streamId);
          if (receiver != null) {
            receiver.onComplete();
          }
          break;
        case CANCEL:
          Subscription sender = senders.remove(streamId);
          if (sender != null) {
            sender.cancel();
          }
          break;
        case ERROR:
          receiver = receivers.get(streamId);
          if (receiver != null) {
            receiver.onError(errorFrameMapper.streamFrameToError(frame, StreamType.REQUEST));
          }
          break;
        case SETUP:
          disposeConnection(new IllegalStateException("SETUP frame received post setup"));
          break;
        case PAYLOAD:
          disposeConnection(
              new IllegalStateException(
                  "Unexpected PAYLOAD frame received: expect NEXT, NEXT_COMPLETE, COMPLETE"));
          break;
        case LEASE:
          disposeConnection(new IllegalStateException("Unexpected LEASE frame received"));
          break;
        default:
          if (logger.isDebugEnabled()) {
            logger.debug("Unexpected frame received: {}", frameType);
            logger.debug(ByteBufUtil.hexDump(frame));
          }
          break;
      }
      ReferenceCountUtil.safeRelease(frame);
    } catch (Throwable t) {
      ReferenceCountUtil.safeRelease(frame);
      throw Exceptions.propagate(t);
    }
  }

  private void handleFireAndForget(int streamId, Mono result) {
    result.subscribe(
        new BaseSubscriber() {
          @Override
          protected void hookOnSubscribe(Subscription subscription) {
            senders.put(streamId, subscription);
            subscription.request(Long.MAX_VALUE);
          }

          @Override
          protected void hookOnError(Throwable throwable) {
            errorConsumer.accept(throwable);
          }

          @Override
          protected void hookFinally(SignalType type) {
            removeSender(streamId);
          }
        });
  }

  private void handleRequestResponse(int streamId, Mono response) {
    response.subscribe(
        new BaseSubscriber() {
          private boolean isEmpty = true;

          @Override
          protected void hookOnSubscribe(Subscription subscription) {
            senders.put(streamId, subscription);
            subscription.request(Long.MAX_VALUE);
          }

          @Override
          protected void hookOnNext(Payload payload) {
            if (isEmpty) {
              isEmpty = false;
            }

            ByteBuf byteBuf;
            try {
              byteBuf = PayloadFrameFlyweight.encodeNextComplete(allocator, streamId, payload);
            } catch (Throwable t) {
              payload.release();
              throw Exceptions.propagate(t);
            }

            payload.release();

            sendProcessor.onNext(byteBuf);
          }

          @Override
          protected void hookOnError(Throwable throwable) {
            sendProcessor.onNext(
                errorFrameMapper.streamErrorToFrame(streamId, StreamType.RESPONSE, throwable));
          }

          @Override
          protected void hookOnComplete() {
            if (isEmpty) {
              sendProcessor.onNext(PayloadFrameFlyweight.encodeComplete(allocator, streamId));
            }
          }

          @Override
          protected void hookFinally(SignalType type) {
            removeSender(streamId);
          }
        });
  }

  private void handleStream(int streamId, Flux response, int initialRequestN) {
    response.subscribe(
        new BaseSubscriber() {

          @Override
          protected void hookOnSubscribe(Subscription subscription) {
            senders.put(streamId, subscription);
            subscription.request(
                initialRequestN == Integer.MAX_VALUE ? Long.MAX_VALUE : initialRequestN);
          }

          @Override
          protected void hookOnNext(Payload payload) {
            ByteBuf byteBuf;
            try {
              byteBuf = PayloadFrameFlyweight.encodeNext(allocator, streamId, payload);
            } catch (Throwable t) {
              payload.release();
              throw Exceptions.propagate(t);
            }

            payload.release();

            sendProcessor.onNext(byteBuf);
          }

          @Override
          protected void hookOnComplete() {
            sendProcessor.onNext(PayloadFrameFlyweight.encodeComplete(allocator, streamId));
          }

          @Override
          protected void hookOnError(Throwable throwable) {
            sendProcessor.onNext(
                errorFrameMapper.streamErrorToFrame(streamId, StreamType.RESPONSE, throwable));
          }

          @Override
          protected void hookFinally(SignalType type) {
            removeSender(streamId);
          }
        });
  }

  private void handleChannel(int streamId, Payload payload, int initialRequestN) {
    UnicastProcessor frames = UnicastProcessor.create();
    receivers.put(streamId, frames);

    Flux payloads =
        frames
            .doOnCancel(
                () -> sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)))
            .doOnRequest(
                l -> sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, l)))
            .doFinally(signalType -> removeReceiver(streamId));

    // not chained, as the payload should be enqueued in the Unicast processor before this method
    // returns
    // and any later payload can be processed
    frames.onNext(payload);

    handleStream(streamId, requestChannel(payload, payloads), initialRequestN);
  }

  private void handleMetadataPush(Mono result) {
    result.subscribe(
        new BaseSubscriber() {
          @Override
          protected void hookOnSubscribe(Subscription subscription) {
            subscription.request(Long.MAX_VALUE);
          }

          @Override
          protected void hookOnError(Throwable throwable) {
            errorConsumer.accept(throwable);
          }
        });
  }

  private void disposeConnection(Throwable t) {
    terminationError = t;

    connection
        .sendOne(ErrorFrameFlyweight.encode(allocator, 0, t))
        .doFinally(s -> connection.dispose())
        .subscribe(null, errorConsumer);
    errorConsumer.accept(t);
  }

  private void handleRequestN(int streamId, ByteBuf frame) {
    Subscription sender = senders.get(streamId);
    if (sender != null) {
      int n = RequestNFrameFlyweight.requestN(frame);
      sender.request(n == Integer.MAX_VALUE ? Long.MAX_VALUE : n);
    }
  }

  private void removeReceiver(int streamId) {
    /*on termination receivers are explicitly cleared to avoid removing from map while iterating over one
    of its views*/
    if (terminationError == null) {
      receivers.remove(streamId);
    }
  }

  private void removeSender(int streamId) {
    /*on termination receivers are explicitly cleared to avoid removing from map while iterating over one
    of its views*/
    if (terminationError == null) {
      senders.remove(streamId);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy