All Downloads are FREE. Search and download functionalities are using the official Maven repository.

discord4j.gateway.DefaultGatewayClient Maven / Gradle / Ivy

/*
 * This file is part of Discord4J.
 *
 * Discord4J is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Discord4J is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with Discord4J. If not, see .
 */
package discord4j.gateway;

import discord4j.common.GitProperties;
import discord4j.common.LogUtil;
import discord4j.common.ResettableInterval;
import discord4j.common.close.CloseException;
import discord4j.common.close.CloseStatus;
import discord4j.common.close.DisconnectBehavior;
import discord4j.common.operator.RateLimitOperator;
import discord4j.common.retry.ReconnectContext;
import discord4j.common.retry.ReconnectOptions;
import discord4j.discordjson.json.gateway.*;
import discord4j.gateway.json.GatewayPayload;
import discord4j.gateway.limiter.PayloadTransformer;
import discord4j.gateway.payload.PayloadReader;
import discord4j.gateway.payload.PayloadWriter;
import discord4j.gateway.retry.GatewayException;
import discord4j.gateway.retry.GatewayStateChange;
import discord4j.gateway.retry.PartialDisconnectException;
import discord4j.gateway.retry.ReconnectException;
import io.netty.buffer.ByteBuf;
import io.netty.util.IllegalReferenceCountException;
import org.reactivestreams.Publisher;
import reactor.core.publisher.*;
import reactor.netty.ConnectionObserver;
import reactor.netty.http.client.WebsocketClientSpec;
import reactor.retry.Retry;
import reactor.util.Logger;
import reactor.util.Loggers;
import reactor.util.annotation.Nullable;
import reactor.util.context.Context;

import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static discord4j.common.LogUtil.format;
import static io.netty.handler.codec.http.HttpHeaderNames.USER_AGENT;

/**
 * Represents a Discord WebSocket client, called Gateway, implementing its lifecycle.
 * 

* Keeps track of a single websocket session by wrapping an instance of {@link GatewayWebsocketHandler} each time a * new WebSocket connection to Discord is made, therefore only one instance of this class is enough to * handle the lifecycle of the Gateway operations, that could span multiple WebSocket sessions over time. *

* Provides automatic reconnecting through a configurable retry policy, allows consumers to receive inbound events * through {@link #dispatch()}, mapped payloads through {@link #receiver()} and allows a producer to * submit events through {@link #sender()}. *

* Provides sending raw {@link ByteBuf} payloads through {@link #sendBuffer(Publisher)} and receiving raw * {@link ByteBuf} payloads mapped in-flight using a specified mapper using {@link #receiver(Function)}. */ public class DefaultGatewayClient implements GatewayClient { private static final Logger log = Loggers.getLogger(DefaultGatewayClient.class); private static final Logger senderLog = Loggers.getLogger("discord4j.gateway.protocol.sender"); private static final Logger receiverLog = Loggers.getLogger("discord4j.gateway.protocol.receiver"); // basic properties private final GatewayReactorResources reactorResources; private final PayloadReader payloadReader; private final PayloadWriter payloadWriter; private final ReconnectOptions reconnectOptions; private final ReconnectContext reconnectContext; private final IdentifyOptions identifyOptions; private final String token; private final GatewayObserver observer; private final PayloadTransformer identifyLimiter; private final ResettableInterval heartbeat; private final int maxMissedHeartbeatAck; // reactive pipelines private final EmitterProcessor receiver = EmitterProcessor.create(false); private final EmitterProcessor sender = EmitterProcessor.create(false); private final EmitterProcessor dispatch = EmitterProcessor.create(false); private final EmitterProcessor> outbound = EmitterProcessor.create(false); private final EmitterProcessor> heartbeats = EmitterProcessor.create(false); private final FluxSink receiverSink; private final FluxSink senderSink; private final FluxSink dispatchSink; private final FluxSink> outboundSink; private final FluxSink> heartbeatSink; // mutable state, modified here and at PayloadHandlers private final AtomicBoolean connected = new AtomicBoolean(false); private final AtomicBoolean allowResume = new AtomicBoolean(false); private final AtomicInteger sequence = new AtomicInteger(0); private final AtomicReference sessionId = new AtomicReference<>(""); private final AtomicLong lastSent = new AtomicLong(0); private final AtomicLong lastAck = new AtomicLong(0); private final AtomicInteger missedAck = new AtomicInteger(0); private final AtomicLong responseTime = new AtomicLong(0); private volatile MonoProcessor disconnectNotifier; private volatile GatewayWebsocketHandler sessionHandler; /** * Initializes a new GatewayClient. * * @param options the {@link GatewayOptions} to configure this client */ public DefaultGatewayClient(GatewayOptions options) { this.token = Objects.requireNonNull(options.getToken()); this.reactorResources = Objects.requireNonNull(options.getReactorResources()); this.payloadReader = Objects.requireNonNull(options.getPayloadReader()); this.payloadWriter = Objects.requireNonNull(options.getPayloadWriter()); this.reconnectOptions = options.getReconnectOptions(); this.reconnectContext = new ReconnectContext( this.reconnectOptions.getFirstBackoff(), this.reconnectOptions.getMaxBackoffInterval()); this.identifyOptions = Objects.requireNonNull(options.getIdentifyOptions()); this.observer = options.getInitialObserver(); this.identifyLimiter = Objects.requireNonNull(options.getIdentifyLimiter()); this.maxMissedHeartbeatAck = Math.max(0, options.getMaxMissedHeartbeatAck()); // TODO: consider exposing OverflowStrategy to GatewayOptions this.receiverSink = receiver.sink(FluxSink.OverflowStrategy.BUFFER); this.senderSink = sender.sink(FluxSink.OverflowStrategy.ERROR); this.dispatchSink = dispatch.sink(FluxSink.OverflowStrategy.BUFFER); this.outboundSink = outbound.sink(FluxSink.OverflowStrategy.ERROR); this.heartbeatSink = heartbeats.sink(FluxSink.OverflowStrategy.ERROR); this.heartbeat = new ResettableInterval(this.reactorResources.getTimerTaskScheduler()); this.identifyOptions.getResumeSession().ifPresent(resumeSession -> { this.allowResume.set(true); this.sequence.set(resumeSession.getSequence()); this.sessionId.set(resumeSession.getId()); }); } @Override public Mono execute(String gatewayUrl) { return Mono.deferWithContext( context -> { disconnectNotifier = MonoProcessor.create(); lastAck.set(0); lastSent.set(0); missedAck.set(0); MonoProcessor ping = MonoProcessor.create(); // Setup the sending logic from multiple sources into one merged Flux Flux heartbeatFlux = heartbeats.flatMap(payload -> Flux.from(payloadWriter.write(payload))); Flux identifyFlux = outbound.filter(payload -> Opcode.IDENTIFY.equals(payload.getOp())) .delayUntil(payload -> ping) .flatMap(payload -> Flux.from(payloadWriter.write(payload))) .transform(identifyLimiter); Flux payloadFlux = outbound.filter(payload -> !Opcode.IDENTIFY.equals(payload.getOp())) .flatMap(payload -> Flux.from(payloadWriter.write(payload))) .transform(buf -> Flux.merge(buf, sender)); RateLimitOperator outLimiter = new RateLimitOperator<>(outboundLimiterCapacity(), Duration.ofSeconds(60), reactorResources.getTimerTaskScheduler(), reactorResources.getPayloadSenderScheduler()); Flux outFlux = Flux.merge(heartbeatFlux, identifyFlux, payloadFlux) .transform(outLimiter) .doOnNext(buf -> logPayload(senderLog, context, buf)); sessionHandler = new GatewayWebsocketHandler(receiverSink, outFlux, context); Mono readyHandler = dispatch.filter(DefaultGatewayClient::isReadyOrResume) .doOnNext(event -> { connected.set(true); ConnectionObserver.State state; if (reconnectContext.getResetCount() == 0) { log.info(format(context, "Connected to Gateway")); dispatchSink.next(GatewayStateChange.connected()); state = GatewayObserver.CONNECTED; } else { log.info(format(context, "Reconnected to Gateway")); dispatchSink.next(GatewayStateChange.retrySucceeded(reconnectContext.getAttempts())); state = GatewayObserver.RETRY_SUCCEEDED; } reconnectContext.reset(); allowResume.set(true); notifyObserver(state); }) .then(); // Subscribe the receiver to process and transform the inbound payloads into Dispatch events Mono receiverFuture = receiver.map(ByteBuf::retain) .doOnDiscard(ByteBuf.class, DefaultGatewayClient::safeRelease) .doOnNext(buf -> logPayload(receiverLog, context, buf)) .flatMap(payloadReader::read) .doOnNext(payload -> { if (Opcode.HEARTBEAT_ACK.equals(payload.getOp())) { ping.onComplete(); } }) .map(this::updateSequence) .map(payload -> new PayloadContext<>(payload, sessionHandler, this, context)) .doOnNext(PayloadHandlers::handle) .then(); // Subscribe the handler's outbound exchange with our outbound signals // routing completion signals to close the gateway Mono senderFuture = outbound.doOnComplete(sessionHandler::close) .doOnNext(payload -> { if (Opcode.RECONNECT.equals(payload.getOp())) { sessionHandler.error( new GatewayException(context, "Reconnecting due to user action")); } }) .then(); // Create the heartbeat loop, and subscribe it using the sender sink Mono heartbeatHandler = heartbeat.ticks() .flatMap(t -> { long now = System.nanoTime(); lastAck.compareAndSet(0, now); long delay = now - lastAck.get(); if (lastSent.get() - lastAck.get() > 0) { if (missedAck.incrementAndGet() > maxMissedHeartbeatAck) { log.warn(format(context, "Missing heartbeat ACK for {} (tick: {}, seq: {})"), Duration.ofNanos(delay), t, sequence.get()); sessionHandler.error(new GatewayException(context, "Reconnecting due to zombie or failed connection")); return Mono.empty(); } } log.debug(format(context, "Sending heartbeat {} after last ACK"), Duration.ofNanos(delay)); lastSent.set(now); return Mono.just(GatewayPayload.heartbeat(ImmutableHeartbeat.of(sequence.get()))); }) .doOnNext(heartbeatSink::next) .then(); Mono httpFuture = reactorResources.getHttpClient() .headers(headers -> headers.add(USER_AGENT, initUserAgent())) .observe(getObserver(context)) .websocket(WebsocketClientSpec.builder() .maxFramePayloadLength(Integer.MAX_VALUE) .build()) .uri(gatewayUrl) .handle(sessionHandler::handle) .subscriberContext(LogUtil.clearContext()) .flatMap(t2 -> handleClose(t2.getT1(), t2.getT2())) .then(); return Mono.zip(httpFuture, readyHandler, receiverFuture, senderFuture, heartbeatHandler) .doOnError(t -> { if (t instanceof ReconnectException) { log.info(format(context, "{}"), t.getMessage()); } else { if (log.isTraceEnabled()) { log.error(format(context, "Gateway client error"), t); } else { log.error(format(context, "{}"), t.toString()); } } }) .doOnTerminate(heartbeat::stop) .doOnCancel(() -> sessionHandler.close()) .then(); }) .subscriberContext(ctx -> ctx.put(LogUtil.KEY_SHARD_ID, identifyOptions.getShardInfo().getIndex())) .retryWhen(retryFactory()) .then(Mono.defer(() -> disconnectNotifier.then())) .doOnSubscribe(s -> { if (disconnectNotifier != null) { throw new IllegalStateException("execute can only be subscribed once"); } }); } private String initUserAgent() { final Properties properties = GitProperties.getProperties(); final String version = properties.getProperty(GitProperties.APPLICATION_VERSION, "3"); final String url = properties.getProperty(GitProperties.APPLICATION_URL, "https://discord4j.com"); return "DiscordBot(" + url + ", " + version + ")"; } private void logPayload(Logger logger, Context context, ByteBuf buf) { logger.trace(format(context, buf.toString(StandardCharsets.UTF_8) .replaceAll("(\"token\": ?\")([A-Za-z0-9._-]*)(\")", "$1hunter2$3"))); } private static boolean isReadyOrResume(Dispatch d) { return Ready.class.isAssignableFrom(d.getClass()) || Resumed.class.isAssignableFrom(d.getClass()); } private GatewayPayload updateSequence(GatewayPayload payload) { if (payload.getSequence() != null) { sequence.set(payload.getSequence()); notifyObserver(GatewayObserver.SEQUENCE); } return payload; } private Retry retryFactory() { return Retry.onlyIf(t -> isRetryable(t.exception())) .withApplicationContext(reconnectContext) .withBackoffScheduler(reconnectOptions.getBackoffScheduler()) .backoff(reconnectOptions.getBackoff()) .jitter(reconnectOptions.getJitter()) .retryMax(reconnectOptions.getMaxRetries()) .doOnRetry(retryContext -> { connected.set(false); int attempt = retryContext.applicationContext().getAttempts(); Duration backoff = retryContext.backoff(); log.info(format(getContextFromException(retryContext.exception()), "Reconnect attempt {} in {}"), attempt, backoff); if (attempt == 1) { if (!allowResume.get() || !canResume(retryContext.exception())) { dispatchSink.next(GatewayStateChange.retryStarted(backoff)); allowResume.set(false); notifyObserver(GatewayObserver.RETRY_STARTED); } else { dispatchSink.next(GatewayStateChange.retryStartedResume(backoff)); notifyObserver(GatewayObserver.RETRY_RESUME_STARTED); } } else { dispatchSink.next(GatewayStateChange.retryFailed(attempt - 1, backoff)); notifyObserver(GatewayObserver.RETRY_FAILED); allowResume.set(false); } retryContext.applicationContext().next(); }); } private static final List nonRetryableStatusCodes = Arrays.asList( 4004, // Authentication failed 4010, // Invalid shard 4011, // Sharding required 4012, // Invalid API version 4013, // Invalid intent(s) 4014 // Disallowed intent(s) ); private boolean isRetryable(@Nullable Throwable t) { if (t instanceof CloseException) { CloseException closeException = (CloseException) t; return !nonRetryableStatusCodes.contains(closeException.getCode()); } return !(t instanceof PartialDisconnectException); } private boolean canResume(Throwable t) { if (t instanceof CloseException) { CloseException closeException = (CloseException) t; return closeException.getCode() < 4000; } return true; } private Context getContextFromException(Throwable t) { if (t instanceof CloseException) { return ((CloseException) t).getContext(); } if (t instanceof GatewayException) { return ((GatewayException) t).getContext(); } return Context.empty(); } private Mono handleClose(DisconnectBehavior sourceBehavior, CloseStatus closeStatus) { return Mono.deferWithContext(ctx -> { DisconnectBehavior behavior; if (nonRetryableStatusCodes.contains(closeStatus.getCode())) { // non-retryable close codes are non-transient errors therefore stopping is the only choice behavior = DisconnectBehavior.stop(sourceBehavior.getCause()); } else { behavior = sourceBehavior; } log.info(format(ctx, "Handling close {} with behavior: {}"), closeStatus, behavior); heartbeat.stop(); reconnectContext.clear(); connected.set(false); lastSent.set(0); lastAck.set(0); responseTime.set(0); if (behavior.getAction() == DisconnectBehavior.Action.STOP_ABRUPTLY) { dispatchSink.next(GatewayStateChange.disconnectedResume()); notifyObserver(GatewayObserver.DISCONNECTED_RESUME); } else if (behavior.getAction() == DisconnectBehavior.Action.STOP) { dispatchSink.next(GatewayStateChange.disconnected()); allowResume.set(false); sequence.set(0); sessionId.set(""); notifyObserver(GatewayObserver.DISCONNECTED); } switch (behavior.getAction()) { case STOP_ABRUPTLY: case STOP: if (behavior.getCause() != null) { return Mono.just(new CloseException(closeStatus, ctx, behavior.getCause())) .flatMap(ex -> { disconnectNotifier.onError(ex); return Mono.error(ex); }); } return Mono.just(closeStatus).doOnNext(status -> disconnectNotifier.onNext(closeStatus)); case RETRY_ABRUPTLY: case RETRY: default: return Mono.error(new CloseException(closeStatus, ctx, behavior.getCause())); } }); } private ConnectionObserver getObserver(Context context) { return (connection, newState) -> { log.debug(format(context, "{} {}"), newState, connection); notifyObserver(newState); }; } private void notifyObserver(ConnectionObserver.State state) { observer.onStateChange(state, this); } @Override public Mono close(boolean allowResume) { return Mono.defer(() -> { if (sessionHandler == null || disconnectNotifier == null) { return Mono.error(new IllegalStateException("Gateway client is not active!")); } if (!disconnectNotifier.isTerminated()) { if (allowResume) { sessionHandler.close(DisconnectBehavior.stopAbruptly(null)); } else { sessionHandler.close(DisconnectBehavior.stop(null)); } } return disconnectNotifier.then(); }); } @Override public Flux dispatch() { return dispatch; } @Override public Flux> receiver() { return receiver(payloadReader::read); } @Override public Flux receiver(Function> mapper) { return receiver.map(ByteBuf::retainedDuplicate) .doOnDiscard(ByteBuf.class, DefaultGatewayClient::safeRelease) .flatMap(mapper); } private static void safeRelease(ByteBuf buf) { if (buf.refCnt() > 0) { try { buf.release(); } catch (IllegalReferenceCountException e) { if (log.isDebugEnabled()) { log.debug("", e); } } } } @Override public FluxSink> sender() { return outboundSink; } @Override public Mono sendBuffer(Publisher publisher) { return Flux.from(publisher).doOnNext(senderSink::next).then(); } @Override public int getShardCount() { return identifyOptions.getShardInfo().getCount(); } @Override public String getSessionId() { return sessionId.get(); } @Override public int getSequence() { return sequence.get(); } @Override public boolean isConnected() { return connected.get(); } @Override public Duration getResponseTime() { return Duration.ofNanos(responseTime.get()); } ///////////////////////////////// // Methods for PayloadHandlers // ///////////////////////////////// void ackHeartbeat() { missedAck.set(0); responseTime.set(lastAck.updateAndGet(x -> System.nanoTime()) - lastSent.get()); } //////////////////////////////// // Fields for PayloadHandlers // //////////////////////////////// /** * Obtains the FluxSink to send Dispatch events towards GatewayClient's users. * * @return a {@link FluxSink} for {@link Dispatch} * objects */ FluxSink dispatchSink() { return dispatchSink; } /** * Gets the atomic reference for the current heartbeat sequence. * * @return an AtomicInteger representing the current gateway sequence */ AtomicInteger sequence() { return sequence; } /** * Gets the atomic reference for the current session ID. * * @return an AtomicReference of the String representing the current session ID */ AtomicReference sessionId() { return sessionId; } /** * Gets the heartbeat manager bound to this GatewayClient. * * @return a {@link ResettableInterval} to manipulate heartbeat operations */ ResettableInterval heartbeat() { return heartbeat; } /** * Gets the token used to connect to the gateway. * * @return a token String */ String token() { return token; } /** * An boolean value indicating if this client will attempt to RESUME. * * @return an AtomicBoolean representing resume capabilities */ AtomicBoolean allowResume() { return allowResume; } /** * Gets the configuration object for gateway identifying procedure. * * @return an IdentifyOptions configuration object */ IdentifyOptions identifyOptions() { return identifyOptions; } /** * JVM property that allows modifying the number of outbound payloads permitted before activating the * rate-limiter and delaying every following payload for 60 seconds. Default value: 115 permits */ private static final String OUTBOUND_CAPACITY_PROPERTY = "discord4j.gateway.outbound.capacity"; private int outboundLimiterCapacity() { String capacityValue = System.getProperty(OUTBOUND_CAPACITY_PROPERTY); if (capacityValue != null) { try { int capacity = Integer.parseInt(capacityValue); log.info("Overriding default outbound limiter capacity: {}", capacity); } catch (NumberFormatException e) { log.warn("Invalid custom outbound limiter capacity: {}", capacityValue); } } return 120; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy