All Downloads are FREE. Search and download functionalities are using the official Maven repository.

okhttp3.internal.ws.WebSocketReader.kt Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (C) 2014 Square, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package okhttp3.internal.ws

import java.io.Closeable
import java.io.IOException
import java.net.ProtocolException
import java.util.concurrent.TimeUnit
import okhttp3.internal.and
import okhttp3.internal.toHexString
import okhttp3.internal.ws.WebSocketProtocol.B0_FLAG_FIN
import okhttp3.internal.ws.WebSocketProtocol.B0_FLAG_RSV1
import okhttp3.internal.ws.WebSocketProtocol.B0_FLAG_RSV2
import okhttp3.internal.ws.WebSocketProtocol.B0_FLAG_RSV3
import okhttp3.internal.ws.WebSocketProtocol.B0_MASK_OPCODE
import okhttp3.internal.ws.WebSocketProtocol.B1_FLAG_MASK
import okhttp3.internal.ws.WebSocketProtocol.B1_MASK_LENGTH
import okhttp3.internal.ws.WebSocketProtocol.CLOSE_NO_STATUS_CODE
import okhttp3.internal.ws.WebSocketProtocol.OPCODE_BINARY
import okhttp3.internal.ws.WebSocketProtocol.OPCODE_CONTINUATION
import okhttp3.internal.ws.WebSocketProtocol.OPCODE_CONTROL_CLOSE
import okhttp3.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PING
import okhttp3.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PONG
import okhttp3.internal.ws.WebSocketProtocol.OPCODE_FLAG_CONTROL
import okhttp3.internal.ws.WebSocketProtocol.OPCODE_TEXT
import okhttp3.internal.ws.WebSocketProtocol.PAYLOAD_BYTE_MAX
import okhttp3.internal.ws.WebSocketProtocol.PAYLOAD_LONG
import okhttp3.internal.ws.WebSocketProtocol.PAYLOAD_SHORT
import okhttp3.internal.ws.WebSocketProtocol.toggleMask
import okio.Buffer
import okio.BufferedSource
import okio.ByteString

/**
 * An [RFC 6455][rfc_6455]-compatible WebSocket frame reader.
 *
 * This class is not thread safe.
 *
 * [rfc_6455]: http://tools.ietf.org/html/rfc6455
 */
class WebSocketReader(
  private val isClient: Boolean,
  val source: BufferedSource,
  private val frameCallback: FrameCallback,
  private val perMessageDeflate: Boolean,
  private val noContextTakeover: Boolean
) : Closeable {
  private var closed = false

  // Stateful data about the current frame.
  private var opcode = 0
  private var frameLength = 0L
  private var isFinalFrame = false
  private var isControlFrame = false
  private var readingCompressedMessage = false

  private val controlFrameBuffer = Buffer()
  private val messageFrameBuffer = Buffer()

  /** Lazily initialized on first use. */
  private var messageInflater: MessageInflater? = null

  // Masks are only a concern for server writers.
  private val maskKey: ByteArray? = if (isClient) null else ByteArray(4)
  private val maskCursor: Buffer.UnsafeCursor? = if (isClient) null else Buffer.UnsafeCursor()

  interface FrameCallback {
    @Throws(IOException::class)
    fun onReadMessage(text: String)

    @Throws(IOException::class)
    fun onReadMessage(bytes: ByteString)

    fun onReadPing(payload: ByteString)
    fun onReadPong(payload: ByteString)
    fun onReadClose(code: Int, reason: String)
  }

  /**
   * Process the next protocol frame.
   *
   *  * If it is a control frame this will result in a single call to [FrameCallback].
   *  * If it is a message frame this will result in a single call to [FrameCallback.onReadMessage].
   *    If the message spans multiple frames, each interleaved control frame will result in a
   *    corresponding call to [FrameCallback].
   */
  @Throws(IOException::class)
  fun processNextFrame() {
    readHeader()
    if (isControlFrame) {
      readControlFrame()
    } else {
      readMessageFrame()
    }
  }

  @Throws(IOException::class, ProtocolException::class)
  private fun readHeader() {
    if (closed) throw IOException("closed")

    // Disable the timeout to read the first byte of a new frame.
    val b0: Int
    val timeoutBefore = source.timeout().timeoutNanos()
    source.timeout().clearTimeout()
    try {
      b0 = source.readByte() and 0xff
    } finally {
      source.timeout().timeout(timeoutBefore, TimeUnit.NANOSECONDS)
    }

    opcode = b0 and B0_MASK_OPCODE
    isFinalFrame = b0 and B0_FLAG_FIN != 0
    isControlFrame = b0 and OPCODE_FLAG_CONTROL != 0

    // Control frames must be final frames (cannot contain continuations).
    if (isControlFrame && !isFinalFrame) {
      throw ProtocolException("Control frames must be final.")
    }

    val reservedFlag1 = b0 and B0_FLAG_RSV1 != 0
    when (opcode) {
      OPCODE_TEXT, OPCODE_BINARY -> {
        readingCompressedMessage = if (reservedFlag1) {
          if (!perMessageDeflate) throw ProtocolException("Unexpected rsv1 flag")
          true
        } else {
          false
        }
      }
      else -> {
        if (reservedFlag1) throw ProtocolException("Unexpected rsv1 flag")
      }
    }

    val reservedFlag2 = b0 and B0_FLAG_RSV2 != 0
    if (reservedFlag2) throw ProtocolException("Unexpected rsv2 flag")

    val reservedFlag3 = b0 and B0_FLAG_RSV3 != 0
    if (reservedFlag3) throw ProtocolException("Unexpected rsv3 flag")

    val b1 = source.readByte() and 0xff

    val isMasked = b1 and B1_FLAG_MASK != 0
    if (isMasked == isClient) {
      // Masked payloads must be read on the server. Unmasked payloads must be read on the client.
      throw ProtocolException(if (isClient) {
        "Server-sent frames must not be masked."
      } else {
        "Client-sent frames must be masked."
      })
    }

    // Get frame length, optionally reading from follow-up bytes if indicated by special values.
    frameLength = (b1 and B1_MASK_LENGTH).toLong()
    if (frameLength == PAYLOAD_SHORT.toLong()) {
      frameLength = (source.readShort() and 0xffff).toLong() // Value is unsigned.
    } else if (frameLength == PAYLOAD_LONG.toLong()) {
      frameLength = source.readLong()
      if (frameLength < 0L) {
        throw ProtocolException(
            "Frame length 0x${frameLength.toHexString()} > 0x7FFFFFFFFFFFFFFF")
      }
    }

    if (isControlFrame && frameLength > PAYLOAD_BYTE_MAX) {
      throw ProtocolException("Control frame must be less than ${PAYLOAD_BYTE_MAX}B.")
    }

    if (isMasked) {
      // Read the masking key as bytes so that they can be used directly for unmasking.
      source.readFully(maskKey!!)
    }
  }

  @Throws(IOException::class)
  private fun readControlFrame() {
    if (frameLength > 0L) {
      source.readFully(controlFrameBuffer, frameLength)

      if (!isClient) {
        controlFrameBuffer.readAndWriteUnsafe(maskCursor!!)
        maskCursor.seek(0)
        toggleMask(maskCursor, maskKey!!)
        maskCursor.close()
      }
    }

    when (opcode) {
      OPCODE_CONTROL_PING -> {
        frameCallback.onReadPing(controlFrameBuffer.readByteString())
      }
      OPCODE_CONTROL_PONG -> {
        frameCallback.onReadPong(controlFrameBuffer.readByteString())
      }
      OPCODE_CONTROL_CLOSE -> {
        var code = CLOSE_NO_STATUS_CODE
        var reason = ""
        val bufferSize = controlFrameBuffer.size
        if (bufferSize == 1L) {
          throw ProtocolException("Malformed close payload length of 1.")
        } else if (bufferSize != 0L) {
          code = controlFrameBuffer.readShort().toInt()
          reason = controlFrameBuffer.readUtf8()
          val codeExceptionMessage = WebSocketProtocol.closeCodeExceptionMessage(code)
          if (codeExceptionMessage != null) throw ProtocolException(codeExceptionMessage)
        }
        frameCallback.onReadClose(code, reason)
        closed = true
      }
      else -> {
        throw ProtocolException("Unknown control opcode: " + opcode.toHexString())
      }
    }
  }

  @Throws(IOException::class)
  private fun readMessageFrame() {
    val opcode = this.opcode
    if (opcode != OPCODE_TEXT && opcode != OPCODE_BINARY) {
      throw ProtocolException("Unknown opcode: ${opcode.toHexString()}")
    }

    readMessage()

    if (readingCompressedMessage) {
      val messageInflater = this.messageInflater
          ?: MessageInflater(noContextTakeover).also { this.messageInflater = it }
      messageInflater.inflate(messageFrameBuffer)
    }

    if (opcode == OPCODE_TEXT) {
      frameCallback.onReadMessage(messageFrameBuffer.readUtf8())
    } else {
      frameCallback.onReadMessage(messageFrameBuffer.readByteString())
    }
  }

  /** Read headers and process any control frames until we reach a non-control frame. */
  @Throws(IOException::class)
  private fun readUntilNonControlFrame() {
    while (!closed) {
      readHeader()
      if (!isControlFrame) {
        break
      }
      readControlFrame()
    }
  }

  /**
   * Reads a message body into across one or more frames. Control frames that occur between
   * fragments will be processed. If the message payload is masked this will unmask as it's being
   * processed.
   */
  @Throws(IOException::class)
  private fun readMessage() {
    while (true) {
      if (closed) throw IOException("closed")

      if (frameLength > 0L) {
        source.readFully(messageFrameBuffer, frameLength)

        if (!isClient) {
          messageFrameBuffer.readAndWriteUnsafe(maskCursor!!)
          maskCursor.seek(messageFrameBuffer.size - frameLength)
          toggleMask(maskCursor, maskKey!!)
          maskCursor.close()
        }
      }

      if (isFinalFrame) break // We are exhausted and have no continuations.

      readUntilNonControlFrame()
      if (opcode != OPCODE_CONTINUATION) {
        throw ProtocolException("Expected continuation opcode. Got: ${opcode.toHexString()}")
      }
    }
  }

  @Throws(IOException::class)
  override fun close() {
    messageInflater?.close()
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy