io.rsocket.core.DefaultRSocketClient Maven / Gradle / Ivy
/*
* Copyright 2015-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.rsocket.core;
import io.netty.util.IllegalReferenceCountException;
import io.netty.util.ReferenceCounted;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.frame.FrameType;
import java.util.AbstractMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.CorePublisher;
import reactor.core.CoreSubscriber;
import reactor.core.Scannable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoOperator;
import reactor.core.publisher.Operators;
import reactor.util.annotation.Nullable;
import reactor.util.context.Context;
/**
* Default implementation of {@link RSocketClient}
*
* @since 1.0.1
*/
class DefaultRSocketClient extends ResolvingOperator
implements CoreSubscriber, CorePublisher, RSocketClient {
static final Consumer DISCARD_ELEMENTS_CONSUMER =
referenceCounted -> {
if (referenceCounted.refCnt() > 0) {
try {
referenceCounted.release();
} catch (IllegalReferenceCountException e) {
// ignored
}
}
};
static final Object ON_DISCARD_KEY;
static {
Context discardAwareContext = Operators.enableOnDiscard(null, DISCARD_ELEMENTS_CONSUMER);
ON_DISCARD_KEY = discardAwareContext.stream().findFirst().get().getKey();
}
final Mono source;
volatile Subscription s;
static final AtomicReferenceFieldUpdater S =
AtomicReferenceFieldUpdater.newUpdater(DefaultRSocketClient.class, Subscription.class, "s");
DefaultRSocketClient(Mono source) {
this.source = unwrapReconnectMono(source);
}
private Mono unwrapReconnectMono(Mono source) {
return source instanceof ReconnectMono ? ((ReconnectMono) source).getSource() : source;
}
@Override
public Mono source() {
return Mono.fromDirect(this);
}
@Override
public Mono fireAndForget(Mono payloadMono) {
return new RSocketClientMonoOperator<>(this, FrameType.REQUEST_FNF, payloadMono);
}
@Override
public Mono requestResponse(Mono payloadMono) {
return new RSocketClientMonoOperator<>(this, FrameType.REQUEST_RESPONSE, payloadMono);
}
@Override
public Flux requestStream(Mono payloadMono) {
return new RSocketClientFluxOperator<>(this, FrameType.REQUEST_STREAM, payloadMono);
}
@Override
public Flux requestChannel(Publisher payloads) {
return new RSocketClientFluxOperator<>(this, FrameType.REQUEST_CHANNEL, payloads);
}
@Override
public Mono metadataPush(Mono payloadMono) {
return new RSocketClientMonoOperator<>(this, FrameType.METADATA_PUSH, payloadMono);
}
@Override
@SuppressWarnings("uncheked")
public void subscribe(CoreSubscriber super RSocket> actual) {
final ResolvingOperator.MonoDeferredResolutionOperator inner =
new ResolvingOperator.MonoDeferredResolutionOperator<>(this, actual);
actual.onSubscribe(inner);
this.observe(inner);
}
@Override
public void subscribe(Subscriber super RSocket> s) {
subscribe(Operators.toCoreSubscriber(s));
}
@Override
public void onSubscribe(Subscription s) {
if (Operators.setOnce(S, this, s)) {
s.request(Long.MAX_VALUE);
}
}
@Override
public void onComplete() {
final Subscription s = this.s;
final RSocket value = this.value;
if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) {
this.doFinally();
return;
}
if (value == null) {
this.terminate(new IllegalStateException("Source completed empty"));
} else {
this.complete(value);
}
}
@Override
public void onError(Throwable t) {
final Subscription s = this.s;
if (s == Operators.cancelledSubscription()
|| S.getAndSet(this, Operators.cancelledSubscription())
== Operators.cancelledSubscription()) {
this.doFinally();
Operators.onErrorDropped(t, Context.empty());
return;
}
this.doFinally();
// terminate upstream which means retryBackoff has exhausted
this.terminate(t);
}
@Override
public void onNext(RSocket value) {
if (this.s == Operators.cancelledSubscription()) {
this.doOnValueExpired(value);
return;
}
this.value = value;
// volatile write and check on racing
this.doFinally();
}
@Override
protected void doSubscribe() {
this.source.subscribe(this);
}
@Override
protected void doOnValueResolved(RSocket value) {
value.onClose().subscribe(null, t -> this.invalidate(), this::invalidate);
}
@Override
protected void doOnValueExpired(RSocket value) {
value.dispose();
}
@Override
protected void doOnDispose() {
Operators.terminate(S, this);
}
static final class FlatMapMain implements CoreSubscriber, Context, Scannable {
final DefaultRSocketClient parent;
final CoreSubscriber super R> actual;
final FlattingInner second;
Subscription s;
boolean done;
FlatMapMain(
DefaultRSocketClient parent, CoreSubscriber super R> actual, FrameType requestType) {
this.parent = parent;
this.actual = actual;
this.second = new FlattingInner<>(parent, this, actual, requestType);
}
@Override
public Context currentContext() {
return this;
}
@Override
public Stream extends Scannable> inners() {
return Stream.of(this.second);
}
@Override
@Nullable
public Object scanUnsafe(Attr key) {
if (key == Attr.PARENT) return this.s;
if (key == Attr.CANCELLED) return this.second.isCancelled();
if (key == Attr.TERMINATED) return this.done;
return null;
}
@Override
public void onSubscribe(Subscription s) {
if (Operators.validate(this.s, s)) {
this.s = s;
this.actual.onSubscribe(this.second);
}
}
@Override
public void onNext(Payload payload) {
if (this.done) {
if (payload.refCnt() > 0) {
try {
payload.release();
} catch (IllegalReferenceCountException e) {
// ignored
}
}
return;
}
this.done = true;
final FlattingInner inner = this.second;
if (inner.isCancelled()) {
if (payload.refCnt() > 0) {
try {
payload.release();
} catch (IllegalReferenceCountException e) {
// ignored
}
}
return;
}
inner.payload = payload;
if (inner.isCancelled()) {
if (FlattingInner.PAYLOAD.compareAndSet(inner, payload, null)) {
if (payload.refCnt() > 0) {
try {
payload.release();
} catch (IllegalReferenceCountException e) {
// ignored
}
}
}
return;
}
this.parent.observe(inner);
}
@Override
public void onError(Throwable t) {
if (this.done) {
Operators.onErrorDropped(t, this.actual.currentContext());
return;
}
this.done = true;
this.actual.onError(t);
}
@Override
public void onComplete() {
if (this.done) {
return;
}
this.done = true;
this.actual.onComplete();
}
void request(long n) {
this.s.request(n);
}
void cancel() {
this.s.cancel();
}
@Override
@SuppressWarnings("unchecked")
public K get(Object key) {
if (key == ON_DISCARD_KEY) {
return (K) DISCARD_ELEMENTS_CONSUMER;
}
return this.actual.currentContext().get(key);
}
@Override
public boolean hasKey(Object key) {
if (key == ON_DISCARD_KEY) {
return true;
}
return this.actual.currentContext().hasKey(key);
}
@Override
public Context put(Object key, Object value) {
return this.actual
.currentContext()
.put(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER)
.put(key, value);
}
@Override
public Context delete(Object key) {
return this.actual
.currentContext()
.put(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER)
.delete(key);
}
@Override
public int size() {
return this.actual.currentContext().size() + 1;
}
@Override
public Stream> stream() {
return Stream.concat(
Stream.of(
new AbstractMap.SimpleImmutableEntry<>(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER)),
this.actual.currentContext().stream());
}
}
static final class FlattingInner extends DeferredResolution {
final FlatMapMain main;
final FrameType interactionType;
volatile Payload payload;
@SuppressWarnings("rawtypes")
static final AtomicReferenceFieldUpdater PAYLOAD =
AtomicReferenceFieldUpdater.newUpdater(FlattingInner.class, Payload.class, "payload");
FlattingInner(
DefaultRSocketClient parent,
FlatMapMain main,
CoreSubscriber super T> actual,
FrameType interactionType) {
super(parent, actual);
this.main = main;
this.interactionType = interactionType;
}
@Override
@SuppressWarnings({"unchecked", "rawtypes"})
public void accept(RSocket rSocket, Throwable t) {
if (this.isCancelled()) {
return;
}
Payload payload = PAYLOAD.getAndSet(this, null);
// means cancelled
if (payload == null) {
return;
}
if (t != null) {
if (payload.refCnt() > 0) {
try {
payload.release();
} catch (IllegalReferenceCountException e) {
// ignored
}
}
onError(t);
return;
}
CorePublisher> source;
switch (this.interactionType) {
case REQUEST_FNF:
source = rSocket.fireAndForget(payload);
break;
case REQUEST_RESPONSE:
source = rSocket.requestResponse(payload);
break;
case REQUEST_STREAM:
source = rSocket.requestStream(payload);
break;
case METADATA_PUSH:
source = rSocket.metadataPush(payload);
break;
default:
this.onError(new IllegalStateException("Should never happen"));
return;
}
source.subscribe((CoreSubscriber) this);
}
@Override
public void request(long n) {
this.main.request(n);
super.request(n);
}
public void cancel() {
long state = REQUESTED.getAndSet(this, STATE_CANCELLED);
if (state == STATE_CANCELLED) {
return;
}
this.main.cancel();
if (state == STATE_SUBSCRIBED) {
this.s.cancel();
} else {
this.parent.remove(this);
Payload payload = PAYLOAD.getAndSet(this, null);
if (payload != null) {
payload.release();
}
}
}
}
static final class RequestChannelInner extends DeferredResolution {
final FrameType interactionType;
final Publisher upstream;
RequestChannelInner(
DefaultRSocketClient parent,
Publisher upstream,
CoreSubscriber super Payload> actual,
FrameType interactionType) {
super(parent, actual);
this.upstream = upstream;
this.interactionType = interactionType;
}
@Override
public void accept(RSocket rSocket, Throwable t) {
if (this.isCancelled()) {
return;
}
if (t != null) {
onError(t);
return;
}
Flux source;
if (this.interactionType == FrameType.REQUEST_CHANNEL) {
source = rSocket.requestChannel(this.upstream);
} else {
this.onError(new IllegalStateException("Should never happen"));
return;
}
source.subscribe(this);
}
}
static class RSocketClientMonoOperator extends MonoOperator {
final DefaultRSocketClient parent;
final FrameType requestType;
public RSocketClientMonoOperator(
DefaultRSocketClient parent, FrameType requestType, Mono source) {
super(source);
this.parent = parent;
this.requestType = requestType;
}
@Override
public void subscribe(CoreSubscriber super T> actual) {
this.source.subscribe(new FlatMapMain(this.parent, actual, this.requestType));
}
}
static class RSocketClientFluxOperator> extends Flux {
final DefaultRSocketClient parent;
final FrameType requestType;
final ST source;
public RSocketClientFluxOperator(
DefaultRSocketClient parent, FrameType requestType, ST source) {
this.parent = parent;
this.requestType = requestType;
this.source = source;
}
@Override
public void subscribe(CoreSubscriber super Payload> actual) {
if (requestType == FrameType.REQUEST_CHANNEL) {
RequestChannelInner inner =
new RequestChannelInner(this.parent, source, actual, requestType);
actual.onSubscribe(inner);
this.parent.observe(inner);
} else {
this.source.subscribe(new FlatMapMain<>(this.parent, actual, this.requestType));
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy