All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.rsocket.core.DefaultRSocketClient Maven / Gradle / Ivy

/*
 * Copyright 2015-2020 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.rsocket.core;

import io.netty.util.IllegalReferenceCountException;
import io.netty.util.ReferenceCounted;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.frame.FrameType;
import java.util.AbstractMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.CorePublisher;
import reactor.core.CoreSubscriber;
import reactor.core.Scannable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoOperator;
import reactor.core.publisher.Operators;
import reactor.util.annotation.Nullable;
import reactor.util.context.Context;

/**
 * Default implementation of {@link RSocketClient}
 *
 * @since 1.0.1
 */
class DefaultRSocketClient extends ResolvingOperator
    implements CoreSubscriber, CorePublisher, RSocketClient {
  static final Consumer DISCARD_ELEMENTS_CONSUMER =
      referenceCounted -> {
        if (referenceCounted.refCnt() > 0) {
          try {
            referenceCounted.release();
          } catch (IllegalReferenceCountException e) {
            // ignored
          }
        }
      };

  static final Object ON_DISCARD_KEY;

  static {
    Context discardAwareContext = Operators.enableOnDiscard(null, DISCARD_ELEMENTS_CONSUMER);
    ON_DISCARD_KEY = discardAwareContext.stream().findFirst().get().getKey();
  }

  final Mono source;

  volatile Subscription s;

  static final AtomicReferenceFieldUpdater S =
      AtomicReferenceFieldUpdater.newUpdater(DefaultRSocketClient.class, Subscription.class, "s");

  DefaultRSocketClient(Mono source) {
    this.source = unwrapReconnectMono(source);
  }

  private Mono unwrapReconnectMono(Mono source) {
    return source instanceof ReconnectMono ? ((ReconnectMono) source).getSource() : source;
  }

  @Override
  public Mono source() {
    return Mono.fromDirect(this);
  }

  @Override
  public Mono fireAndForget(Mono payloadMono) {
    return new RSocketClientMonoOperator<>(this, FrameType.REQUEST_FNF, payloadMono);
  }

  @Override
  public Mono requestResponse(Mono payloadMono) {
    return new RSocketClientMonoOperator<>(this, FrameType.REQUEST_RESPONSE, payloadMono);
  }

  @Override
  public Flux requestStream(Mono payloadMono) {
    return new RSocketClientFluxOperator<>(this, FrameType.REQUEST_STREAM, payloadMono);
  }

  @Override
  public Flux requestChannel(Publisher payloads) {
    return new RSocketClientFluxOperator<>(this, FrameType.REQUEST_CHANNEL, payloads);
  }

  @Override
  public Mono metadataPush(Mono payloadMono) {
    return new RSocketClientMonoOperator<>(this, FrameType.METADATA_PUSH, payloadMono);
  }

  @Override
  @SuppressWarnings("uncheked")
  public void subscribe(CoreSubscriber actual) {
    final ResolvingOperator.MonoDeferredResolutionOperator inner =
        new ResolvingOperator.MonoDeferredResolutionOperator<>(this, actual);
    actual.onSubscribe(inner);

    this.observe(inner);
  }

  @Override
  public void subscribe(Subscriber s) {
    subscribe(Operators.toCoreSubscriber(s));
  }

  @Override
  public void onSubscribe(Subscription s) {
    if (Operators.setOnce(S, this, s)) {
      s.request(Long.MAX_VALUE);
    }
  }

  @Override
  public void onComplete() {
    final Subscription s = this.s;
    final RSocket value = this.value;

    if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) {
      this.doFinally();
      return;
    }

    if (value == null) {
      this.terminate(new IllegalStateException("Source completed empty"));
    } else {
      this.complete(value);
    }
  }

  @Override
  public void onError(Throwable t) {
    final Subscription s = this.s;

    if (s == Operators.cancelledSubscription()
        || S.getAndSet(this, Operators.cancelledSubscription())
            == Operators.cancelledSubscription()) {
      this.doFinally();
      Operators.onErrorDropped(t, Context.empty());
      return;
    }

    this.doFinally();
    // terminate upstream which means retryBackoff has exhausted
    this.terminate(t);
  }

  @Override
  public void onNext(RSocket value) {
    if (this.s == Operators.cancelledSubscription()) {
      this.doOnValueExpired(value);
      return;
    }

    this.value = value;
    // volatile write and check on racing
    this.doFinally();
  }

  @Override
  protected void doSubscribe() {
    this.source.subscribe(this);
  }

  @Override
  protected void doOnValueResolved(RSocket value) {
    value.onClose().subscribe(null, t -> this.invalidate(), this::invalidate);
  }

  @Override
  protected void doOnValueExpired(RSocket value) {
    value.dispose();
  }

  @Override
  protected void doOnDispose() {
    Operators.terminate(S, this);
  }

  static final class FlatMapMain implements CoreSubscriber, Context, Scannable {

    final DefaultRSocketClient parent;
    final CoreSubscriber actual;

    final FlattingInner second;

    Subscription s;

    boolean done;

    FlatMapMain(
        DefaultRSocketClient parent, CoreSubscriber actual, FrameType requestType) {
      this.parent = parent;
      this.actual = actual;
      this.second = new FlattingInner<>(parent, this, actual, requestType);
    }

    @Override
    public Context currentContext() {
      return this;
    }

    @Override
    public Stream inners() {
      return Stream.of(this.second);
    }

    @Override
    @Nullable
    public Object scanUnsafe(Attr key) {
      if (key == Attr.PARENT) return this.s;
      if (key == Attr.CANCELLED) return this.second.isCancelled();
      if (key == Attr.TERMINATED) return this.done;

      return null;
    }

    @Override
    public void onSubscribe(Subscription s) {
      if (Operators.validate(this.s, s)) {
        this.s = s;
        this.actual.onSubscribe(this.second);
      }
    }

    @Override
    public void onNext(Payload payload) {
      if (this.done) {
        if (payload.refCnt() > 0) {
          try {
            payload.release();
          } catch (IllegalReferenceCountException e) {
            // ignored
          }
        }
        return;
      }
      this.done = true;

      final FlattingInner inner = this.second;

      if (inner.isCancelled()) {
        if (payload.refCnt() > 0) {
          try {
            payload.release();
          } catch (IllegalReferenceCountException e) {
            // ignored
          }
        }
        return;
      }

      inner.payload = payload;

      if (inner.isCancelled()) {
        if (FlattingInner.PAYLOAD.compareAndSet(inner, payload, null)) {
          if (payload.refCnt() > 0) {
            try {
              payload.release();
            } catch (IllegalReferenceCountException e) {
              // ignored
            }
          }
        }
        return;
      }

      this.parent.observe(inner);
    }

    @Override
    public void onError(Throwable t) {
      if (this.done) {
        Operators.onErrorDropped(t, this.actual.currentContext());
        return;
      }
      this.done = true;

      this.actual.onError(t);
    }

    @Override
    public void onComplete() {
      if (this.done) {
        return;
      }
      this.done = true;

      this.actual.onComplete();
    }

    void request(long n) {
      this.s.request(n);
    }

    void cancel() {
      this.s.cancel();
    }

    @Override
    @SuppressWarnings("unchecked")
    public  K get(Object key) {
      if (key == ON_DISCARD_KEY) {
        return (K) DISCARD_ELEMENTS_CONSUMER;
      }
      return this.actual.currentContext().get(key);
    }

    @Override
    public boolean hasKey(Object key) {
      if (key == ON_DISCARD_KEY) {
        return true;
      }
      return this.actual.currentContext().hasKey(key);
    }

    @Override
    public Context put(Object key, Object value) {
      return this.actual
          .currentContext()
          .put(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER)
          .put(key, value);
    }

    @Override
    public Context delete(Object key) {
      return this.actual
          .currentContext()
          .put(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER)
          .delete(key);
    }

    @Override
    public int size() {
      return this.actual.currentContext().size() + 1;
    }

    @Override
    public Stream> stream() {
      return Stream.concat(
          Stream.of(
              new AbstractMap.SimpleImmutableEntry<>(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER)),
          this.actual.currentContext().stream());
    }
  }

  static final class FlattingInner extends DeferredResolution {

    final FlatMapMain main;
    final FrameType interactionType;

    volatile Payload payload;

    @SuppressWarnings("rawtypes")
    static final AtomicReferenceFieldUpdater PAYLOAD =
        AtomicReferenceFieldUpdater.newUpdater(FlattingInner.class, Payload.class, "payload");

    FlattingInner(
        DefaultRSocketClient parent,
        FlatMapMain main,
        CoreSubscriber actual,
        FrameType interactionType) {
      super(parent, actual);

      this.main = main;
      this.interactionType = interactionType;
    }

    @Override
    @SuppressWarnings({"unchecked", "rawtypes"})
    public void accept(RSocket rSocket, Throwable t) {
      if (this.isCancelled()) {
        return;
      }

      Payload payload = PAYLOAD.getAndSet(this, null);

      // means cancelled
      if (payload == null) {
        return;
      }

      if (t != null) {
        if (payload.refCnt() > 0) {
          try {
            payload.release();
          } catch (IllegalReferenceCountException e) {
            // ignored
          }
        }
        onError(t);
        return;
      }

      CorePublisher source;
      switch (this.interactionType) {
        case REQUEST_FNF:
          source = rSocket.fireAndForget(payload);
          break;
        case REQUEST_RESPONSE:
          source = rSocket.requestResponse(payload);
          break;
        case REQUEST_STREAM:
          source = rSocket.requestStream(payload);
          break;
        case METADATA_PUSH:
          source = rSocket.metadataPush(payload);
          break;
        default:
          this.onError(new IllegalStateException("Should never happen"));
          return;
      }

      source.subscribe((CoreSubscriber) this);
    }

    @Override
    public void request(long n) {
      this.main.request(n);
      super.request(n);
    }

    public void cancel() {
      long state = REQUESTED.getAndSet(this, STATE_CANCELLED);
      if (state == STATE_CANCELLED) {
        return;
      }

      this.main.cancel();

      if (state == STATE_SUBSCRIBED) {
        this.s.cancel();
      } else {
        this.parent.remove(this);
        Payload payload = PAYLOAD.getAndSet(this, null);
        if (payload != null) {
          payload.release();
        }
      }
    }
  }

  static final class RequestChannelInner extends DeferredResolution {

    final FrameType interactionType;
    final Publisher upstream;

    RequestChannelInner(
        DefaultRSocketClient parent,
        Publisher upstream,
        CoreSubscriber actual,
        FrameType interactionType) {
      super(parent, actual);

      this.upstream = upstream;
      this.interactionType = interactionType;
    }

    @Override
    public void accept(RSocket rSocket, Throwable t) {
      if (this.isCancelled()) {
        return;
      }

      if (t != null) {
        onError(t);
        return;
      }

      Flux source;
      if (this.interactionType == FrameType.REQUEST_CHANNEL) {
        source = rSocket.requestChannel(this.upstream);
      } else {
        this.onError(new IllegalStateException("Should never happen"));
        return;
      }

      source.subscribe(this);
    }
  }

  static class RSocketClientMonoOperator extends MonoOperator {

    final DefaultRSocketClient parent;
    final FrameType requestType;

    public RSocketClientMonoOperator(
        DefaultRSocketClient parent, FrameType requestType, Mono source) {
      super(source);
      this.parent = parent;
      this.requestType = requestType;
    }

    @Override
    public void subscribe(CoreSubscriber actual) {
      this.source.subscribe(new FlatMapMain(this.parent, actual, this.requestType));
    }
  }

  static class RSocketClientFluxOperator> extends Flux {

    final DefaultRSocketClient parent;
    final FrameType requestType;
    final ST source;

    public RSocketClientFluxOperator(
        DefaultRSocketClient parent, FrameType requestType, ST source) {
      this.parent = parent;
      this.requestType = requestType;
      this.source = source;
    }

    @Override
    public void subscribe(CoreSubscriber actual) {
      if (requestType == FrameType.REQUEST_CHANNEL) {
        RequestChannelInner inner =
            new RequestChannelInner(this.parent, source, actual, requestType);
        actual.onSubscribe(inner);
        this.parent.observe(inner);
      } else {
        this.source.subscribe(new FlatMapMain<>(this.parent, actual, this.requestType));
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy