reactor.test.subscriber.DefaultTestSubscriber Maven / Gradle / Ivy
/*
* Copyright (c) 2021 VMware Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package reactor.test.subscriber;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.concurrent.atomic.AtomicReference;
import org.reactivestreams.Subscription;
import reactor.core.Fuseable;
import reactor.core.publisher.Operators;
import reactor.core.publisher.Signal;
import reactor.util.annotation.Nullable;
import reactor.util.context.Context;
/**
* Base version of a {@link TestSubscriber} (aka non-conditional).
*
* @author Simon Baslé
*/
class DefaultTestSubscriber implements TestSubscriber {
final long initialRequest;
final Context context;
final DefaultTestSubscriber.FusionRequirement fusionRequirement;
final int requestedFusionMode;
final int expectedFusionMode;
Subscription s;
@Nullable
Fuseable.QueueSubscription qs;
int fusionMode = -1;
// state tracking
final AtomicBoolean cancelled;
final List receivedOnNext;
final List receivedPostCancellation;
final List> protocolErrors;
final CountDownLatch doneLatch;
final AtomicReference subscriptionFailure;
@Nullable
volatile Signal terminalSignal;
volatile int state;
@SuppressWarnings("rawtypes")
static final AtomicIntegerFieldUpdater STATE =
AtomicIntegerFieldUpdater.newUpdater(DefaultTestSubscriber.class, "state");
volatile long requestedTotal;
@SuppressWarnings("rawtypes")
static final AtomicLongFieldUpdater REQUESTED_TOTAL =
AtomicLongFieldUpdater.newUpdater(DefaultTestSubscriber.class, "requestedTotal");
volatile long requestedPreSubscription;
@SuppressWarnings("rawtypes")
static final AtomicLongFieldUpdater REQUESTED_PRE_SUBSCRIPTION =
AtomicLongFieldUpdater.newUpdater(DefaultTestSubscriber.class, "requestedPreSubscription");
DefaultTestSubscriber(TestSubscriberBuilder options) {
this.initialRequest = options.initialRequest;
this.context = options.context;
this.fusionRequirement = options.fusionRequirement;
this.requestedFusionMode = options.requestedFusionMode;
this.expectedFusionMode = options.expectedFusionMode;
this.cancelled = new AtomicBoolean();
this.receivedOnNext = new CopyOnWriteArrayList<>();
this.receivedPostCancellation = new CopyOnWriteArrayList<>();
this.protocolErrors = new CopyOnWriteArrayList<>();
this.state = 0;
this.doneLatch = new CountDownLatch(1);
this.subscriptionFailure = new AtomicReference<>();
REQUESTED_PRE_SUBSCRIPTION.lazySet(this, initialRequest);
}
@Override
public Context currentContext() {
return this.context;
}
void internalCancel() {
Subscription s = this.s;
if (cancelled.compareAndSet(false, true) && s != null) {
s.cancel();
safeClearQueue(s);
}
}
void safeClearQueue(@Nullable Subscription s) {
if (s instanceof Fuseable.QueueSubscription) {
((Fuseable.QueueSubscription) s).clear();
}
}
void subscriptionFail(String message) {
if (this.subscriptionFailure.compareAndSet(null, new AssertionError(message))) {
internalCancel();
notifyDone();
}
}
final void notifyDone() {
doneLatch.countDown();
}
@Override
public void onSubscribe(Subscription s) {
if (cancelled.get()) {
s.cancel();
safeClearQueue(s);
return;
}
if (!Operators.validate(this.s, s)) {
safeClearQueue(s);
//s is already cancelled at that point, subscriptionFail will cancel this.s
subscriptionFail("TestSubscriber must not be reused, but Subscription has already been set.");
return;
}
this.s = s;
this.fusionMode = -1;
if (s instanceof Fuseable.QueueSubscription) {
if (fusionRequirement == FusionRequirement.NOT_FUSEABLE) {
subscriptionFail("TestSubscriber configured to reject QueueSubscription, got " + s);
return;
}
@SuppressWarnings("unchecked") //intermediate variable to suppress via annotation for compiler's benefit
Fuseable.QueueSubscription converted = (Fuseable.QueueSubscription) s;
this.qs = converted;
int negotiatedMode = qs.requestFusion(this.requestedFusionMode);
if (expectedFusionMode != negotiatedMode && expectedFusionMode != Fuseable.ANY) {
subscriptionFail("TestSubscriber negotiated fusion mode inconsistent, expected " +
Fuseable.fusionModeName(expectedFusionMode) + " got " + Fuseable.fusionModeName(negotiatedMode));
return;
}
this.fusionMode = negotiatedMode;
if (negotiatedMode == Fuseable.SYNC) {
for (;;) {
if (cancelled.get()) {
safeClearQueue(qs);
break;
}
T v = qs.poll();
if (v == null) {
onComplete();
break;
}
onNext(v);
}
}
else {
long rPre = REQUESTED_PRE_SUBSCRIPTION.getAndSet(this, -1L);
if (rPre > 0L) {
upstreamRequest(s, rPre);
}
}
}
else if (fusionRequirement == FusionRequirement.FUSEABLE) {
subscriptionFail("TestSubscriber configured to require QueueSubscription, got " + s);
}
else if (this.initialRequest > 0L) {
long rPre = REQUESTED_PRE_SUBSCRIPTION.getAndSet(this, -1L);
if (rPre > 0L) {
upstreamRequest(s, rPre);
}
}
}
@Override
public void onNext(@Nullable T t) {
int previousState = markOnNextStart();
boolean wasTerminated = isMarkedTerminated(previousState);
boolean wasOnNext = isMarkedOnNext(previousState);
if (wasTerminated || wasOnNext) {
//at this point, we know we haven't switched the markedOnNext bit. if it is set, let the other onNext unset it
if (t != null) {
this.protocolErrors.add(Signal.next(t));
}
else if (wasTerminated) {
this.protocolErrors.add(Signal.error(
new AssertionError("onNext(null) received despite SYNC fusion (which has already completed)")
));
}
else {
//due to the looping nature of SYNC fusion in onSubscribe, this shouldn't happen
this.protocolErrors.add(Signal.error(
new AssertionError("onNext(null) received despite SYNC fusion (with concurrent onNext)")
));
}
return;
}
if (t == null) {
if (this.fusionMode == Fuseable.ASYNC) {
drainAsync(false);
return;
}
else {
subscriptionFail("onNext(null) received while ASYNC fusion not established");
}
}
this.receivedOnNext.add(t);
if (cancelled.get()) {
this.receivedPostCancellation.add(t);
}
checkTerminatedAfterOnNext();
}
@Override
public void onComplete() {
Signal sig = Signal.complete();
int previousState = markTerminated();
if (isMarkedTerminated(previousState) || isMarkedTerminating(previousState)) {
this.protocolErrors.add(sig);
return;
}
if (isMarkedOnNext(previousState)) {
this.protocolErrors.add(sig);
this.terminalSignal = sig;
return; //isTerminating will be detected later, triggering the notifyDone()
}
this.terminalSignal = sig;
if (fusionMode == Fuseable.ASYNC) {
drainAsync(true);
return;
}
notifyDone();
}
@Override
public void onError(Throwable t) {
Signal sig = Signal.error(t);
int previousState = markTerminated();
if (isMarkedTerminated(previousState) || isMarkedTerminating(previousState)) {
this.protocolErrors.add(sig);
return;
}
if (isMarkedOnNext(previousState)) {
this.protocolErrors.add(sig);
this.terminalSignal = sig;
return; //isTerminating will be detected later, triggering the notifyDone()
}
this.terminalSignal = sig;
if (fusionMode == Fuseable.ASYNC) {
drainAsync(true);
return;
}
notifyDone();
}
/**
* Drain the subscriber in asynchronous fusion mode (assumes there is a this.qs).
*
* @param isTerminal is the draining happening from onComplete/onError?
*/
void drainAsync(boolean isTerminal) {
assert this.qs != null;
//onComplete and onError move to terminated/terminating and call drainAsync ONLY if no work in progress
int previousState = this.state;
if (isTerminal && isMarkedOnNext(previousState)) {
return;
}
if (isMarkedTerminated(previousState)) {
safeClearQueue(qs);
notifyDone();
return;
}
T t;
for (; ; ) {
if (cancelled.get()) {
safeClearQueue(qs);
notifyDone();
return;
}
long r = REQUESTED_TOTAL.get(this);
if (r != Long.MAX_VALUE && r - this.receivedOnNext.size() < 1) {
//no more request for data, until next request (or termination)
if (checkTerminatedAfterOnNext()) {
safeClearQueue(qs);
}
return;
}
t = qs.poll();
if (t == null) {
//no more available data, until next request (or termination)
if (checkTerminatedAfterOnNext()) {
safeClearQueue(qs);
}
return;
}
this.receivedOnNext.add(t);
}
}
@Nullable
@Override
public Object scanUnsafe(Attr key) {
if (key == Attr.TERMINATED) return terminalSignal != null || subscriptionFailure.get() != null;
if (key == Attr.CANCELLED) return cancelled.get();
if (key == Attr.ERROR) {
Throwable subFailure = subscriptionFailure.get();
Signal sig = terminalSignal;
if (sig != null && sig.getThrowable() != null) {
return sig.getThrowable();
}
else return subFailure; //simplified: ok to return null if subscriptionFailure holds null
}
if (key == Attr.PARENT) return s;
if (key == Attr.RUN_STYLE) return Attr.RunStyle.SYNC;
if (key == Attr.REQUESTED_FROM_DOWNSTREAM) return REQUESTED_TOTAL.get(this);
return null;
}
void upstreamRequest(Subscription s, long n) {
long prev = Operators.addCap(REQUESTED_TOTAL, this, n);
if (prev != Long.MAX_VALUE) {
s.request(n);
}
}
static final int MASK_TERMINATED = 0b1000;
static final int MASK_TERMINATING = 0b0100;
static final int MASK_ON_NEXT = 0b0001;
boolean checkTerminatedAfterOnNext() {
int donePreviousState = markOnNextDone();
if (isMarkedTerminating(donePreviousState)) {
notifyDone();
return true;
}
return false;
}
static boolean isMarkedTerminated(int state) {
return (state & MASK_TERMINATED) == MASK_TERMINATED;
}
static boolean isMarkedOnNext(int state) {
return (state & MASK_ON_NEXT) == MASK_ON_NEXT;
}
static boolean isMarkedTerminating(int state) {
return (state & MASK_TERMINATING) == MASK_TERMINATING
&& (state & MASK_TERMINATED) != MASK_TERMINATED;
}
/**
* Attempt to mark the TestSubscriber as terminated. Does nothing if already terminated.
* Mark as {@link #isMarkedTerminating(int)} if a concurrent onNext is detected.
*
* @return the previous state
*/
int markTerminated() {
for(;;) {
int state = this.state;
if (isMarkedTerminated(state) || isMarkedTerminating(state)) {
return state;
}
int newState;
if (isMarkedOnNext(state)) {
newState = state | MASK_TERMINATING;
}
else {
newState = MASK_TERMINATED;
}
if (STATE.compareAndSet(this, state, newState)) {
return state;
}
}
}
/**
* Mark that onNext processing has started (work in progress) and return the previous state.
* @return the previous state
*/
int markOnNextStart() {
for(;;) {
int state = this.state;
if (state != 0) {
return state;
}
if (STATE.compareAndSet(this, state, MASK_ON_NEXT)) {
return state;
}
}
}
/**
* Mark that onNext processing has terminated and return the previous state.
* @return the previous state
*/
int markOnNextDone() {
for(;;) {
int state = this.state;
int nextState = state & ~MASK_ON_NEXT;
if (STATE.compareAndSet(this, state, nextState)) {
return state;
}
}
}
// == public subscription-like methods
@Override
public void cancel() {
if (cancelled.compareAndSet(false, true)) {
if (this.s != null) {
this.s.cancel();
}
if (requestedFusionMode == Fuseable.ASYNC) {
int st = this.state;
Fuseable.QueueSubscription q = this.qs;
if (!isMarkedOnNext(st) && q != null) {
q.clear();
}
}
notifyDone();
}
}
@Override
public void request(long n) {
if (this.s == null) {
for (;;) {
long prevReq = REQUESTED_PRE_SUBSCRIPTION.get(this);
if (prevReq == -1L) {
request(n); //will propagate upstream
return;
}
long newReq = Operators.addCap(prevReq, n);
if (REQUESTED_PRE_SUBSCRIPTION.compareAndSet(this, prevReq, newReq)) {
return;
}
}
}
if (Operators.validate(n)) {
if (this.fusionMode == Fuseable.SYNC) {
internalCancel();
throw new IllegalStateException("Request is short circuited in SYNC fusion mode, and should not be explicitly used");
}
upstreamRequest(this.s, n);
}
}
void checkSubscriptionFailure() {
AssertionError subscriptionFailure = this.subscriptionFailure.get();
if (subscriptionFailure != null) {
throw subscriptionFailure;
}
}
// == public accessors
@Override
public boolean isTerminatedOrCancelled() {
checkSubscriptionFailure();
return doneLatch.getCount() == 0;
}
@Override
public boolean isTerminated() {
checkSubscriptionFailure();
return terminalSignal != null;
}
@Override
public boolean isTerminatedComplete() {
checkSubscriptionFailure();
Signal ts = this.terminalSignal;
return ts != null && ts.isOnComplete();
}
@Override
public boolean isTerminatedError() {
checkSubscriptionFailure();
Signal ts = this.terminalSignal;
return ts != null && ts.isOnError();
}
@Override
public boolean isCancelled() {
checkSubscriptionFailure();
return cancelled.get();
}
@Override
@Nullable
public Signal getTerminalSignal() {
checkSubscriptionFailure();
return this.terminalSignal;
}
@Override
public Signal expectTerminalSignal() {
checkSubscriptionFailure();
Signal sig = this.terminalSignal;
if (sig == null || (!sig.isOnError() && !sig.isOnComplete())) {
cancel();
throw new AssertionError("Expected subscriber to be terminated, but it has not been terminated yet: cancelling subscription.");
}
return sig;
}
@Override
public Throwable expectTerminalError() {
checkSubscriptionFailure();
Signal sig = this.terminalSignal;
if (sig == null) {
cancel();
throw new AssertionError("Expected subscriber to have errored, but it has not been terminated yet.");
}
if (sig.isOnComplete()) {
throw new AssertionError("Expected subscriber to have errored, but it has completed instead.");
}
Throwable terminal = sig.getThrowable();
if (terminal == null) {
cancel();
throw new AssertionError("Expected subscriber to have errored, got unexpected terminal signal <" + sig + ">.");
}
return terminal;
}
@Override
public List getReceivedOnNext() {
checkSubscriptionFailure();
return new ArrayList<>(this.receivedOnNext);
}
@Override
public List getReceivedOnNextAfterCancellation() {
checkSubscriptionFailure();
return new ArrayList<>(this.receivedPostCancellation);
}
@Override
public List> getProtocolErrors() {
checkSubscriptionFailure();
return new ArrayList<>(this.protocolErrors);
}
@Override
public int getFusionMode() {
checkSubscriptionFailure();
return this.fusionMode;
}
// == blocking and awaiting termination
@Override
public void block() {
try {
this.doneLatch.await();
checkSubscriptionFailure();
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new AssertionError("Block() interrupted", e);
}
}
@Override
public void block(Duration timeout) {
long timeoutMs = timeout.toMillis();
try {
boolean done = this.doneLatch.await(timeoutMs, TimeUnit.MILLISECONDS);
checkSubscriptionFailure();
if (!done) {
throw new AssertionError("TestSubscriber timed out, not terminated after " + timeout + " (" + timeoutMs + "ms)");
}
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new AssertionError("Block(" + timeout +") interrupted", e);
}
}
//TODO should we add a method to await the latch without throwing ?
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy