All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonJvmMain.com.ditchoom.socket.SSLClientSocket.kt Maven / Gradle / Ivy

package com.ditchoom.socket

import com.ditchoom.buffer.AllocationZone
import com.ditchoom.buffer.JvmBuffer
import com.ditchoom.buffer.PlatformBuffer
import com.ditchoom.buffer.ReadBuffer
import com.ditchoom.buffer.ReadBuffer.Companion.EMPTY_BUFFER
import com.ditchoom.buffer.allocate
import com.ditchoom.socket.nio.ByteBufferClientSocket
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.security.NoSuchAlgorithmException
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLEngine
import javax.net.ssl.SSLEngineResult
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

class SSLClientSocket(
    private val underlyingSocket: ClientToServerSocket,
) : ClientToServerSocket {
    private val byteBufferClientSocket = underlyingSocket as ByteBufferClientSocket<*>
    private val closeTimeout = 1.seconds
    private lateinit var engine: SSLEngine
    private var overflowEncryptedReadBuffer: JvmBuffer? = null

    override suspend fun open(
        port: Int,
        timeout: Duration,
        hostname: String?,
    ) {
        val context =
            try {
                SSLContext.getInstance("TLSv1.3")
            } catch (e: NoSuchAlgorithmException) {
                SSLContext.getInstance("TLSv1.2")
            }
        context.init(null, null, null)
        engine = context.createSSLEngine(hostname, port)
        engine.useClientMode = true
        engine.beginHandshake()
        underlyingSocket.open(port, timeout, hostname)
        doHandshake(timeout)
    }

    override fun isOpen(): Boolean = underlyingSocket.isOpen()

    override suspend fun localPort(): Int = underlyingSocket.localPort()

    override suspend fun remotePort(): Int = underlyingSocket.remotePort()

    override suspend fun read(timeout: Duration): ReadBuffer = unwrap(timeout)

    override suspend fun write(
        buffer: ReadBuffer,
        timeout: Duration,
    ): Int = wrap(buffer as JvmBuffer, timeout)

    private suspend fun doHandshake(timeout: Duration) {
        var cachedBuffer: JvmBuffer? = null
        val emptyBuffer = EMPTY_BUFFER as JvmBuffer
        while (engine.handshakeStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
            when (engine.handshakeStatus) {
                SSLEngineResult.HandshakeStatus.NEED_WRAP -> wrap(emptyBuffer, timeout)
                SSLEngineResult.HandshakeStatus.NEED_TASK ->
                    withContext(Dispatchers.IO) { engine.delegatedTask.run() }

                else -> { // UNWRAP + UNWRAP AGAIN
                    val dataRead =
                        if (cachedBuffer != null) {
                            cachedBuffer
                        } else {
                            val plainTextReadBuffer =
                                bufferFactory(engine.session.applicationBufferSize)
                            byteBufferClientSocket.read(plainTextReadBuffer, timeout)
                            plainTextReadBuffer.resetForRead()
                            plainTextReadBuffer
                        }
                    val result = engine.unwrap(dataRead.byteBuffer, emptyBuffer.byteBuffer)
                    cachedBuffer =
                        if (dataRead.byteBuffer.hasRemaining()) {
                            dataRead
                        } else {
                            null
                        }
                    when (checkNotNull(result.status)) {
                        SSLEngineResult.Status.BUFFER_UNDERFLOW -> {
                            cachedBuffer ?: continue
                            cachedBuffer.byteBuffer.compact()
                            byteBufferClientSocket.read(cachedBuffer, timeout)
                            cachedBuffer.resetForRead()
                        }

                        SSLEngineResult.Status.BUFFER_OVERFLOW ->
                            throw IllegalStateException("Unwrap Buffer Overflow")

                        SSLEngineResult.Status.CLOSED ->
                            throw IllegalStateException("SSLEngineResult Status Closed")

                        SSLEngineResult.Status.OK -> continue
                    }
                }
            }
        }
    }

    private suspend fun wrap(
        plainText: JvmBuffer,
        timeout: Duration,
    ): Int {
        val encrypted = bufferFactory(engine.session.packetBufferSize)
        val result = engine.wrap(plainText.byteBuffer, encrypted.byteBuffer)
        when (result.status!!) {
            SSLEngineResult.Status.BUFFER_UNDERFLOW -> throw IllegalStateException("SSL Engine Buffer Underflow - wrap")
            SSLEngineResult.Status.BUFFER_OVERFLOW -> {
                throw IllegalStateException("SSL Engine Buffer Overflow - wrap")
            }

            SSLEngineResult.Status.CLOSED,
            SSLEngineResult.Status.OK,
            -> {
                encrypted.resetForRead()
                var writtenBytes = 0
                while (encrypted.hasRemaining()) {
                    val bytesWrote = underlyingSocket.write(encrypted, timeout)
                    if (bytesWrote < 0) {
                        return -1
                    }
                    writtenBytes += bytesWrote
                }
                return writtenBytes
            }
        }
    }

    private fun bufferFactory(size: Int): JvmBuffer {
        return PlatformBuffer.allocate(size, AllocationZone.Direct) as JvmBuffer
    }

    private suspend fun unwrap(timeout: Duration): ReadBuffer {
        val byteBufferClientSocket = underlyingSocket as ByteBufferClientSocket<*>
        val encryptedReadBuffer =
            overflowEncryptedReadBuffer
                ?: bufferFactory(engine.session.packetBufferSize).also {
                    val bytesRead = byteBufferClientSocket.read(it, timeout)
                    if (bytesRead < 1) {
                        return EMPTY_BUFFER
                    }
                    it.resetForRead()
                }
        val plainTextReadBuffer = bufferFactory(engine.session.applicationBufferSize)
        while (encryptedReadBuffer.hasRemaining()) {
            val result =
                engine.unwrap(encryptedReadBuffer.byteBuffer, plainTextReadBuffer.byteBuffer)
            when (checkNotNull(result.status)) {
                SSLEngineResult.Status.BUFFER_OVERFLOW -> {
                    // plaintext buffer is too small, cache the encrypted read buffer so we can use it for next time
                    overflowEncryptedReadBuffer = encryptedReadBuffer
                    return slicePlainText(plainTextReadBuffer)
                }

                SSLEngineResult.Status.BUFFER_UNDERFLOW -> {
                    encryptedReadBuffer.byteBuffer.compact()
                    byteBufferClientSocket.read(encryptedReadBuffer, timeout)
                    encryptedReadBuffer.resetForRead()
                    overflowEncryptedReadBuffer = encryptedReadBuffer
                }

                SSLEngineResult.Status.OK -> {
                    overflowEncryptedReadBuffer = null
                }

                SSLEngineResult.Status.CLOSED -> {
                    overflowEncryptedReadBuffer = null
                    close()
                    return slicePlainText(plainTextReadBuffer)
                }
            }
        }
        return slicePlainText(plainTextReadBuffer)
    }

    private fun slicePlainText(plainText: JvmBuffer): JvmBuffer {
        val position = plainText.position()
        plainText.position(0)
        plainText.setLimit(position)
        val slicedBuffer = plainText.slice()
        slicedBuffer.position(slicedBuffer.limit())
        return slicedBuffer
    }

    override suspend fun close() {
        engine.closeOutbound()
        doHandshake(closeTimeout)
        underlyingSocket.close()
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy