All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.king.platform.net.http.netty.websocket.WebSocketClientImpl Maven / Gradle / Ivy

package com.king.platform.net.http.netty.websocket;

import com.king.platform.net.http.Headers;
import com.king.platform.net.http.WebSocketClient;
import com.king.platform.net.http.WebSocketListener;
import com.king.platform.net.http.netty.HttpRequestContext;
import com.king.platform.net.http.netty.eventbus.Event;
import com.king.platform.net.http.netty.requestbuilder.BuiltNettyClientRequest;
import com.king.platform.net.http.netty.util.AwaitLatch;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.concurrent.ScheduledFuture;
import org.slf4j.Logger;

import java.nio.charset.CharacterCodingException;
import java.nio.charset.Charset;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CodingErrorAction;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;

import static org.slf4j.LoggerFactory.getLogger;

public class WebSocketClientImpl implements WebSocketClient {
	private final Logger logger = getLogger(getClass());

	private final CharsetDecoder utf8Decoder = Charset.forName("UTF8").newDecoder().onMalformedInput(CodingErrorAction.REPORT);

	private final CopyOnWriteArrayList listeners = new CopyOnWriteArrayList<>();
	private final BuiltNettyClientRequest builtNettyClientRequest;
	private final Executor callbackExecutor;
	private final Executor completableFutureExecutor;
	private final boolean autoPong;
	private final boolean autoCloseFrame;
	private final Duration pingEveryDuration;

	private Headers headers = new Headers(EmptyHttpHeaders.INSTANCE);

	private FragmentedFrameType expectedFragmentedFrameType;

	private List bufferedFrames = new ArrayList<>();

	private volatile Channel channel;
	private volatile boolean ready;
	private volatile CompletableFuture connectionFuture;
	private final AwaitLatch awaitLatch = new AwaitLatch();
	private ScheduledFuture pingFuture;


	public WebSocketClientImpl(BuiltNettyClientRequest builtNettyClientRequest, Executor listenerExecutor, Executor completableFutureExecutor, boolean
		autoPong, boolean autoCloseFrame, Duration pingEveryDuration) {
		this.builtNettyClientRequest = builtNettyClientRequest;
		this.callbackExecutor = listenerExecutor;
		this.completableFutureExecutor = completableFutureExecutor;
		this.autoPong = autoPong;
		this.autoCloseFrame = autoCloseFrame;
		this.pingEveryDuration = pingEveryDuration;


		builtNettyClientRequest.withCustomCallbackSupplier(requestEventBus -> {
			requestEventBus.subscribe(Event.onWsOpen, WebSocketClientImpl.this::onOpen);
			requestEventBus.subscribe(Event.onWsFrame, WebSocketClientImpl.this::onWebSocketFrame);
			requestEventBus.subscribe(Event.ERROR, WebSocketClientImpl.this::onError);
			requestEventBus.subscribe(Event.COMPLETED, WebSocketClientImpl.this::onCompleted);
			requestEventBus.subscribe(Event.POPULATE_CONNECTION_SPECIFIC_HEADERS, WebSocketUtil::populateHeaders);
		});

	}


	@Override
	public Headers headers() {
		return headers;
	}

	@Override
	public String getNegotiatedSubProtocol() {
		return headers.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);

	}

	@Override
	public void addListener(WebSocketListener webSocketListener) {
		listeners.add(webSocketListener);
	}

	@Override
	public CompletableFuture connect() {
		if (connectionFuture != null) {
			throw new IllegalStateException("Already trying to connect!");
		}

		if (!ready) {
			connectionFuture = new CompletableFuture<>();
			builtNettyClientRequest.execute();
			return connectionFuture;
		} else {
			throw new IllegalStateException("Already connected");
		}

	}

	@Override
	public void awaitClose() throws InterruptedException {
		if (connectionFuture == null && channel == null) {
			return;
		}

		awaitLatch.awaitClose();
	}

	@Override
	public boolean isConnected() {
		return ready;
	}

	@Override
	public CompletableFuture sendTextFrame(String text) {
		if (!ready || channel == null) {
			CompletableFuture future = new CompletableFuture<>();
			future.completeExceptionally(new IllegalStateException("Not connected!"));
			return future;
		}
		return convert(channel.writeAndFlush(new TextWebSocketFrame(text)));
	}

	@Override
	public CompletableFuture sendCloseFrame(int statusCode, String reason) {
		if (!ready || channel == null) {
			CompletableFuture future = new CompletableFuture<>();
			future.completeExceptionally(new IllegalStateException("Not connected!"));
			return future;
		}

		if (channel.isOpen()) {
			channel.writeAndFlush(new CloseWebSocketFrame(statusCode, reason));
		}

		return CompletableFuture.completedFuture(null);
	}

	@Override
	public CompletableFuture sendCloseFrame() {
		return sendCloseFrame(1000, "");
	}

	@Override
	public CompletableFuture sendBinaryFrame(byte[] payload) {
		if (!ready || channel == null) {
			CompletableFuture future = new CompletableFuture<>();
			future.completeExceptionally(new IllegalStateException("Not connected!"));
			return future;
		}

		return convert(channel.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(payload))));
	}

	@Override
	public CompletableFuture sendPingFrame() {
		if (!ready || channel == null) {
			CompletableFuture future = new CompletableFuture<>();
			future.completeExceptionally(new IllegalStateException("Not connected!"));
			return future;
		}

		return convert(channel.writeAndFlush(new PingWebSocketFrame()));
	}

	@Override
	public CompletableFuture sendPingFrame(byte[] payload) {
		if (!ready || channel == null) {
			CompletableFuture future = new CompletableFuture<>();
			future.completeExceptionally(new IllegalStateException("Not connected!"));
			return future;
		}

		return convert(channel.writeAndFlush(new PingWebSocketFrame(Unpooled.copiedBuffer(payload))));
	}


	private CompletableFuture convert(ChannelFuture f) {
		CompletableFuture completableFuture = new CompletableFuture<>();
		f.addListener((ChannelFutureListener) future -> {
			if (future.isSuccess()) {
				completableFuture.complete(null);
			} else {
				completableFuture.completeExceptionally(future.cause());
			}
		});
		return completableFuture;
	}

	private void onWebSocketFrame(WebSocketFrame frame) {
		if (!ready) {
			bufferedFrames.add(frame);
			frame.retain();
		} else {
			try {
				if (frame instanceof TextWebSocketFrame) {
					onTextFrame((TextWebSocketFrame) frame);
				} else if (frame instanceof BinaryWebSocketFrame) {
					onBinaryFrame((BinaryWebSocketFrame) frame);
				} else if (frame instanceof CloseWebSocketFrame) {
					onClose((CloseWebSocketFrame) frame);
				} else if (frame instanceof PingWebSocketFrame) {
					onPingFrame((PingWebSocketFrame) frame);
				} else if (frame instanceof PongWebSocketFrame) {
					onPongFrame((PongWebSocketFrame) frame);
				} else if (frame instanceof ContinuationWebSocketFrame) {
					onContinuationFrame((ContinuationWebSocketFrame) frame);
				} else {
					logger.error("Invalid message {}", frame);
				}
			} finally {
				frame.release();
			}
		}
	}

	private void onPongFrame(PongWebSocketFrame pongWebSocketFrame) {
		byte[] bytes = getBytes(pongWebSocketFrame.content());

		callbackExecutor.execute(() -> {
			for (WebSocketListener webSocketListener : listeners) {
				webSocketListener.onPongFrame(bytes);
			}
		});

	}

	private void onPingFrame(PingWebSocketFrame pingWebSocketFrame) {

		byte[] bytes = getBytes(pingWebSocketFrame.content());

		if (autoPong) {
			channel.writeAndFlush(new PongWebSocketFrame(Unpooled.copiedBuffer(bytes)));
		}

		callbackExecutor.execute(() -> {
			for (WebSocketListener webSocketListener : listeners) {
				webSocketListener.onPingFrame(bytes);
			}

		});

	}

	private void onContinuationFrame(ContinuationWebSocketFrame continuationWebSocketFrame) {
		if (expectedFragmentedFrameType == null) {
			logger.error("Received continuation frame when the last frame was completed!");
			return;
		}

		try {
			switch (expectedFragmentedFrameType) {
				case BINARY:
					handleBinaryFrame(continuationWebSocketFrame);
					break;
				case TEXT:
					handleTextFrame(continuationWebSocketFrame);
					break;
				default:
					sendCloseFrame(1002, "Incorrect continuation frame!");
			}
		} finally {
			if (continuationWebSocketFrame.isFinalFragment()) {
				expectedFragmentedFrameType = null;
			}
		}
	}

	private void handleTextFrame(WebSocketFrame webSocketFrame) {
		try {
			String text = utf8Decoder.decode(webSocketFrame.content().nioBuffer()).toString();
			boolean finalFragment = webSocketFrame.isFinalFragment();
			int rsv = webSocketFrame.rsv();
			callbackExecutor.execute(() -> {
				for (WebSocketListener webSocketListener : listeners) {
					webSocketListener.onTextFrame(text, finalFragment, rsv);
				}

			});
		} catch (CharacterCodingException e) {
			sendCloseFrame(1007, "Invalid UTF-8 encoding");
		}

	}

	private void handleBinaryFrame(WebSocketFrame webSocketFrame) {
		byte[] bytes = getBytes(webSocketFrame.content());
		boolean finalFragment = webSocketFrame.isFinalFragment();
		int rsv = webSocketFrame.rsv();

		callbackExecutor.execute(() -> {
			for (WebSocketListener webSocketListener : listeners) {
				webSocketListener.onBinaryFrame(bytes, finalFragment, rsv);
			}
		});
	}

	private void onBinaryFrame(BinaryWebSocketFrame binaryWebSocketFrame) {
		if (expectedFragmentedFrameType == null && !binaryWebSocketFrame.isFinalFragment()) {
			expectedFragmentedFrameType = FragmentedFrameType.BINARY;
		}
		handleBinaryFrame(binaryWebSocketFrame);
	}

	private void onTextFrame(TextWebSocketFrame textWebSocketFrame) {
		if (expectedFragmentedFrameType == null && !textWebSocketFrame.isFinalFragment()) {
			expectedFragmentedFrameType = FragmentedFrameType.TEXT;
		}

		handleTextFrame(textWebSocketFrame);
	}

	private void onClose(CloseWebSocketFrame closeWebSocketFrame) {
		int statusCode = closeWebSocketFrame.statusCode();
		String reasonText = closeWebSocketFrame.reasonText();

		if (autoCloseFrame) {
			sendCloseFrame(statusCode, reasonText);
		}

		callbackExecutor.execute(() -> {
			for (WebSocketListener webSocketListener : listeners) {
				webSocketListener.onCloseFrame(statusCode, reasonText);
			}
		});


	}

	private void onOpen(Channel channel, io.netty.handler.codec.http.HttpHeaders httpHeaders) {
		this.channel = channel;
		this.headers = new Headers(httpHeaders);
		this.ready = true;
		if (this.connectionFuture != null) {
			CompletableFuture future = this.connectionFuture;
			completableFutureExecutor.execute(() -> future.complete(this));
			this.connectionFuture = null;
		}

		callbackExecutor.execute(() -> {
			for (WebSocketListener webSocketListener : listeners) {
				webSocketListener.onConnect(this);
			}
		});


		for (WebSocketFrame bufferedFrame : bufferedFrames) {
			onWebSocketFrame(bufferedFrame);
			bufferedFrame.release();
		}


		if (pingEveryDuration != null) {

			if (pingFuture != null) {
				pingFuture.cancel(true);
			}

			pingFuture = channel.eventLoop().scheduleAtFixedRate(new Runnable() {
				@Override
				public void run() {
					sendPingFrame();
				}
			}, pingEveryDuration.toMillis(), pingEveryDuration.toMillis(), TimeUnit.MILLISECONDS);
		}

	}

	private void onError(HttpRequestContext httpRequestContext, Throwable throwable) {
		boolean wasConnected = ready;
		ready = false;
		channel = null;
		if (this.connectionFuture != null) {
			CompletableFuture future = this.connectionFuture;

			completableFutureExecutor.execute(() -> future.completeExceptionally(throwable));
			this.connectionFuture = null;

		}

		callbackExecutor.execute(() -> {
			for (WebSocketListener webSocketListener : listeners) {
				webSocketListener.onError(throwable);
			}
		});

		if (wasConnected) {
			callbackExecutor.execute(() -> {
				for (WebSocketListener webSocketListener : listeners) {
					webSocketListener.onDisconnect();
				}
			});
		}

		if (pingFuture != null) {
			pingFuture.cancel(true);
		}

		awaitLatch.closed();
	}

	private void onCompleted(HttpRequestContext httpRequestContext) {
		boolean wasConnected = ready;

		ready = false;
		channel = null;

		if (wasConnected) {
			callbackExecutor.execute(() -> {
				for (WebSocketListener webSocketListener : listeners) {
					webSocketListener.onDisconnect();
				}
			});
		}

		if (pingFuture != null) {
			pingFuture.cancel(true);
		}
		awaitLatch.closed();
	}

	public byte[] getBytes(ByteBuf buf) {
		int readable = buf.readableBytes();
		int readerIndex = buf.readerIndex();
		if (buf.hasArray()) {
			byte[] array = buf.array();
			if (buf.arrayOffset() == 0 && readerIndex == 0 && array.length == readable) {
				return array;
			}
		}
		byte[] array = new byte[readable];
		buf.getBytes(readerIndex, array);
		return array;
	}

	private enum FragmentedFrameType {
		TEXT, BINARY;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy