All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.lark.oapi.okhttp.internal.ws.RealWebSocket Maven / Gradle / Ivy

/*
 *
 *  * Copyright (C) 2015 Square, Inc.
 *  *
 *  * Licensed under the Apache License, Version 2.0 (the "License");
 *  * you may not use this file except in compliance with the License.
 *  * You may obtain a copy of the License at
 *  *
 *  *      http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS,
 *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  * See the License for the specific language governing permissions and
 *  * limitations under the License.
 *
 */
package com.lark.oapi.okhttp.internal.ws;

import com.lark.oapi.okhttp.*;
import com.lark.oapi.okhttp.internal.Internal;
import com.lark.oapi.okhttp.internal.Util;
import com.lark.oapi.okhttp.internal.connection.Exchange;
import com.lark.oapi.okio.BufferedSink;
import com.lark.oapi.okio.BufferedSource;
import com.lark.oapi.okio.ByteString;
import com.lark.oapi.okio.Okio;

import javax.annotation.Nullable;
import java.io.Closeable;
import java.io.IOException;
import java.net.ProtocolException;
import java.net.SocketTimeoutException;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import static com.lark.oapi.okhttp.internal.ws.WebSocketProtocol.*;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

public final class RealWebSocket implements WebSocket, WebSocketReader.FrameCallback {

    private static final List ONLY_HTTP1 = Collections.singletonList(Protocol.HTTP_1_1);

    /**
     * The maximum number of bytes to enqueue. Rather than enqueueing beyond this limit we tear down
     * the web socket! It's possible that we're writing faster than the peer can read.
     */
    private static final long MAX_QUEUE_SIZE = 16 * 1024 * 1024; // 16 MiB.

    /**
     * The maximum amount of time after the client calls {@link #close} to wait for a graceful
     * shutdown. If the server doesn't respond the websocket will be canceled.
     */
    private static final long CANCEL_AFTER_CLOSE_MILLIS = 60 * 1000;
    final WebSocketListener listener;
    /**
     * The application's original request unadulterated by web socket headers.
     */
    private final Request originalRequest;
    private final Random random;
    private final long pingIntervalMillis;
    private final String key;
    /**
     * This runnable processes the outgoing queues. Call {@link #runWriter()} to after enqueueing.
     */
    private final Runnable writerRunnable;
    /**
     * Outgoing pongs in the order they should be written.
     */
    private final ArrayDeque pongQueue = new ArrayDeque<>();
    /**
     * Outgoing messages and close frames in the order they should be written.
     */
    private final ArrayDeque messageAndCloseQueue = new ArrayDeque<>();

    // All mutable web socket state is guarded by this.
    /**
     * Non-null for client web sockets. These can be canceled.
     */
    private Call call;
    /**
     * Null until this web socket is connected. Only accessed by the reader thread.
     */
    private WebSocketReader reader;
    /**
     * Null until this web socket is connected. Note that messages may be enqueued before that.
     */
    private WebSocketWriter writer;
    /**
     * Null until this web socket is connected. Used for writes, pings, and close timeouts.
     */
    private ScheduledExecutorService executor;
    /**
     * The streams held by this web socket. This is non-null until all incoming messages have been
     * read and all outgoing messages have been written. It is closed when both reader and writer are
     * exhausted, or if there is any failure.
     */
    private Streams streams;
    /**
     * The total size in bytes of enqueued but not yet transmitted messages.
     */
    private long queueSize;

    /**
     * True if we've enqueued a close frame. No further message frames will be enqueued.
     */
    private boolean enqueuedClose;

    /**
     * When executed this will cancel this websocket. This future itself should be canceled if that is
     * unnecessary because the web socket is already closed or canceled.
     */
    private ScheduledFuture cancelFuture;

    /**
     * The close code from the peer, or -1 if this web socket has not yet read a close frame.
     */
    private int receivedCloseCode = -1;

    /**
     * The close reason from the peer, or null if this web socket has not yet read a close frame.
     */
    private String receivedCloseReason;

    /**
     * True if this web socket failed and the listener has been notified.
     */
    private boolean failed;

    /**
     * Total number of pings sent by this web socket.
     */
    private int sentPingCount;

    /**
     * Total number of pings received by this web socket.
     */
    private int receivedPingCount;

    /**
     * Total number of pongs received by this web socket.
     */
    private int receivedPongCount;

    /**
     * True if we have sent a ping that is still awaiting a reply.
     */
    private boolean awaitingPong;

    public RealWebSocket(Request request, WebSocketListener listener, Random random,
                         long pingIntervalMillis) {
        if (!"GET".equals(request.method())) {
            throw new IllegalArgumentException("Request must be GET: " + request.method());
        }
        this.originalRequest = request;
        this.listener = listener;
        this.random = random;
        this.pingIntervalMillis = pingIntervalMillis;

        byte[] nonce = new byte[16];
        random.nextBytes(nonce);
        this.key = ByteString.of(nonce).base64();

        this.writerRunnable = () -> {
            try {
                while (writeOneFrame()) {
                }
            } catch (IOException e) {
                failWebSocket(e, null);
            }
        };
    }

    @Override
    public Request request() {
        return originalRequest;
    }

    @Override
    public synchronized long queueSize() {
        return queueSize;
    }

    @Override
    public void cancel() {
        call.cancel();
    }

    public void connect(OkHttpClient client) {
        client = client.newBuilder()
                .eventListener(EventListener.NONE)
                .protocols(ONLY_HTTP1)
                .build();
        final Request request = originalRequest.newBuilder()
                .header("Upgrade", "websocket")
                .header("Connection", "Upgrade")
                .header("Sec-WebSocket-Key", key)
                .header("Sec-WebSocket-Version", "13")
                .build();
        call = Internal.instance.newWebSocketCall(client, request);
        call.enqueue(new Callback() {
            @Override
            public void onResponse(Call call, Response response) {
                Exchange exchange = Internal.instance.exchange(response);
                Streams streams;
                try {
                    checkUpgradeSuccess(response, exchange);
                    streams = exchange.newWebSocketStreams();
                } catch (IOException e) {
                    if (exchange != null) {
                        exchange.webSocketUpgradeFailed();
                    }
                    failWebSocket(e, response);
                    Util.closeQuietly(response);
                    return;
                }

                // Process all web socket messages.
                try {
                    String name = "OkHttp WebSocket " + request.url().redact();
                    initReaderAndWriter(name, streams);
                    listener.onOpen(RealWebSocket.this, response);
                    loopReader();
                } catch (Exception e) {
                    failWebSocket(e, null);
                }
            }

            @Override
            public void onFailure(Call call, IOException e) {
                failWebSocket(e, null);
            }
        });
    }

    void checkUpgradeSuccess(Response response, @Nullable Exchange exchange) throws IOException {
        if (response.code() != 101) {
            throw new ProtocolException("Expected HTTP 101 response but was '"
                    + response.code() + " " + response.message() + "'");
        }

        String headerConnection = response.header("Connection");
        if (!"Upgrade".equalsIgnoreCase(headerConnection)) {
            throw new ProtocolException("Expected 'Connection' header value 'Upgrade' but was '"
                    + headerConnection + "'");
        }

        String headerUpgrade = response.header("Upgrade");
        if (!"websocket".equalsIgnoreCase(headerUpgrade)) {
            throw new ProtocolException(
                    "Expected 'Upgrade' header value 'websocket' but was '" + headerUpgrade + "'");
        }

        String headerAccept = response.header("Sec-WebSocket-Accept");
        String acceptExpected = ByteString.encodeUtf8(key + WebSocketProtocol.ACCEPT_MAGIC)
                .sha1().base64();
        if (!acceptExpected.equals(headerAccept)) {
            throw new ProtocolException("Expected 'Sec-WebSocket-Accept' header value '"
                    + acceptExpected + "' but was '" + headerAccept + "'");
        }

        if (exchange == null) {
            throw new ProtocolException("Web Socket exchange missing: bad interceptor?");
        }
    }

    public void initReaderAndWriter(String name, Streams streams) throws IOException {
        synchronized (this) {
            this.streams = streams;
            this.writer = new WebSocketWriter(streams.client, streams.sink, random);
            this.executor = new ScheduledThreadPoolExecutor(1, Util.threadFactory(name, false));
            if (pingIntervalMillis != 0) {
                executor.scheduleAtFixedRate(
                        new PingRunnable(), pingIntervalMillis, pingIntervalMillis, MILLISECONDS);
            }
            if (!messageAndCloseQueue.isEmpty()) {
                runWriter(); // Send messages that were enqueued before we were connected.
            }
        }

        reader = new WebSocketReader(streams.client, streams.source, this);
    }

    /**
     * Receive frames until there are no more. Invoked only by the reader thread.
     */
    public void loopReader() throws IOException {
        while (receivedCloseCode == -1) {
            // This method call results in one or more onRead* methods being called on this thread.
            reader.processNextFrame();
        }
    }

    /**
     * For testing: receive a single frame and return true if there are more frames to read. Invoked
     * only by the reader thread.
     */
    boolean processNextFrame() throws IOException {
        try {
            reader.processNextFrame();
            return receivedCloseCode == -1;
        } catch (Exception e) {
            failWebSocket(e, null);
            return false;
        }
    }

    /**
     * For testing: wait until the web socket's executor has terminated.
     */
    void awaitTermination(int timeout, TimeUnit timeUnit) throws InterruptedException {
        executor.awaitTermination(timeout, timeUnit);
    }

    /**
     * For testing: force this web socket to release its threads.
     */
    void tearDown() throws InterruptedException {
        if (cancelFuture != null) {
            cancelFuture.cancel(false);
        }
        executor.shutdown();
        executor.awaitTermination(10, TimeUnit.SECONDS);
    }

    synchronized int sentPingCount() {
        return sentPingCount;
    }

    synchronized int receivedPingCount() {
        return receivedPingCount;
    }

    synchronized int receivedPongCount() {
        return receivedPongCount;
    }

    @Override
    public void onReadMessage(String text) throws IOException {
        listener.onMessage(this, text);
    }

    @Override
    public void onReadMessage(ByteString bytes) throws IOException {
        listener.onMessage(this, bytes);
    }

    @Override
    public synchronized void onReadPing(ByteString payload) {
        // Don't respond to pings after we've failed or sent the close frame.
        if (failed || (enqueuedClose && messageAndCloseQueue.isEmpty())) {
            return;
        }

        pongQueue.add(payload);
        runWriter();
        receivedPingCount++;
    }

    @Override
    public synchronized void onReadPong(ByteString buffer) {
        // This API doesn't expose pings.
        receivedPongCount++;
        awaitingPong = false;
    }

    @Override
    public void onReadClose(int code, String reason) {
        if (code == -1) {
            throw new IllegalArgumentException();
        }

        Streams toClose = null;
        synchronized (this) {
            if (receivedCloseCode != -1) {
                throw new IllegalStateException("already closed");
            }
            receivedCloseCode = code;
            receivedCloseReason = reason;
            if (enqueuedClose && messageAndCloseQueue.isEmpty()) {
                toClose = this.streams;
                this.streams = null;
                if (cancelFuture != null) {
                    cancelFuture.cancel(false);
                }
                this.executor.shutdown();
            }
        }

        try {
            listener.onClosing(this, code, reason);

            if (toClose != null) {
                listener.onClosed(this, code, reason);
            }
        } finally {
            Util.closeQuietly(toClose);
        }
    }

    // Writer methods to enqueue frames. They'll be sent asynchronously by the writer thread.

    @Override
    public boolean send(String text) {
        if (text == null) {
            throw new NullPointerException("text == null");
        }
        return send(ByteString.encodeUtf8(text), OPCODE_TEXT);
    }

    @Override
    public boolean send(ByteString bytes) {
        if (bytes == null) {
            throw new NullPointerException("bytes == null");
        }
        return send(bytes, OPCODE_BINARY);
    }

    private synchronized boolean send(ByteString data, int formatOpcode) {
        // Don't send new frames after we've failed or enqueued a close frame.
        if (failed || enqueuedClose) {
            return false;
        }

        // If this frame overflows the buffer, reject it and close the web socket.
        if (queueSize + data.size() > MAX_QUEUE_SIZE) {
            close(CLOSE_CLIENT_GOING_AWAY, null);
            return false;
        }

        // Enqueue the message frame.
        queueSize += data.size();
        messageAndCloseQueue.add(new Message(formatOpcode, data));
        runWriter();
        return true;
    }

    synchronized boolean pong(ByteString payload) {
        // Don't send pongs after we've failed or sent the close frame.
        if (failed || (enqueuedClose && messageAndCloseQueue.isEmpty())) {
            return false;
        }

        pongQueue.add(payload);
        runWriter();
        return true;
    }

    @Override
    public boolean close(int code, String reason) {
        return close(code, reason, CANCEL_AFTER_CLOSE_MILLIS);
    }

    synchronized boolean close(int code, String reason, long cancelAfterCloseMillis) {
        validateCloseCode(code);

        ByteString reasonBytes = null;
        if (reason != null) {
            reasonBytes = ByteString.encodeUtf8(reason);
            if (reasonBytes.size() > CLOSE_MESSAGE_MAX) {
                throw new IllegalArgumentException("reason.size() > " + CLOSE_MESSAGE_MAX + ": " + reason);
            }
        }

        if (failed || enqueuedClose) {
            return false;
        }

        // Immediately prevent further frames from being enqueued.
        enqueuedClose = true;

        // Enqueue the close frame.
        messageAndCloseQueue.add(new Close(code, reasonBytes, cancelAfterCloseMillis));
        runWriter();
        return true;
    }

    private void runWriter() {
        assert (Thread.holdsLock(this));

        if (executor != null) {
            executor.execute(writerRunnable);
        }
    }

    /**
     * Attempts to remove a single frame from a queue and send it. This prefers to write urgent pongs
     * before less urgent messages and close frames. For example it's possible that a caller will
     * enqueue messages followed by pongs, but this sends pongs followed by messages. Pongs are always
     * written in the order they were enqueued.
     *
     * 

If a frame cannot be sent - because there are none enqueued or because the web socket is * not connected - this does nothing and returns false. Otherwise this returns true and the caller * should immediately invoke this method again until it returns false. * *

This method may only be invoked by the writer thread. There may be only thread invoking * this method at a time. */ boolean writeOneFrame() throws IOException { WebSocketWriter writer; ByteString pong; Object messageOrClose = null; int receivedCloseCode = -1; String receivedCloseReason = null; Streams streamsToClose = null; synchronized (RealWebSocket.this) { if (failed) { return false; // Failed web socket. } writer = this.writer; pong = pongQueue.poll(); if (pong == null) { messageOrClose = messageAndCloseQueue.poll(); if (messageOrClose instanceof Close) { receivedCloseCode = this.receivedCloseCode; receivedCloseReason = this.receivedCloseReason; if (receivedCloseCode != -1) { streamsToClose = this.streams; this.streams = null; this.executor.shutdown(); } else { // When we request a graceful close also schedule a cancel of the websocket. cancelFuture = executor.schedule(new CancelRunnable(), ((Close) messageOrClose).cancelAfterCloseMillis, MILLISECONDS); } } else if (messageOrClose == null) { return false; // The queue is exhausted. } } } try { if (pong != null) { writer.writePong(pong); } else if (messageOrClose instanceof Message) { ByteString data = ((Message) messageOrClose).data; BufferedSink sink = Okio.buffer(writer.newMessageSink( ((Message) messageOrClose).formatOpcode, data.size())); sink.write(data); sink.close(); synchronized (this) { queueSize -= data.size(); } } else if (messageOrClose instanceof Close) { Close close = (Close) messageOrClose; writer.writeClose(close.code, close.reason); // We closed the writer: now both reader and writer are closed. if (streamsToClose != null) { listener.onClosed(this, receivedCloseCode, receivedCloseReason); } } else { throw new AssertionError(); } return true; } finally { Util.closeQuietly(streamsToClose); } } void writePingFrame() { WebSocketWriter writer; int failedPing; synchronized (this) { if (failed) { return; } writer = this.writer; failedPing = awaitingPong ? sentPingCount : -1; sentPingCount++; awaitingPong = true; } if (failedPing != -1) { failWebSocket(new SocketTimeoutException("sent ping but didn't receive pong within " + pingIntervalMillis + "ms (after " + (failedPing - 1) + " successful ping/pongs)"), null); return; } try { writer.writePing(ByteString.EMPTY); } catch (IOException e) { failWebSocket(e, null); } } public void failWebSocket(Exception e, @Nullable Response response) { Streams streamsToClose; synchronized (this) { if (failed) { return; // Already failed. } failed = true; streamsToClose = this.streams; this.streams = null; if (cancelFuture != null) { cancelFuture.cancel(false); } if (executor != null) { executor.shutdown(); } } try { listener.onFailure(this, e, response); } finally { Util.closeQuietly(streamsToClose); } } static final class Message { final int formatOpcode; final ByteString data; Message(int formatOpcode, ByteString data) { this.formatOpcode = formatOpcode; this.data = data; } } static final class Close { final int code; final ByteString reason; final long cancelAfterCloseMillis; Close(int code, ByteString reason, long cancelAfterCloseMillis) { this.code = code; this.reason = reason; this.cancelAfterCloseMillis = cancelAfterCloseMillis; } } public abstract static class Streams implements Closeable { public final boolean client; public final BufferedSource source; public final BufferedSink sink; public Streams(boolean client, BufferedSource source, BufferedSink sink) { this.client = client; this.source = source; this.sink = sink; } } private final class PingRunnable implements Runnable { PingRunnable() { } @Override public void run() { writePingFrame(); } } final class CancelRunnable implements Runnable { @Override public void run() { cancel(); } } }