
io.helidon.webserver.websocket.WsConnection Maven / Gradle / Ivy
/*
* Copyright (c) 2022, 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.helidon.webserver.websocket;
import java.lang.System.Logger.Level;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.ZonedDateTime;
import java.util.Optional;
import java.util.concurrent.Semaphore;
import io.helidon.common.buffers.BufferData;
import io.helidon.common.buffers.DataReader;
import io.helidon.common.concurrency.limits.FixedLimit;
import io.helidon.common.concurrency.limits.Limit;
import io.helidon.common.concurrency.limits.LimitException;
import io.helidon.common.socket.SocketContext;
import io.helidon.http.DateTime;
import io.helidon.http.Headers;
import io.helidon.http.HttpPrologue;
import io.helidon.webserver.CloseConnectionException;
import io.helidon.webserver.ConnectionContext;
import io.helidon.webserver.spi.ServerConnection;
import io.helidon.websocket.ClientWsFrame;
import io.helidon.websocket.ServerWsFrame;
import io.helidon.websocket.WsCloseCodes;
import io.helidon.websocket.WsCloseException;
import io.helidon.websocket.WsListener;
import io.helidon.websocket.WsOpCode;
import io.helidon.websocket.WsSession;
/**
* WebSocket connection, server side session implementation.
*/
public class WsConnection implements ServerConnection, WsSession {
private static final System.Logger LOGGER = System.getLogger(WsConnection.class.getName());
static final int MAX_FRAME_LENGTH = 1048576;
private final ConnectionContext ctx;
private final HttpPrologue prologue;
private final Headers upgradeHeaders;
private final String wsKey;
private final WsListener listener;
private final WsConfig wsConfig;
private final BufferData sendBuffer = BufferData.growing(1024);
private final DataReader dataReader;
private ContinuationType recvContinuation = ContinuationType.NONE;
private boolean sendContinuation;
private boolean closeSent;
private volatile Thread myThread;
private volatile boolean canRun = true;
private volatile boolean readingNetwork;
private volatile ZonedDateTime lastRequestTimestamp;
private WsConnection(ConnectionContext ctx,
HttpPrologue prologue,
Headers upgradeHeaders,
String wsKey,
WsListener wsListener) {
this.ctx = ctx;
this.prologue = prologue;
this.upgradeHeaders = upgradeHeaders;
this.wsKey = wsKey;
this.listener = wsListener;
this.dataReader = ctx.dataReader();
this.lastRequestTimestamp = DateTime.timestamp();
this.wsConfig = (WsConfig) ctx.listenerContext()
.config()
.protocols()
.stream()
.filter(p -> p instanceof WsConfig)
.findFirst()
.orElseThrow(() -> new InternalError("Unable to find WebSocket config"));
}
/**
* Create a new connection using a listener.
*
* @param ctx server connection context
* @param prologue prologue of this request
* @param upgradeHeaders headers for
* @param wsKey ws key
* @param wsListener a ws listener
* @return a new connection
*/
public static WsConnection create(ConnectionContext ctx,
HttpPrologue prologue,
Headers upgradeHeaders,
String wsKey,
WsListener wsListener) {
return new WsConnection(ctx, prologue, upgradeHeaders, wsKey, wsListener);
}
/**
* Create a new connection using a route.
*
* @param ctx server connection context
* @param prologue prologue of this request
* @param upgradeHeaders headers for
* @param wsKey ws key
* @param wsRoute route to use
* @return a new connection
*/
public static WsConnection create(ConnectionContext ctx,
HttpPrologue prologue,
Headers upgradeHeaders,
String wsKey,
WsRoute wsRoute) {
return new WsConnection(ctx, prologue, upgradeHeaders, wsKey, wsRoute.listener());
}
@SuppressWarnings("removal")
@Override
public void handle(Semaphore requestSemaphore) {
handle(FixedLimit.create(requestSemaphore));
}
@Override
public void handle(Limit limit) {
myThread = Thread.currentThread();
try {
limit.invoke(() -> listener.onOpen(this));
} catch (LimitException e) {
close(WsCloseCodes.TRY_AGAIN_LATER, "Too Many Concurrent Requests");
return;
} catch (Exception e) {
close(WsCloseCodes.UNEXPECTED_CONDITION, e.getMessage());
return;
}
while (canRun) {
readingNetwork = true;
ClientWsFrame frame = readFrame();
readingNetwork = false;
lastRequestTimestamp = DateTime.timestamp();
try {
boolean result = limit.invoke(() -> processFrame(frame));
if (!result) {
lastRequestTimestamp = DateTime.timestamp();
return;
}
lastRequestTimestamp = DateTime.timestamp();
} catch (LimitException e) {
listener.onClose(this, WsCloseCodes.TRY_AGAIN_LATER, "Too Many Concurrent Requests");
close(WsCloseCodes.TRY_AGAIN_LATER, "Too Many Concurrent Requests");
return;
} catch (CloseConnectionException e) {
throw e;
} catch (Exception e) {
listener.onError(this, e);
this.close(WsCloseCodes.UNEXPECTED_CONDITION, e.getMessage());
return;
}
}
this.close(WsCloseCodes.NORMAL_CLOSE, "Idle timeout");
}
@Override
public WsSession send(String text, boolean last) {
return send(ServerWsFrame.data(text, last));
}
@Override
public WsSession send(BufferData bufferData, boolean last) {
return send(ServerWsFrame.data(bufferData, last));
}
@Override
public WsSession ping(BufferData bufferData) {
return send(ServerWsFrame.control(WsOpCode.PING, bufferData));
}
@Override
public WsSession pong(BufferData bufferData) {
return send(ServerWsFrame.control(WsOpCode.PONG, bufferData));
}
@Override
public WsSession close(int code, String reason) {
closeSent = true;
byte[] reasonBytes = reason.getBytes(StandardCharsets.UTF_8);
BufferData bufferData = BufferData.create(2 + reasonBytes.length);
bufferData.writeInt16(code);
bufferData.write(reasonBytes);
return send(ServerWsFrame.control(WsOpCode.CLOSE, bufferData));
}
@Override
public WsSession terminate() {
close(WsCloseCodes.NORMAL_CLOSE, "Terminate");
throw new CloseConnectionException("Terminate from WebSocket");
}
@Override
public Optional subProtocol() {
return upgradeHeaders.first(WsUpgrader.PROTOCOL);
}
@Override
public SocketContext socketContext() {
return ctx;
}
@Override
public Duration idleTime() {
return Duration.between(lastRequestTimestamp, DateTime.timestamp());
}
@Override
public void close(boolean interrupt) {
// either way, finish
this.canRun = false;
if (interrupt) {
// interrupt regardless of current state
if (myThread != null) {
myThread.interrupt();
}
} else if (readingNetwork) {
// only interrupt when not processing a request (there is a chance of a race condition, this edge case
// is ignored
myThread.interrupt();
}
}
private boolean processFrame(ClientWsFrame frame) {
BufferData payload = frame.payloadData();
switch (frame.opCode()) {
case CONTINUATION -> {
boolean finalFrame = frame.fin();
ContinuationType ct = recvContinuation;
if (finalFrame) {
recvContinuation = ContinuationType.NONE;
}
switch (ct) {
case TEXT -> listener.onMessage(this, payload.readString(payload.available(), StandardCharsets.UTF_8), finalFrame);
case BINARY -> listener.onMessage(this, payload, finalFrame);
default -> {
close(WsCloseCodes.PROTOCOL_ERROR, "Unexpected continuation received");
throw new CloseConnectionException("Websocket unexpected continuation");
}
}
}
case TEXT -> {
recvContinuation = ContinuationType.TEXT;
listener.onMessage(this, payload.readString(payload.available(), StandardCharsets.UTF_8), frame.fin());
}
case BINARY -> {
recvContinuation = ContinuationType.BINARY;
listener.onMessage(this, payload, frame.fin());
}
case CLOSE -> {
int status = WsCloseCodes.NORMAL_CLOSE;
String reason = "normal";
if (payload.available() > 0) {
status = payload.readInt16();
if (payload.available() > 0) {
reason = payload.readString(payload.available(), StandardCharsets.UTF_8);
}
}
listener.onClose(this, status, reason);
if (!closeSent) {
close(WsCloseCodes.NORMAL_CLOSE, "normal");
}
return false;
}
case PING -> listener.onPing(this, payload);
case PONG -> listener.onPong(this, payload);
default -> throw new IllegalStateException("Invalid frame opCode: " + frame.opCode());
}
return true;
}
private ClientWsFrame readFrame() {
try {
return ClientWsFrame.read(ctx, dataReader, wsConfig.maxFrameLength());
} catch (DataReader.InsufficientDataAvailableException e) {
throw new CloseConnectionException("Socket closed by the other side", e);
} catch (WsCloseException e) {
close(e.closeCode(), e.getMessage());
throw new CloseConnectionException("WebSocket failed to read client frame", e);
}
}
private WsSession send(ServerWsFrame frame) {
WsOpCode usedCode = frame.opCode();
if (frame.isPayload()) {
// check if continuation or set continuation
if (sendContinuation) {
usedCode = WsOpCode.CONTINUATION;
}
// do not change type for the first frame
sendContinuation = !frame.fin();
}
frame.opCode(usedCode);
if (LOGGER.isLoggable(Level.TRACE)) {
ctx.log(LOGGER, Level.TRACE, "ws server frame send %s", frame);
}
sendBuffer.clear();
int opCodeFull = frame.fin() ? 0b10000000 : 0;
opCodeFull |= usedCode.code();
sendBuffer.write(opCodeFull);
long length = frame.payloadLength();
if (length < 126) {
sendBuffer.write((int) length);
} else if (length < 1 << 16) {
sendBuffer.write(126);
sendBuffer.write((int) (length >>> 8));
sendBuffer.write((int) (length & 0xFF));
} else {
sendBuffer.write(127);
for (int i = 56; i >= 0; i -= 8){
sendBuffer.write((int) (length >>> i) & 0xFF);
}
}
sendBuffer.write(frame.payloadData());
ctx.dataWriter().writeNow(sendBuffer);
return this;
}
private enum ContinuationType {
NONE,
TEXT,
BINARY
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy