All Downloads are FREE. Search and download functionalities are using the official Maven repository.

it.auties.whatsapp.net.SocketClient Maven / Gradle / Ivy

package it.auties.whatsapp.net;

import it.auties.whatsapp.util.Proxies;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import java.io.*;
import java.net.*;
import java.nio.ByteBuffer;
import java.nio.InvalidMarkException;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

@SuppressWarnings("unused")
public class SocketClient extends Socket implements AutoCloseable {
    private static final int DEFAULT_CONNECTION_TIMEOUT = 300;

    public static SocketClient newPlainClient(URI proxy) throws IOException {
        var channel = AsynchronousSocketChannel.open();
        var layerSupport = new SocketTransport.Plain(channel);
        var proxySupport = SocketConnection.of(channel, layerSupport, proxy);
        return new SocketClient(channel, proxySupport, layerSupport);
    }

    public static SocketClient newSecureClient(SSLEngine sslEngine, URI proxy) throws IOException {
        var channel = AsynchronousSocketChannel.open();
        var layerSupport = new SocketTransport.Secure(channel, sslEngine);
        var proxySupport = SocketConnection.of(channel, layerSupport, proxy);
        return new SocketClient(channel, proxySupport, layerSupport);
    }

    final AsynchronousSocketChannel channel;
    final SocketConnection socketConnection;
    SocketTransport socketTransport;
    private SocketClient(AsynchronousSocketChannel channel, SocketConnection socketConnection, SocketTransport socketTransport) {
        this.channel = channel;
        this.socketConnection = socketConnection;
        this.socketTransport = socketTransport;
    }

    public CompletableFuture upgradeToSsl(SSLEngine sslEngine) {
        if(!isConnected()) {
            throw new IllegalArgumentException("The socket is not connected");
        }

        if(socketTransport.isSecure()) {
            throw new IllegalStateException("This socket is already using a secure connection");
        }

        this.socketTransport = new SocketTransport.Secure(channel, sslEngine);
        return socketTransport.handshake(); // Upgrading a websocket is not supported, so path is always null
    }

    @Override
    public void connect(SocketAddress endpoint) throws IOException {
        connect(endpoint, DEFAULT_CONNECTION_TIMEOUT);
    }

    @Override
    public void connect(SocketAddress endpoint, int timeout) throws IOException {
        if(!(endpoint instanceof InetSocketAddress inetSocketAddress)) {
            throw new IllegalArgumentException("Unsupported address type");
        }

        var future = connectAsync(inetSocketAddress, timeout);
        future.join();
    }

    public CompletableFuture connectAsync(InetSocketAddress address) {
        return connectAsync(address, DEFAULT_CONNECTION_TIMEOUT);
    }

    public CompletableFuture connectAsync(InetSocketAddress address, int timeout) {
        return socketConnection.connectAsync(address, timeout)
                .thenComposeAsync(ignored -> socketTransport.handshake())
                .exceptionallyComposeAsync(this::closeSocketOnError);
    }

    private CompletableFuture closeSocketOnError(Throwable error) {
        try {
            close();
        }catch (Throwable ignored) {

        }

        return CompletableFuture.failedFuture(error);
    }

    @Override
    public InputStream getInputStream() throws IOException {
        if(!isConnected()) {
            throw new IOException("Connection is closed");
        }

        return new InputStream() {
            @Override
            public int read() throws IOException {
                var data = new byte[1];
                var result = read(data);
                if(result == -1) {
                    close();
                    throw new EOFException();
                }

                return Byte.toUnsignedInt(data[0]);
            }

            @Override
            public int read(byte[] b) {
                return read(b, 0, b.length);
            }

            @Override
            public int read(byte[] b, int off, int len) {
                if(len == 0) {
                    return 0;
                }

                return readAsync(ByteBuffer.wrap(b, off, len))
                        .join();
            }

            @Override
            public void close() throws IOException {
                SocketClient.this.close();
            }
        };
    }

    @Override
    public OutputStream getOutputStream() throws IOException {
        if(!isConnected()) {
            throw new IOException("Connection is closed");
        }

        return new OutputStream() {
            @Override
            public void write(int b) {
                write(new byte[]{(byte) b}, 0, 1);
            }

            @Override
            public void write(byte[] b, int off, int len) {
                writeAsync(b, off, len).join();
            }

            @Override
            public void close() throws IOException {
                SocketClient.this.close();
            }
        };
    }

    @Override
    public void close() throws IOException {
        channel.close();
    }

    @Override
    public boolean isBound() {
        return channel.isOpen();
    }

    @Override
    public boolean isConnected() {
        try {
            return channel.getRemoteAddress() != null;
        } catch (IOException e) {
            return false;
        }
    }

    @Override
    public boolean isOutputShutdown() {
        return !isConnected();
    }

    @Override
    public boolean isInputShutdown() {
        return !isConnected();
    }

    @Override
    public boolean isClosed() {
        return !channel.isOpen();
    }

    @Override
    public int getReceiveBufferSize() {
        try {
            return channel.getOption(StandardSocketOptions.SO_RCVBUF);
        } catch (Throwable e) {
            return 0;
        }
    }

    @Override
    public int getSendBufferSize() {
        try {
            return channel.getOption(StandardSocketOptions.SO_SNDBUF);
        } catch (Throwable e) {
            return 0;
        }
    }

    @Override
    public void bind(SocketAddress endpoint) throws IOException {
        throw new UnsupportedOperationException("Client socket");
    }

    @Override
    public int getPort() {
        try {
            if(channel.getRemoteAddress() instanceof InetSocketAddress inetSocketAddress) {
                return inetSocketAddress.getPort();
            }

            return -1;
        } catch (Throwable e) {
            return -1;
        }
    }

    @Override
    public int getLocalPort() {
        return -1;
    }

    @Override
    public InetAddress getInetAddress() {
        try {
            if(channel.getRemoteAddress() instanceof InetSocketAddress inetSocketAddress) {
                return inetSocketAddress.getAddress();
            }

            return null;
        } catch (Throwable e) {
            return null;
        }
    }

    @Override
    public InetAddress getLocalAddress() {
        return null;
    }

    @Override
    public SocketAddress getRemoteSocketAddress() {
        try {
            return channel.getRemoteAddress();
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    @Override
    public SocketAddress getLocalSocketAddress() {
        return null;
    }

    @Override
    public void setTcpNoDelay(boolean on) throws SocketException {
        if(!supportedOptions().contains(StandardSocketOptions.TCP_NODELAY)) {
            return;
        }

        try {
            channel.setOption(StandardSocketOptions.TCP_NODELAY, on);
        } catch (IOException e) {
            throw new SocketException(e);
        }
    }

    @Override
    public boolean getTcpNoDelay() throws SocketException {
        try {
            return getOption(StandardSocketOptions.TCP_NODELAY);
        } catch (IOException e) {
            throw new SocketException(e);
        }
    }

    @Override
    public void setSoLinger(boolean on, int linger) throws SocketException {
        if(!supportedOptions().contains(StandardSocketOptions.SO_LINGER)) {
            return;
        }

        try {
            if(on) {
                channel.setOption(StandardSocketOptions.SO_LINGER, linger);
            }else {
                channel.setOption(StandardSocketOptions.SO_LINGER, -1);
            }
        }catch (IOException e) {
            throw new SocketException(e);
        }
    }

    @Override
    public int getSoLinger() {
        try {
            return getOption(StandardSocketOptions.SO_LINGER);
        } catch (Throwable ignored) {
            return 0;
        }
    }

    @Override
    public void sendUrgentData(int data) throws SocketException {
        try {
            var future = writeAsync(new byte[]{(byte) data});
            future.join();
        }catch (Throwable throwable) {
            throw new SocketException(throwable);
        }
    }

    @Override
    public void setOOBInline(boolean on) {

    }

    @Override
    public boolean getOOBInline() {
        return false;
    }

    @Override
    public void setSoTimeout(int timeout) {

    }

    @Override
    public int getSoTimeout() {
        return 0;
    }

    @Override
    public void setSendBufferSize(int size) throws SocketException {
        if(!supportedOptions().contains(StandardSocketOptions.SO_SNDBUF)) {
            return;
        }

        try {
            channel.setOption(StandardSocketOptions.SO_SNDBUF, size);
        } catch (IOException e) {
            throw new SocketException(e);
        }
    }

    @Override
    public void setReceiveBufferSize(int size) throws SocketException {
        if(!supportedOptions().contains(StandardSocketOptions.SO_RCVBUF)) {
            return;
        }

        try {
            channel.setOption(StandardSocketOptions.SO_RCVBUF, size);
        } catch (IOException e) {
            throw new SocketException(e);
        }
    }

    @Override
    public void setKeepAlive(boolean on) throws SocketException {
        if(!supportedOptions().contains(StandardSocketOptions.SO_KEEPALIVE)) {
            return;
        }

        try {
            channel.setOption(StandardSocketOptions.SO_KEEPALIVE, on);
        } catch (IOException e) {
            throw new SocketException(e);
        }
    }

    @Override
    public boolean getKeepAlive() {
        try {
            return getOption(StandardSocketOptions.SO_KEEPALIVE);
        } catch (IOException ignored) {
            return false;
        }
    }

    @Override
    public void setTrafficClass(int tc) throws SocketException {
        if(!supportedOptions().contains(StandardSocketOptions.IP_TOS)) {
            return;
        }

        try {
            channel.setOption(StandardSocketOptions.IP_TOS, tc);
        } catch (IOException e) {
            throw new SocketException(e);
        }
    }

    @Override
    public int getTrafficClass() throws SocketException {
        try {
            return getOption(StandardSocketOptions.IP_TOS);
        } catch (IOException e) {
            throw new SocketException(e);
        }
    }

    @Override
    public void setReuseAddress(boolean on) throws SocketException {
        if(!supportedOptions().contains(StandardSocketOptions.SO_REUSEADDR)) {
            return;
        }

        try {
            channel.setOption(StandardSocketOptions.SO_REUSEADDR, on);
        } catch (IOException e) {
            throw new SocketException(e);
        }
    }

    @Override
    public boolean getReuseAddress() {
        try {
            return getOption(StandardSocketOptions.SO_REUSEADDR);
        } catch (IOException ignored) {
            return false;
        }
    }

    @Override
    public void shutdownInput() throws IOException {
        channel.shutdownInput();
    }

    @Override
    public void shutdownOutput() throws IOException {
        channel.shutdownOutput();
    }

    @Override
    public  Socket setOption(SocketOption name, T value) throws IOException {
        channel.setOption(name, value);
        return this;
    }

    @Override
    public  T getOption(SocketOption name) throws IOException {
        return channel.getOption(name);
    }

    @Override
    public Set> supportedOptions() {
        return channel.supportedOptions();
    }

    public CompletableFuture writeAsync(byte[] data) {
        return writeAsync(data, 0, data.length);
    }

    public CompletableFuture writeAsync(byte[] data, int offset, int length) {
        return writeAsync(ByteBuffer.wrap(data, offset, length));
    }

    public CompletableFuture writeAsync(ByteBuffer buffer) {
        var future = new Response.Future();
        return socketTransport.write(buffer, future);
    }

    public void readFullyAsync(int length, Response.Callback callback) {
        if (length < 0) {
            throw new IllegalArgumentException("Cannot read %s bytes from socket".formatted(length));
        }

        var buffer = ByteBuffer.allocate(length);
        socketTransport.readFully(buffer, callback);
    }

    public CompletableFuture readFullyAsync(int length) {
        if (length < 0) {
            return CompletableFuture.failedFuture(new IllegalArgumentException("Cannot read %s bytes from socket".formatted(length)));
        }

        var buffer = ByteBuffer.allocate(length);
        var future = new Response.Future();
        socketTransport.readFully(buffer, future);
        return future;
    }

    public CompletableFuture readAsync(int length) {
        var future = new Response.Future();
        return readAsyncBuffer(length, future);
    }

    public void readAsync(int length, Response.Callback callback) {
        readAsyncBuffer(length, callback);
    }

    private > R readAsyncBuffer(int length, R result) {
        if (length < 0) {
            result.completeExceptionally(new IllegalArgumentException("Cannot read %s bytes from socket".formatted(length)));
            return result;
        }

        var buffer = ByteBuffer.allocate(length);
        readAsync(buffer, (bytesRead, error) -> {
            if(error != null) {
                result.completeExceptionally(error);
                return;
            }

            result.complete(buffer);
        });
        return result;
    }

    public CompletableFuture readAsync(ByteBuffer buffer) {
        var future = new Response.Future();
        return socketTransport.read(buffer, true, future);
    }

    public void readAsync(ByteBuffer buffer, Response.Callback callback) {
        socketTransport.read(buffer, true, callback);
    }

    private static sealed abstract class SocketTransport {
        final AsynchronousSocketChannel channel;
        private SocketTransport(AsynchronousSocketChannel channel) {
            this.channel = channel;
        }

        abstract CompletableFuture handshake();

        abstract boolean isSecure();

        abstract > R write(ByteBuffer buffer, R result);

        abstract > R read(ByteBuffer buffer, boolean lastRead, R result);

        > R readPlain(ByteBuffer buffer, boolean lastRead, R result) {
            var outerCaller = new RuntimeException();
            channel.read(buffer, null, new CompletionHandler<>() {
                @Override
                public void completed(Integer bytesRead, Object attachment) {
                    if(bytesRead == -1) {
                        var eof = new EOFException();
                        eof.addSuppressed(outerCaller);
                        result.completeExceptionally(eof);
                        return;
                    }

                    if(lastRead) {
                        buffer.flip();
                    }

                    result.complete(bytesRead);
                }

                @Override
                public void failed(Throwable exc, Object attachment) {
                    exc.addSuppressed(outerCaller);
                    result.completeExceptionally(exc);
                }
            });
            return result;
        }

        > R writePlain(ByteBuffer buffer, R result) {
            var outerCaller = new RuntimeException();
            channel.write(buffer, null, new CompletionHandler<>() {
                @Override
                public void completed(Integer bytesWritten, Object attachment) {
                    if(bytesWritten == -1) {
                        result.completeExceptionally(new SocketException());
                        return;
                    }

                    if(buffer.hasRemaining()) {
                        writePlain(buffer, result);
                        return;
                    }

                    result.complete(null);
                }

                @Override
                public void failed(Throwable exc, Object attachment) {
                    exc.addSuppressed(outerCaller);
                    result.completeExceptionally(exc);
                }
            });
            return result;
        }

        public void readFully(ByteBuffer buffer, Response result) {
            read(buffer, false, (Response.Callback) (readResult, error) -> {
                if (error != null) {
                    result.completeExceptionally(error);
                    return;
                }

                if(buffer.hasRemaining()) {
                    readFully(buffer, result);
                    return;
                }

                buffer.flip();
                result.complete(buffer);
            });
        }

        private static final class Plain extends SocketTransport {
            private Plain(AsynchronousSocketChannel channel) {
                super(channel);
            }

            @Override
            boolean isSecure() {
                return false;
            }

            @Override
            > R read(ByteBuffer buffer, boolean lastRead, R result) {
                return readPlain(buffer, lastRead, result);
            }

            @Override
            > R write(ByteBuffer buffer, R result) {
                return writePlain(buffer, result);
            }

            @Override
            CompletableFuture handshake() {
                return CompletableFuture.completedFuture(null);
            }
        }

        private static final class Secure extends SocketTransport {
            private final AtomicBoolean sslHandshakeCompleted;
            private final Object sslHandshakeLock;
            private final SSLEngine sslEngine;
            private final ByteBuffer sslReadBuffer;
            private final ByteBuffer sslWriteBuffer;
            private final ByteBuffer sslOutputBuffer;
            private Response.Future sslHandshake;
            private Secure(AsynchronousSocketChannel channel, SSLEngine sslEngine) {
                super(channel);
                this.sslHandshakeCompleted = new AtomicBoolean();
                this.sslHandshakeLock = new Object();
                sslHandshakeCompleted.set(sslEngine == null);
                this.sslEngine = sslEngine;
                var bufferSize = sslEngine.getSession().getPacketBufferSize();
                this.sslReadBuffer = ByteBuffer.allocate(bufferSize);
                this.sslWriteBuffer = ByteBuffer.allocate(bufferSize);
                this.sslOutputBuffer = ByteBuffer.allocate(bufferSize);
            }

            @Override
            boolean isSecure() {
                return true;
            }

            @Override
            CompletableFuture handshake() {
                try {
                    if(sslEngine == null) {
                        return CompletableFuture.completedFuture(null);
                    }

                    if(sslHandshakeCompleted.get()) {
                        return CompletableFuture.completedFuture(null);
                    }

                    if(sslHandshake != null) {
                        return sslHandshake;
                    }

                    synchronized (sslHandshakeLock) {
                        if(sslHandshake != null) {
                            return sslHandshake;
                        }

                        this.sslHandshake = new Response.Future<>();
                        sslEngine.beginHandshake();
                        sslReadBuffer.position(sslReadBuffer.limit());
                        handleSslHandshakeStatus(null);
                        return sslHandshake;
                    }
                } catch (Throwable throwable) {
                    return CompletableFuture.failedFuture(throwable);
                }
            }

            private void handleSslHandshakeStatus(Status status){
                switch (sslEngine.getHandshakeStatus()) {
                    case NEED_WRAP -> doSslHandshakeWrap();
                    case NEED_UNWRAP, NEED_UNWRAP_AGAIN -> doSslHandshakeUnwrap(status == Status.BUFFER_UNDERFLOW);
                    case NEED_TASK -> doSslHandshakeTasks();
                    case FINISHED -> finishSslHandshake();
                    case NOT_HANDSHAKING -> sslHandshake.completeExceptionally(new IOException("Cannot complete handshake"));
                }
            }

            private void finishSslHandshake() {
                sslHandshakeCompleted.set(true);
                sslOutputBuffer.clear();
                sslHandshake.complete(null);
            }

            private void doSslHandshakeTasks() {
                Runnable runnable;
                while ((runnable = sslEngine.getDelegatedTask()) != null) {
                    runnable.run();
                }

                handleSslHandshakeStatus(null);
            }

            private void doSslHandshakeUnwrap(boolean forceRead) {
                sslReadBuffer.compact();
                if (!forceRead && sslReadBuffer.position() != 0) {
                    sslReadBuffer.flip();
                    doSSlHandshakeUnwrapOperation();
                    return;
                }

                readPlain(sslReadBuffer, true, (Response.Callback) (ignored, error) -> {
                    if(error != null) {
                        sslHandshake.completeExceptionally(error);
                        return;
                    }

                    doSSlHandshakeUnwrapOperation();
                });
            }

            private void doSSlHandshakeUnwrapOperation() {
                try {
                    var result = sslEngine.unwrap(sslReadBuffer, sslOutputBuffer);
                    if(isHandshakeFinished(result, false)) {
                        finishSslHandshake();
                    }else {
                        handleSslHandshakeStatus(result.getStatus());
                    }
                }catch(Throwable throwable) {
                    sslHandshake.completeExceptionally(throwable);
                }
            }

            private void doSslHandshakeWrap() {
                try {
                    sslWriteBuffer.clear();
                    var result = sslEngine.wrap(sslOutputBuffer, sslWriteBuffer);
                    var isHandshakeFinished = isHandshakeFinished(result, true);
                    sslWriteBuffer.flip();
                    writePlain(sslWriteBuffer, (Response.Callback) (ignored, error) -> {
                        if(error != null) {
                            sslHandshake.completeExceptionally(error);
                            return;
                        }

                        if(isHandshakeFinished) {
                            finishSslHandshake();
                        }else {
                            handleSslHandshakeStatus(null);
                        }
                    });
                }catch (Throwable throwable) {
                    sslHandshake.completeExceptionally(throwable);
                }
            }

            private boolean isHandshakeFinished(SSLEngineResult result, boolean wrap) {
                var sslEngineStatus = result.getStatus();
                if (sslEngineStatus != Status.OK && (wrap || sslEngineStatus != Status.BUFFER_UNDERFLOW)) {
                    throw new IllegalStateException("SSL handshake operation failed with status: " + sslEngineStatus);
                }

                if (wrap && result.bytesConsumed() != 0) {
                    throw new IllegalStateException("SSL handshake operation failed with status: no bytes consumed");
                }

                if (!wrap && result.bytesProduced() != 0) {
                    throw new IllegalStateException("SSL handshake operation failed with status: no bytes produced");
                }

                var sslHandshakeStatus = result.getHandshakeStatus();
                return sslHandshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED;
            }

            @Override
            > R read(ByteBuffer buffer, boolean lastRead, R result) {
                try {
                    if(!sslHandshakeCompleted.get()) {
                        return readPlain(buffer, lastRead, result);
                    }

                    var bytesCopied = readFromBufferedOutput(buffer, lastRead);
                    if(bytesCopied != 0) {
                        result.complete(bytesCopied);
                    }else if (sslReadBuffer.hasRemaining()) {
                        decodeSslBuffer(buffer, lastRead, result);
                    }else {
                        fillSslBuffer(buffer, lastRead, result);
                    }

                    return result;
                }catch (Throwable throwable) {
                    result.completeExceptionally(throwable);
                    return result;
                }
            }

            private > void fillSslBuffer(ByteBuffer buffer, boolean lastRead, R result) {
                sslReadBuffer.compact();
                readPlain(sslReadBuffer, true, (Response.Callback) (ignored, error) -> {
                    if (error != null) {
                        result.completeExceptionally(error);
                        return;
                    }

                    decodeSslBuffer(buffer, lastRead, result);
                });
            }

            private void decodeSslBuffer(ByteBuffer buffer, boolean lastRead, Response result) {
                try {
                    var unwrapResult = sslEngine.unwrap(sslReadBuffer, sslOutputBuffer);
                    switch (unwrapResult.getStatus()) {
                        case OK -> {
                            if (unwrapResult.bytesProduced() == 0) {
                                sslOutputBuffer.mark();
                                read(buffer, lastRead , result);
                            } else {
                                var bytesCopied = readFromBufferedOutput(buffer, lastRead);
                                result.complete(bytesCopied);
                            }
                        }
                        case BUFFER_UNDERFLOW -> fillSslBuffer(buffer, lastRead, result);
                        case BUFFER_OVERFLOW -> result.completeExceptionally(new IllegalStateException("SSL output buffer overflow"));
                        case CLOSED -> result.completeExceptionally(new EOFException());
                    }
                }catch (Throwable throwable) {
                    result.completeExceptionally(throwable);
                }
            }

            private int readFromBufferedOutput(ByteBuffer buffer, boolean lastRead) {
                var writePosition = sslOutputBuffer.position();
                if(writePosition == 0) {
                    return 0;
                }

                var bytesRead = 0;
                var writeLimit = sslOutputBuffer.limit();
                sslOutputBuffer.limit(writePosition);
                try {
                    sslOutputBuffer.reset(); // Go back to last read position
                }catch (InvalidMarkException exception) {
                    sslOutputBuffer.flip(); // This can happen if unwrapResult.bytesProduced() != 0 on the first call
                }
                while (buffer.hasRemaining() && sslOutputBuffer.hasRemaining()) {
                    buffer.put(sslOutputBuffer.get());
                    bytesRead++;
                }

                if(!sslOutputBuffer.hasRemaining()) {
                    sslOutputBuffer.clear();
                    sslOutputBuffer.mark();
                }else {
                    sslOutputBuffer.limit(writeLimit);
                    sslOutputBuffer.mark();
                    sslOutputBuffer.position(writePosition);
                }

                if(lastRead) {
                    buffer.flip();
                }

                return bytesRead;
            }

            @Override
            > R write(ByteBuffer buffer, R result) {
                if(!sslHandshakeCompleted.get()) {
                    return writePlain(buffer, result);
                }

                writeSecure(buffer, result);
                return result;
            }

            private > void writeSecure(ByteBuffer buffer, R result) {
                if(!buffer.hasRemaining()) {
                    result.complete(null);
                    return;
                }

                try {
                    sslWriteBuffer.clear();
                    var wrapResult = sslEngine.wrap(buffer, sslWriteBuffer);
                    var status = wrapResult.getStatus();
                    if (status != Status.OK && status != Status.BUFFER_OVERFLOW) {
                        throw new IllegalStateException("SSL wrap failed with status: " + status);
                    }

                    sslWriteBuffer.flip();
                    writePlain(sslWriteBuffer, (Response.Callback) (ignored, error) -> {
                        if(error != null) {
                            result.completeExceptionally(error);
                            return;
                        }

                        writeSecure(buffer, result);
                    });
                }catch (SSLException exception) {
                    result.completeExceptionally(exception);
                }
            }
        }
    }

    private sealed static abstract class SocketConnection {
        // Necessary because of a bug in the JDK
        // The number 50 was reached after debugging to find the best possible number to minimize connection time
        // The problem with this approach is that 50 threads ask for a bad connection, they block other instances
        // Shouldn't be a problem because the connection would error out immediately, but it should be looked into
        // Also maybe find a fix, so I can report it to Oracle
        private static final Semaphore CONNECTION_SEMAPHORE = new Semaphore(50, true);
        final AsynchronousSocketChannel channel;
        final SocketTransport socketTransport;
        final URI proxy;
        private SocketConnection(AsynchronousSocketChannel channel, SocketTransport socketTransport, URI proxy) {
            this.channel = channel;
            this.socketTransport = socketTransport;
            this.proxy = proxy;
        }

        private static SocketConnection of(AsynchronousSocketChannel channel, SocketTransport socketTransport, URI proxy) {
            return switch (Proxies.toProxy(proxy).type()) {
                case DIRECT -> new NoProxy(channel);
                case HTTP -> new HttpProxy(channel, socketTransport, proxy);
                case SOCKS -> new SocksProxy(channel, socketTransport, proxy);
            };
        }

        CompletableFuture connectAsync(InetSocketAddress address, int timeout) {
            return CompletableFuture.runAsync(() -> connectSync(address), Thread::startVirtualThread)
                    .orTimeout(timeout > 0 ? timeout : DEFAULT_CONNECTION_TIMEOUT, TimeUnit.SECONDS);
        }

        private void connectSync(InetSocketAddress address) {
            try {
                CONNECTION_SEMAPHORE.acquire();
                var start = System.currentTimeMillis();
                var future = channel.connect(address);
                future.get();
            }catch (Throwable throwable) {
                throw new RuntimeException("Cannot connect to " + address, throwable);
            }finally {
                CONNECTION_SEMAPHORE.release();
            }
        }

        private static final class NoProxy extends SocketConnection {
            private NoProxy(AsynchronousSocketChannel channel) {
                super(channel, null, null);
            }

            @Override
            public CompletableFuture connectAsync(InetSocketAddress address, int timeout) {
                return super.connectAsync(address, timeout);
            }
        }

        private static final class HttpProxy extends SocketConnection {
            private static final int DEFAULT_RCV_BUF = 8192;
            private static final int OK_STATUS_CODE = 200;

            private HttpProxy(AsynchronousSocketChannel channel, SocketTransport socketTransport, URI proxy) {
                super(channel, socketTransport, proxy);
            }

            @Override
            public CompletableFuture connectAsync(InetSocketAddress address, int timeout) {
                return super.connectAsync(new InetSocketAddress(proxy.getHost(), proxy.getPort()), timeout)
                        .thenComposeAsync(openResult -> sendAuthentication(address))
                        .thenComposeAsync(connectionResult -> readAuthenticationResponse())
                        .thenComposeAsync(this::handleAuthentication);
            }

            private CompletableFuture handleAuthentication(String response) {
                var responseParts = response.split(" ");
                if(responseParts.length < 2) {
                    return CompletableFuture.failedFuture(new SocketException("HTTP : Cannot connect to proxy, malformed response: " + response));
                }

                var statusCodePart = responseParts[1];
                try {
                    var statusCode = statusCodePart == null ? -1 : Integer.parseUnsignedInt(statusCodePart);
                    if(statusCode != OK_STATUS_CODE) {
                        return CompletableFuture.failedFuture(new SocketException("HTTP : Cannot connect to proxy, status code " + statusCode));
                    }

                    return CompletableFuture.completedFuture(null);
                }catch (Throwable throwable) {
                    return CompletableFuture.failedFuture(new SocketException("HTTP : Cannot connect to proxy: " + response));
                }
            }

            private CompletableFuture readAuthenticationResponse() {
                var future = new CompletableFuture();
                var buffer = ByteBuffer.allocate(readReceiveBufferSize());
                socketTransport.read(buffer, true, (Response.Callback) (result, error) -> {
                    if (error != null) {
                        future.completeExceptionally(new SocketException("HTTP : Cannot read authentication response", error));
                        return;
                    }

                    var data = new byte[result];
                    buffer.get(data);
                    future.complete(new String(data));
                });
                return future;
            }

            private int readReceiveBufferSize() {
                try {
                    return channel.getOption(StandardSocketOptions.SO_RCVBUF);
                }catch (IOException exception) {
                    return DEFAULT_RCV_BUF;
                }
            }

            private CompletableFuture sendAuthentication(InetSocketAddress endpoint) {
                var builder = new StringBuilder();
                builder.append("CONNECT ")
                        .append(endpoint.getHostName())
                        .append(":")
                        .append(endpoint.getPort())
                        .append(" HTTP/1.1\r\n");
                builder.append("Host: ")
                        .append(endpoint.getHostName())
                        .append(":")
                        .append(endpoint.getPort())
                        .append("\r\n");
                var authInfo = proxy.getUserInfo();
                if (authInfo != null) {
                    builder.append("Proxy-Authorization: Basic ")
                            .append(Base64.getEncoder().encodeToString(authInfo.getBytes()))
                            .append("\r\n");
                }
                builder.append("\r\n");
                var result = new Response.Future();
                socketTransport.write(ByteBuffer.wrap(builder.toString().getBytes()), result);
                return result;
            }
        }

        private static final class SocksProxy extends SocketConnection {
            private static final byte VERSION_5 = 5;

            private static final int NO_AUTH = 0;
            private static final int USER_PASSW = 2;
            private static final int NO_METHODS = -1;

            private static final int CONNECT = 1;

            private static final int IPV4 = 1;
            private static final int DOMAIN_NAME = 3;
            private static final int IPV6 = 4;

            private static final int REQUEST_OK = 0;
            private static final int GENERAL_FAILURE = 1;
            private static final int NOT_ALLOWED = 2;
            private static final int NET_UNREACHABLE = 3;
            private static final int HOST_UNREACHABLE = 4;
            private static final int CONN_REFUSED = 5;
            private static final int TTL_EXPIRED = 6;
            private static final int CMD_NOT_SUPPORTED = 7;
            private static final int ADDR_TYPE_NOT_SUP = 8;

            private SocksProxy(AsynchronousSocketChannel channel, SocketTransport socketTransport, URI proxy) {
                super(channel, socketTransport, proxy);
            }


            @Override
            public CompletableFuture connectAsync(InetSocketAddress address, int timeout) {
                return super.connectAsync(new InetSocketAddress(proxy.getHost(), proxy.getPort()), timeout)
                        .thenComposeAsync(openResult -> sendAuthenticationRequest())
                        .thenComposeAsync(response -> sendAuthenticationData(address, response));
            }

            private CompletableFuture sendAuthenticationRequest() {
                var connectionPayload = new ByteArrayOutputStream();
                connectionPayload.write(VERSION_5);
                connectionPayload.write(2);
                connectionPayload.write(NO_AUTH);
                connectionPayload.write(USER_PASSW);
                var result = new Response.Future();
                socketTransport.write(ByteBuffer.wrap(connectionPayload.toByteArray()), result);
                return result.thenComposeAsync(connectionResult -> readServerResponse(2, "Cannot read authentication request response"));
            }

            private CompletionStage sendAuthenticationData(InetSocketAddress address, ByteBuffer response) {
                var socksVersion = response.get();
                if (socksVersion != VERSION_5) {
                    return CompletableFuture.failedFuture(new SocketException("SOCKS : Invalid version"));
                }

                var method = response.get();
                if (method == NO_METHODS) {
                    return CompletableFuture.failedFuture(new SocketException("SOCKS : No acceptable methods"));
                }

                if (method == NO_AUTH) {
                    return sendConnectionData(address, null);
                }

                if (method != USER_PASSW) {
                    return CompletableFuture.failedFuture(new SocketException("SOCKS : authentication failed"));
                }

                var userInfo = Proxies.parseUserInfo(proxy.getUserInfo());
                if (userInfo == null) {
                    return CompletableFuture.failedFuture(new SocketException("SOCKS : missing user info"));
                }

                var outputStream = new ByteArrayOutputStream();
                outputStream.write(1);
                outputStream.write(userInfo.username().length());
                outputStream.writeBytes(userInfo.username().getBytes(StandardCharsets.ISO_8859_1));
                if (userInfo.password() != null) {
                    outputStream.write(userInfo.password().length());
                    outputStream.writeBytes(userInfo.password().getBytes(StandardCharsets.ISO_8859_1));
                } else {
                    outputStream.write(0);
                }
                var result = new Response.Future();
                socketTransport.write(ByteBuffer.wrap(outputStream.toByteArray()), result);
                return result.thenComposeAsync(connectionResult -> readServerResponse(2, "Cannot read authentication data response"))
                        .thenComposeAsync(connectionResponse -> sendConnectionData(address, connectionResponse));
            }

            private CompletableFuture sendConnectionData(InetSocketAddress address, ByteBuffer connectionResponse) {
                if(connectionResponse != null && connectionResponse.get(1) != 0) {
                    return CompletableFuture.failedFuture(new SocketException("SOCKS : authentication failed"));
                }

                var outputStream = new ByteArrayOutputStream();
                outputStream.write(VERSION_5);
                outputStream.write(CONNECT);
                outputStream.write(0);
                outputStream.write(DOMAIN_NAME);
                outputStream.write(address.getHostName().length());
                outputStream.writeBytes(address.getHostName().getBytes(StandardCharsets.ISO_8859_1));
                outputStream.write((address.getPort() >> 8) & 0xff);
                outputStream.write((address.getPort()) & 0xff);
                var result = new Response.Future();
                socketTransport.write(ByteBuffer.wrap(outputStream.toByteArray()), result);
                return result.thenComposeAsync(authenticationResult -> readServerResponse(4, "Cannot read connection data response"))
                        .thenComposeAsync(this::onConnected);
            }

            private CompletableFuture onConnected(ByteBuffer authenticationResponse) {
                if(authenticationResponse.limit() < 2) {
                    return CompletableFuture.failedFuture(new SocketException("SOCKS malformed response"));
                }

                return switch (authenticationResponse.get(1)) {
                    case REQUEST_OK -> onConnected(authenticationResponse.get(3));
                    case GENERAL_FAILURE -> CompletableFuture.failedFuture(new SocketException("SOCKS server general failure"));
                    case NOT_ALLOWED -> CompletableFuture.failedFuture(new SocketException("SOCKS: Connection not allowed by ruleset"));
                    case NET_UNREACHABLE -> CompletableFuture.failedFuture(new SocketException("SOCKS: Network unreachable"));
                    case HOST_UNREACHABLE -> CompletableFuture.failedFuture(new SocketException("SOCKS: Host unreachable"));
                    case CONN_REFUSED -> CompletableFuture.failedFuture(new SocketException("SOCKS: Connection refused"));
                    case TTL_EXPIRED -> CompletableFuture.failedFuture(new SocketException("SOCKS: TTL expired"));
                    case CMD_NOT_SUPPORTED -> CompletableFuture.failedFuture(new SocketException("SOCKS: Command not supported"));
                    case ADDR_TYPE_NOT_SUP -> CompletableFuture.failedFuture(new SocketException("SOCKS: address type not supported"));
                    default -> CompletableFuture.failedFuture(new SocketException("SOCKS: unhandled error"));
                };
            }

            private CompletableFuture onConnected(byte authenticationType) {
                return switch (authenticationType) {
                    case IPV4 -> readServerResponse(4, "Cannot read IPV4 address")
                            .thenComposeAsync(ipResult -> readServerResponse(2, "Cannot read IPV4 port"))
                            .thenRun(() -> {});
                    case IPV6 -> readServerResponse(16, "Cannot read IPV6 address")
                            .thenComposeAsync(ipResult -> readServerResponse(2, "Cannot read IPV6 port"))
                            .thenRun(() -> {});
                    case DOMAIN_NAME -> readServerResponse(1, "Cannot read domain name")
                            .thenComposeAsync(domainLengthBuffer -> readServerResponse(Byte.toUnsignedInt(domainLengthBuffer.get()), "Cannot read domain hostname"))
                            .thenComposeAsync(ipResult -> readServerResponse(2, "Cannot read domain port"))
                            .thenRun(() -> {});
                    default -> CompletableFuture.failedFuture(new SocketException("Reply from SOCKS server contains wrong code"));
                };
            }

            private CompletableFuture readServerResponse(int length, String errorMessage) {
                var buffer = ByteBuffer.allocate(length);
                var result = new Response.Future();
                socketTransport.readFully(buffer, result);
                return result.exceptionallyCompose(error -> CompletableFuture.failedFuture(new SocketException(errorMessage, error)));
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy