All Downloads are FREE. Search and download functionalities are using the official Maven repository.

korolev.effect.io.SecureDataSocket.scala Maven / Gradle / Ivy

The newest version!
package korolev.effect.io

import jdk.jshell.spi.ExecutionControl.NotImplementedException
import korolev.data.BytesLike
import korolev.effect
import korolev.effect.Effect
import korolev.effect.syntax._

import java.nio.ByteBuffer
import java.util.concurrent.Executor
import javax.net.ssl.SSLEngineResult.{HandshakeStatus, Status}
import javax.net.ssl.{SSLEngine, SSLHandshakeException}
import scala.collection.mutable

class SecureDataSocket[F[_] : Effect, B: BytesLike] private(engine: SSLEngine,
                                                            socket: RawDataSocket[F, B]) extends DataSocket[F, B] {

  import SecureDataSocket.UnwrapStatus

  val stream: effect.Stream[F, B] = new effect.Stream[F, B] {

    def pull(): F[Option[B]] = Effect[F].delayAsync {
      if (!canceled) {
        doUnwrap() map {
          case UnwrapStatus.Ok =>
            peerAppBuffer.flip()
            val bytes = BytesLike[B].copyBuffer(peerAppBuffer)
            peerAppBuffer.compact()
            Some(bytes)
          case _ =>
            closeInbound()
            None
        }
      } else {
        Effect[F].pure(None)
      }
    }

    def cancel(): F[Unit] = Effect[F].delayAsync {
      if (!canceled) {
        canceled = true
        engine.closeOutbound()
        socket.stream.cancel()
      } else {
        Effect[F].unit
      }
    }
  }

  def write(bytes: B): F[Unit] = Effect[F].delayAsync {
    appBuffer.clear()
    BytesLike[B].copyToBuffer(bytes, appBuffer)
    appBuffer.flip()
    doWrap().unit
  }

  def onClose(): F[DataSocket.CloseReason] =
    socket.onClose()

  private val appBuffer: ByteBuffer = ByteBuffer.allocate(engine.getSession.getApplicationBufferSize)
  @volatile private var netBuffer: ByteBuffer = ByteBuffer.allocate(engine.getSession.getPacketBufferSize)
  @volatile private var peerAppBuffer: ByteBuffer = ByteBuffer.allocate(engine.getSession.getApplicationBufferSize)
  @volatile private var peerNetBuffer: ByteBuffer = ByteBuffer.allocate(engine.getSession.getPacketBufferSize)
  @volatile private var canceled = false

  private def enlargeBuffer(buffer: ByteBuffer, newRemaining: Int) = {
    if (buffer.capacity() < newRemaining) {
      val newBuffer = ByteBuffer.allocate(buffer.remaining() + newRemaining)
      buffer.flip()
      newBuffer.put(buffer)
      newBuffer
    } else {
      buffer
    }
  }

  private def handshake(blockingExecutor: Executor): F[Unit] = {

    def handshakeLoop(handshakeStatus: HandshakeStatus): F[Unit] = handshakeStatus match {
      case HandshakeStatus.NOT_HANDSHAKING | HandshakeStatus.FINISHED =>
        peerAppBuffer.clear()
        Effect[F].unit
      case HandshakeStatus.NEED_WRAP =>
        doWrap() flatMap {
          case true => handshakeLoop(engine.getHandshakeStatus)
          case false => throw new SSLHandshakeException("Invalid wrap status. CLOSED given but OK expected")
        }
      case HandshakeStatus.NEED_UNWRAP =>
        doUnwrap().flatMap {
          case UnwrapStatus.Ok => handshakeLoop(engine.getHandshakeStatus)
          case UnwrapStatus.ClosedByPeer => throw new SSLHandshakeException("Peer has closed connection during handshake")
          case UnwrapStatus.ClosedByEngine => throw new SSLHandshakeException("Invalid unwrap status. CLOSED given but OK expected")
        }
      case HandshakeStatus.NEED_TASK =>
        val effects = getDelegatedTasks.map { task =>
          Effect[F].promise[Unit] { cb =>
            blockingExecutor.execute {
              () => {
                task.run()
                cb(Right(()))
              }
            }
          }
        }
        Effect[F]
          .sequence(effects)
          .flatMap(_ => handshakeLoop(engine.getHandshakeStatus))
      case HandshakeStatus.NEED_UNWRAP_AGAIN =>
        throw new NotImplementedException("DTLS handshake doesn't supported yet")
    }

    handshakeLoop(engine.getHandshakeStatus)
  }

  private def doWrap(): F[Boolean] = Effect[F].delayAsync {
    netBuffer.clear()
    engine.wrap(appBuffer, netBuffer).getStatus match {
      case Status.OK =>
        netBuffer.flip()
        socket.write(netBuffer).as(true)
      case Status.BUFFER_OVERFLOW =>
        // netBuffer doesn't have required capacity
        netBuffer = ByteBuffer.allocate(engine.getSession.getPacketBufferSize)
        doWrap()
      case Status.CLOSED => Effect[F].pure(false)
    }
  }

  private def doUnwrap(): F[UnwrapStatus] =
    for {
      hasData <- {
        peerNetBuffer.position() match {
          case 0 => socket.read(peerNetBuffer).map(x => if (x < 0) false else true)
          case _ => Effect[F].pure(true)
        }
      }
      result <- if (!hasData) Effect[F].pure(UnwrapStatus.ClosedByPeer) else {
        peerNetBuffer.flip()
        val result = engine.unwrap(peerNetBuffer, peerAppBuffer)
        peerNetBuffer.compact()
        result.getStatus match {
          case Status.OK => Effect[F].pure(UnwrapStatus.Ok)
          case Status.CLOSED => Effect[F].pure(UnwrapStatus.ClosedByEngine)
          case Status.BUFFER_UNDERFLOW =>
            peerNetBuffer = enlargeBuffer(peerNetBuffer, engine.getSession.getPacketBufferSize)
            socket
              .read(peerNetBuffer)
              .flatMap {
                case -1 => Effect[F].pure(UnwrapStatus.ClosedByPeer: UnwrapStatus)
                case _ => doUnwrap()
              }
          case Status.BUFFER_OVERFLOW =>
            peerAppBuffer = enlargeBuffer(peerAppBuffer, engine.getSession.getApplicationBufferSize)
            doUnwrap()
        }
      }
    } yield result

  private def getDelegatedTasks: List[Runnable] = {
    val tasks = mutable.Buffer.empty[Runnable]
    var task: Runnable = null
    while ( {
      task = engine.getDelegatedTask;
      task != null
    }) {
      tasks += task
    }

    tasks.toList
  }

  private def closeInbound(): Unit = {
    try {
      engine.closeInbound()
    } catch {
      case _: Throwable =>
        // Do nothing.
        // This engine.closeInbound() always throws an exception as I see.
        // Every implementation invokes engine.closeInbound() on socket close.
    }
  }
}

object SecureDataSocket {

  def forClientMode[F[_] : Effect, B: BytesLike](socket: RawDataSocket[F, B],
                                                 engine: SSLEngine,
                                                 blockingExecutor: Executor): F[SecureDataSocket[F, B]] = {
    engine.setUseClientMode(true)
    apply(socket, engine, blockingExecutor)
  }

  def apply[F[_] : Effect, B: BytesLike](socket: RawDataSocket[F, B],
                                         engine: SSLEngine,
                                         blockingExecutor: Executor): F[SecureDataSocket[F, B]] = {
    engine.beginHandshake()
    val secureSocket = new SecureDataSocket[F, B](engine, socket)
    secureSocket.handshake(blockingExecutor).as(secureSocket)
  }

  private sealed trait UnwrapStatus

  private object UnwrapStatus {

    case object Ok extends UnwrapStatus

    case object ClosedByEngine extends UnwrapStatus

    case object ClosedByPeer extends UnwrapStatus
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy