All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.erikvanzijst.scalatlsproxy.TlsProxyHandler.scala Maven / Gradle / Ivy

Go to download

Very simple HTTPS proxy server lib written in Scala 2.12 with no external dependencies.

The newest version!
package io.github.erikvanzijst.scalatlsproxy

import java.io.IOException
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.{CancelledKeyException, SelectionKey, Selector, SocketChannel, UnresolvedAddressException}
import java.nio.charset.StandardCharsets
import com.typesafe.scalalogging.Logger
import org.slf4j.LoggerFactory

import scala.util.Try
import scala.util.matching.Regex

object ProxyPhase extends Enumeration {
  type ProxyPhase = Value
  val Destination, Headers, Response, Connecting, Established, Error = Value
}

object TlsProxyHandler {
  val destPattern: Regex = "CONNECT ([^:]+):([0-9]+) HTTP/1.1".r
  val userAgent: String = s"$NAME/$VERSION: $BUILD_DATE (github.com/erikvanzijst/scala_tlsproxy)"
}

class TlsProxyHandler(selector: Selector, clientChannel: SocketChannel, config: Config) extends KeyHandler {
  import ProxyPhase._
  import TlsProxyHandler._

  protected val logger: Logger = Logger(LoggerFactory.getLogger("io.github.erikvanzijst.scalatlsproxy.TlsProxyHandler"))

  clientChannel.configureBlocking(false)
  val clientAddress: InetSocketAddress = clientChannel.getRemoteAddress.asInstanceOf[InetSocketAddress]

  private val clientKey = clientChannel.register(selector, SelectionKey.OP_READ, this)  // client initiating the connection
  private val clientBuffer = ByteBuffer.allocate(1 << 15) // client-to-server

  private var serverKey: SelectionKey = _   // the upstream server
  private var serverChannel: SocketChannel = _
  private val serverBuffer = ByteBuffer.allocate(1 << 15) // server-to-client

  private var upstreamPipe: Pipe = _
  private var downstreamPipe: Pipe = _

  private var shutdown = false
  private var destination: (String, Int) = _

  private var phase = Destination

  def getServerAddress: String =
    Try(serverChannel.getRemoteAddress.toString)
      .recover { case _ => destination._1 + ":" + destination._2 }
      .recover { case _ => "unconnected" }
      .get

  private def readClient(): Unit = {
    if (clientKey.isValid && clientKey.isReadable && clientChannel.read(clientBuffer) == -1)
      throw new IOException(s"$clientAddress unexpected EOF from client")
    if (!clientBuffer.hasRemaining)
      throw new IOException(s"$clientAddress handshake overflow")
  }

  private def readLine(): Option[String] = {
    readClient()
    clientBuffer.flip()

    val s = StandardCharsets.US_ASCII.decode(clientBuffer).toString
    s.indexOf("\r\n") match {
      case eol if eol != -1 =>
        clientBuffer.position(eol + 2)
        clientBuffer.compact()
        Some(s.substring(0, eol))
      case _ =>
        clientBuffer.position(0)
        clientBuffer.compact()
        None
    }
  }

  private def startResponse(statusCode: Int, statusLine: String, body: String): Unit = {
    val resp = response(statusCode, statusLine, body)
    serverBuffer.put(resp, 0, resp.length)
    clientKey.interestOps(SelectionKey.OP_WRITE)
  }

  private def response(statusCode: Int, statusLine: String, body: String): Array[Byte] =
    (s"HTTP/1.1 $statusCode $statusLine\r\n" +
      s"Proxy-Agent: $userAgent\r\n" +
      "Content-Type: text/plain; charset=us-ascii\r\n" +
      s"Content-Length: ${body.length}\r\n" +
      "\r\n" +
      body).getBytes(StandardCharsets.US_ASCII)

  override def process(): Unit =
    try {

      if (phase == Destination)
        readLine().foreach(line =>
          TlsProxyHandler.destPattern.findFirstMatchIn(line) match {
            case Some(m) =>
              destination = (m.group(1), m.group(2).toInt)
              logger.debug("{} wants to connect to {}:{}...", clientAddress, destination._1, destination._2)
              phase = Headers
            case _ => throw new IOException(s"Unsupported method: ${line.split(' ')(0)}")
          }
        )

      if (phase == Headers)
        Iterator.continually(readLine()).takeWhile(_.isDefined).flatten.foreach {
          case header if header == "" =>
            logger.debug("{} all headers consumed, initiating upstream connection...", clientAddress)

            if (!config.forwardFilter(clientAddress, destination)) {
              logger.info(s"{} denied service to {}:{}", clientAddress, destination._1, destination._2)
              startResponse(403, "Forbidden", s"${destination._1} not allowed")
              phase = Error

            } else {
              serverChannel = SocketChannel.open()
              serverChannel.configureBlocking(false)
              serverKey = serverChannel.register(selector, SelectionKey.OP_CONNECT, this)
              clientKey.interestOps(0)  // stop reading while we connect upstream or server a response

              phase = Try {
                if (serverChannel.connect(new InetSocketAddress(destination._1, destination._2))) {
                  startResponse(200, "Connection Accepted", "")
                  Response
                } else {
                  Connecting
                }
              }.recover {
                case _: UnresolvedAddressException =>
                  logger.info(s"{} cannot resolve {}", clientAddress, destination._1)
                  startResponse(502, "Bad Gateway", s"Failed to resolve ${destination._1}")
                  Error
                case iae: IllegalArgumentException =>
                  startResponse(400, "Bad Request", s"${iae.getMessage}\n")
                  Error
              }.get
            }

          case header => logger.debug("{} ignoring header {}", clientAddress, header)
        }

      if (phase == Connecting)
        if (serverKey.isConnectable)
          phase = Try {
            serverChannel.finishConnect()
            startResponse(200, "Connection Accepted", "")
            Response
          }.recover { case ioe: IOException =>
            startResponse(502, "Gateway Error", s"${ioe.getMessage}\n")
            Error
          }.get

      if (phase == Response)
        if (clientKey.isWritable) {
          serverBuffer.flip()
          clientChannel.write(serverBuffer)
          serverBuffer.compact()

          if (serverBuffer.position() == 0) {
            clientKey.interestOps(SelectionKey.OP_READ)
            serverKey.interestOps(SelectionKey.OP_READ)

            upstreamPipe = new Pipe(clientBuffer, clientKey, clientChannel, serverKey, serverChannel)
            downstreamPipe = new Pipe(serverBuffer, serverKey, serverChannel, clientKey, clientChannel)

            logger.debug("{} 200 OK sent to client -- TLS connection to {} ready", clientAddress, getServerAddress)
            phase = Established
          }
        }

      if (phase == Established) {
        upstreamPipe.process()
        downstreamPipe.process()
        if (upstreamPipe.isClosed && downstreamPipe.isClosed) {
          logger.info("{} -> {} finished (up: {}, down: {})",
            clientAddress, getServerAddress, upstreamPipe.bytes, downstreamPipe.bytes)
          close()
        }
      }

      if (phase == Error)
        if (clientKey.isWritable) {
          serverBuffer.flip()
          clientChannel.write(serverBuffer)
          serverBuffer.compact()

          if (serverBuffer.position() == 0)
            close()
        }

    } catch {
      case e @ (_: IOException | _: CancelledKeyException) =>
        logger.warn(s"$clientAddress -> $getServerAddress" +
          (if (phase == Established) s" (up: ${upstreamPipe.bytes} down: ${downstreamPipe.bytes})" else "") +
          s" connection failed: ${e.getClass.getSimpleName}: ${e.getMessage}")
        close()
    }

  def close(): Unit = {
    shutdown = true
    if (clientKey.isValid) {
      clientKey.cancel()
      clientChannel.close()
    }
    if (serverKey != null && serverKey.isValid) {
      serverKey.cancel()
      serverChannel.close()
    }
    logger.debug("{} connection closed", clientAddress)
  }

  override def toString: String = s"TlsProxyHandler(${clientAddress} -> ${getServerAddress})@${Integer.toHexString(hashCode)}"
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy