com.seeq.link.agent.DefaultWsPingManager Maven / Gradle / Ivy
package com.seeq.link.agent;
import static com.seeq.utilities.Locks.CloseableLock;
import static com.seeq.utilities.Locks.tryWithResourcesLock;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayDeque;
import java.util.List;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.google.common.annotations.VisibleForTesting;
import com.seeq.link.agent.interfaces.TooManyPingsLostException;
import com.seeq.link.agent.interfaces.WsPingManager;
import com.seeq.utilities.AutoResetEvent;
import com.seeq.utilities.SeeqNames;
import com.seeq.utilities.process.StackTraceInfo;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class DefaultWsPingManager implements WsPingManager {
private static final byte[] PING_PREFIX = SeeqNames.Agents.WebSockets.PingPrefix.getBytes(StandardCharsets.UTF_8);
private static final int LONG_SIZE_IN_BYTES = Long.BYTES;
private static final int MAX_ALLOWED_PING_TIMEOUTS = 3;
private Thread pingThread = null;
private final AutoResetEvent pingThreadWakeupEvent = new AutoResetEvent(false);
private final ReentrantLock lockObject = new ReentrantLock();
private Duration keepAliveInterval;
private Duration keepAliveTimeout;
private boolean keepAliveEnabled = false;
@VisibleForTesting
ArrayDeque issuedPings = null;
private String connectionId = null;
private Function sendPingFn = null;
public DefaultWsPingManager(Duration keepAliveInterval, Duration keepAliveTimeout, boolean keepAliveEnabled) {
this.keepAliveInterval = keepAliveInterval;
this.keepAliveTimeout = keepAliveTimeout;
this.keepAliveEnabled = keepAliveEnabled;
}
@Override
public void start(String connectionId, Function sendPingFn) {
this.connectionId = connectionId;
this.sendPingFn = sendPingFn;
this.issuedPings = new ArrayDeque<>();
this.pingThread = new Thread(this::pingThreadRun);
this.pingThread.start();
}
@Override
@SuppressWarnings("try")
public void checkStalePings() throws TooManyPingsLostException {
if (!this.keepAliveEnabled) {
return;
}
long now = Instant.now().toEpochMilli();
long pingTimeoutMilliseconds = this.keepAliveTimeout.toMillis();
long threshold = now - pingTimeoutMilliseconds;
try (CloseableLock ignored = tryWithResourcesLock(this.lockObject)) {
List stalePings =
this.issuedPings.stream().filter(p -> p < threshold).collect(Collectors.toList());
if (!stalePings.isEmpty()) {
LOG.debug("Ping timeouts: {}", stalePings.stream()
.map(p -> (now - p) + "ms")
.collect(Collectors.joining(", ")));
}
if (stalePings.size() > MAX_ALLOWED_PING_TIMEOUTS) {
throw new TooManyPingsLostException(
stalePings.size() + " pings detected without a corresponding pong.");
}
}
}
@Override
public void stop() {
shutdownThread(this.pingThread);
this.pingThread = null;
this.issuedPings = null;
this.connectionId = null;
this.sendPingFn = null;
}
@Override
public boolean hasPongFormat(byte[] messageArray) {
return (messageArray.length == LONG_SIZE_IN_BYTES + PING_PREFIX.length) && startsWithPingPrefix(messageArray);
}
@Override
public void handlePong(byte[] messageArray) {
if (!this.keepAliveEnabled) {
return;
}
Long pingTime = extractPingTime(messageArray);
if (pingTime != null) {
this.handlePong(pingTime);
}
}
@Override
@SuppressWarnings("try")
public void enableKeepAlive(Duration newInterval, Duration newTimeout) {
try (CloseableLock ignored = tryWithResourcesLock(this.lockObject)) {
this.keepAliveInterval = newInterval;
this.keepAliveTimeout = newTimeout;
this.keepAliveEnabled = true;
}
}
private void handlePong(long pingTime) {
long now = Instant.now().toEpochMilli();
long timeDifference = now - pingTime;
LOG.trace("Handling pong for {}. Time difference: {}ms", pingTime, timeDifference);
long pingTimeoutMilliseconds = this.keepAliveTimeout.toMillis();
if (timeDifference > pingTimeoutMilliseconds) {
LOG.warn("Ping-Pong roundtrip took {}s, which is longer than " +
"the expected roundtrip of {}s. Please check your network environment! " +
"You may suppress this message by increasing the wsKeepAliveTimeoutSeconds in your agent " +
"configuration.",
timeDifference / 1000, pingTimeoutMilliseconds / 1000);
} else {
this.removePingsBeforeOrAt(pingTime);
}
}
@SuppressWarnings("try")
private void removePingsBeforeOrAt(long pingTime) {
try (CloseableLock ignored = tryWithResourcesLock(this.lockObject)) {
while (!this.issuedPings.isEmpty() && this.issuedPings.getFirst() <= pingTime) {
this.issuedPings.removeFirst();
}
}
}
@SuppressWarnings("try")
private void pingThreadRun() {
try {
Thread.currentThread().setName("Ping - " + this.connectionId);
while (true) {
if (this.keepAliveEnabled) {
long pingTime = Instant.now().toEpochMilli();
LOG.trace("Sending ping {}", pingTime);
this.sendPingFn.apply(createPingMessage(pingTime));
try (CloseableLock ignored = tryWithResourcesLock(this.lockObject)) {
this.issuedPings.add(pingTime);
}
}
this.pingThreadWakeupEvent.waitOne(this.keepAliveInterval);
}
} catch (InterruptedException e) {
LOG.debug("Ping thread interrupted. Shutting it down.");
}
}
private static byte[] createPingMessage(long pingTime) {
byte[] bytes = ByteBuffer.allocate(Long.BYTES).putLong(pingTime).array();
byte[] message = new byte[PING_PREFIX.length + bytes.length];
System.arraycopy(PING_PREFIX, 0, message, 0, PING_PREFIX.length);
System.arraycopy(bytes, 0, message, PING_PREFIX.length, bytes.length);
return message;
}
private static boolean startsWithPingPrefix(byte[] source) {
if (source.length < PING_PREFIX.length) {return false;}
for (int i = 0; i < PING_PREFIX.length; i++) {
if (source[i] != PING_PREFIX[i]) {
return false;
}
}
return true;
}
private static Long extractPingTime(byte[] messageArray) {
if (messageArray.length != LONG_SIZE_IN_BYTES + PING_PREFIX.length) {
LOG.warn("Unexpected number of bytes received in the pong message. Ignoring pong.");
return null;
}
ByteBuffer buffer = ByteBuffer.wrap(messageArray, PING_PREFIX.length, LONG_SIZE_IN_BYTES);
return buffer.getLong();
}
private static void shutdownThread(Thread thread) {
if (thread == null) {
return;
}
final int timeoutInMilliseconds = 10000;
LOG.info("Shutting down thread: {}", thread.getName());
while (true) {
thread.interrupt();
try {
thread.join(timeoutInMilliseconds);
if (!thread.isAlive()) {
break;
}
LOG.info("Continuing to try to shut down thread '{}'. Current stack trace:\n{}",
thread.getName(), StackTraceInfo.getFullStackTrace(thread.getStackTrace()));
} catch (InterruptedException e) {
break;
}
}
LOG.info("Successfully shut down thread: {}", thread.getName());
}
}