All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tcp.tcp.scala Maven / Gradle / Ivy

package otoroshi.tcp

import java.net.{InetAddress, InetSocketAddress}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, AtomicReference}
import java.util.regex.MatchResult
import otoroshi.actions.ApiAction
import akka.actor.{ActorSystem, Cancellable}
import akka.http.scaladsl.settings.ServerSettings
import akka.http.scaladsl.util.FastFuture
import akka.stream.TLSProtocol.NegotiateNewSession
import akka.stream.scaladsl.{Flow, Keep, Sink, Source, Tcp}
import akka.stream.{IgnoreComplete, Materializer}
import akka.util.ByteString
import akka.{AwesomeIncomingConnection, Done, TcpUtils}
import otoroshi.env.Env
import otoroshi.events.{DataInOut, Location, TcpEvent}

import javax.net.ssl._
import otoroshi.models.{IpFiltering, ServiceDescriptor}
import org.joda.time.DateTime
import play.api.Logger
import play.api.libs.json._
import play.api.mvc.{AbstractController, ControllerComponents}
import redis.RedisClientMasterSlaves
import otoroshi.security.IdGenerator
import otoroshi.ssl.{ClientAuth, CustomSSLEngine, DynamicSSLEngineProvider}
import otoroshi.storage.{BasicStore, RedisLike, RedisLikeStore}
import otoroshi.utils.{RegexPool, SchedulerHelper}

import scala.concurrent.duration.Duration
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.control.NonFatal
import scala.util.{Failure, Success, Try}
import otoroshi.utils.syntax.implicits._

/**
 * - [x] TCP service can be disabled
 * - [x] TCP service without sni is defined on a port and forwards to targets
 * - [x] Target can define their own dns resolving
 * - [x] TCP service can match a sni domain for a same port (need to catch sni name per request)
 * - [x] TCP service can forward non matching request to local http server (only for sni request)
 * - [x] TCP service can be exposed over tls using dyn tls stuff
 * - [x] TCP service can passthrough tls
 * – [x] TCP service can specify if it needs or wants mtls
 * – [x] Passthrough + SNI
 * - [x] rules
 * if no sni matching, then only one Tcp service can exists with a specific port number
 * if sni matching, then multiple Tcp services can exists with a the port number
 * if sni matching, then all Tcp services using the same port number must have the same Tls mode
 * - [x] We need a new datastore for tcp services
 * - [x] We need to include tcp services in backup/restore
 * - [x] We need a new admin api for tcp services
 * - [x] We need a new UI for tcp services
 * - [x] We need to wire routexxx functions to the new datastore
 * - [x] We need to generate access events
 * - [x] A job will request all tcp services with unique ports and stats tcp server. Servers will be shut down with otoroshi app
 * - [ ] add api in swagger when feature is ready
 * - [ ] support ClientConfig for tcp
 * - [ ] support ClientValidator for tcp
 * - [ ] support IpFiltering
 * - [ ] support healthCheck for tcp (+UI)
 * - [ ] support snowMonkey for tcp
 * - [ ] support live metrics (+UI)
 * - [ ] support analytics in UI (metrics + events)
 */
case class TcpService(
    id: String = IdGenerator.token,
    name: String = "TCP Proxy",
    description: String = "A TCP Proxy",
    enabled: Boolean,
    tls: TlsMode,
    sni: SniSettings,
    clientAuth: ClientAuth,
    port: Int,
    interface: String = "0.0.0.0",
    rules: Seq[TcpRule],
    tags: Seq[String],
    metadata: Map[String, String],
    location: otoroshi.models.EntityLocation = otoroshi.models.EntityLocation()
    // clientValidatorRef: Option[String]
    // clientConfig: ClientConfig
    // ipFiltering: IpFiltering
    // healthCheck
    // snowMonkey
) extends otoroshi.models.EntityLocationSupport {
  def internalId: String                              = id
  def json: JsValue                                   = TcpService.fmt.writes(this)
  def save()(implicit ec: ExecutionContext, env: Env) = env.datastores.tcpServiceDataStore.set(this)
  def theDescription: String                          = description
  def theMetadata: Map[String, String]                = metadata
  def theName: String                                 = name
  def theTags: Seq[String]                            = tags
}
case class SniSettings(
    enabled: Boolean,
    forwardIfNoMatch: Boolean,
    forwardsTo: TcpTarget = TcpTarget("127.0.0.1", None, 8080, false)
) {
  def json: JsValue = SniSettings.fmt.writes(this)
}
object SniSettings {
  def fmt: Format[SniSettings] =
    new Format[SniSettings] {
      override def writes(o: SniSettings): JsValue             =
        Json.obj(
          "enabled"          -> o.enabled,
          "forwardIfNoMatch" -> o.forwardIfNoMatch,
          "forwardsTo"       -> o.forwardsTo.json
        )
      override def reads(json: JsValue): JsResult[SniSettings] =
        Try {
          JsSuccess(
            SniSettings(
              enabled = (json \ "enabled").asOpt[Boolean].getOrElse(false),
              forwardIfNoMatch = (json \ "forwardIfNoMatch").asOpt[Boolean].getOrElse(false),
              forwardsTo =
                (json \ "forwardsTo").asOpt(TcpTarget.fmt).getOrElse(TcpTarget("127.0.0.1", None, 8080, false))
            )
          )
        } recover { case e =>
          JsError(e.getMessage)
        } get
    }
}
case class TcpTarget(host: String, ip: Option[String], port: Int, tls: Boolean) {
  def json: JsValue             = TcpTarget.fmt.writes(this)
  def toAnalyticsString: String = s"${host}${ip.map(v => "/" + v).getOrElse("")}:${port}"
}
object TcpTarget {
  def fmt: Format[TcpTarget] =
    new Format[TcpTarget] {
      override def writes(o: TcpTarget): JsValue             =
        Json.obj(
          "host" -> o.host,
          "ip"   -> o.ip.map(JsString.apply).getOrElse(JsNull).as[JsValue],
          "port" -> o.port,
          "tls"  -> o.tls
        )
      override def reads(json: JsValue): JsResult[TcpTarget] =
        Try {
          JsSuccess(
            TcpTarget(
              host = (json \ "host").as[String],
              ip = (json \ "ip").asOpt[String],
              port = (json \ "port").asOpt[Int].getOrElse(8080),
              tls = (json \ "tls").asOpt[Boolean].getOrElse(false)
            )
          )
        } recover { case e =>
          JsError(e.getMessage)
        } get
    }
}
case class TcpRule(domain: String, targets: Seq[TcpTarget]) {
  def json: JsValue = TcpRule.fmt.writes(this)
}
object TcpRule {
  def fmt: Format[TcpRule] =
    new Format[TcpRule] {
      override def writes(o: TcpRule): JsValue             =
        Json.obj(
          "domain"  -> o.domain,
          "targets" -> JsArray(o.targets.map(_.json))
        )
      override def reads(json: JsValue): JsResult[TcpRule] =
        Try {
          JsSuccess(
            TcpRule(
              domain = (json \ "domain").asOpt[String].getOrElse("*"),
              targets = (json \ "targets").asOpt(Reads.seq(TcpTarget.fmt)).getOrElse(Seq.empty)
            )
          )
        } recover { case e =>
          JsError(e.getMessage)
        } get
    }
}
sealed trait TlsMode {
  def name: String
}

case object TlsModeDisabled extends TlsMode {
  def name: String = "Disabled"
}

case object TlsModeEnabled extends TlsMode {
  def name: String = "Enabled"
}

case object TlsModePassThrough extends TlsMode {
  def name: String = "PassThrough"
}
object TlsMode {

  val Disabled    = TlsModeDisabled
  val Enabled     = TlsModeEnabled
  val PassThrough = TlsModePassThrough

  def apply(v: String): Option[TlsMode] =
    v match {
      case "Disabled"    => Some(Disabled)
      case "disabled"    => Some(Disabled)
      case "Enabled"     => Some(Enabled)
      case "enabled"     => Some(Enabled)
      case "PassThrough" => Some(PassThrough)
      case "passthrough" => Some(PassThrough)
      case _             => None
    }
}

object TcpService {

  private val reqCounter = new AtomicLong(0L)
  private val log        = Logger("otoroshi-tcp-proxy")

  def fromJsons(value: JsValue): TcpService =
    try {
      fmt.reads(value).get
    } catch {
      case e: Throwable => {
        log.error(s"Try to deserialize ${Json.prettyPrint(value)}")
        throw e
      }
    }

  def fromJsonSafe(value: JsValue): Either[Seq[(JsPath, Seq[JsonValidationError])], TcpService] =
    fmt.reads(value).asEither

  val fmt: Format[TcpService] = new Format[TcpService] {
    override def reads(json: JsValue): JsResult[TcpService] =
      Try {
        JsSuccess(
          TcpService(
            location = otoroshi.models.EntityLocation.readFromKey(json),
            id = (json \ "id").as[String],
            name = (json \ "name").as[String],
            description = (json \ "description").as[String],
            port = (json \ "port").as[Int],
            interface = (json \ "interface").asOpt[String].getOrElse("0.0.0.0"),
            enabled = (json \ "enabled").asOpt[Boolean].getOrElse(false),
            tls = (json \ "tls").asOpt[String].flatMap(TlsMode.apply).getOrElse(TlsMode.Disabled),
            sni = (json \ "sni").asOpt(SniSettings.fmt).getOrElse(SniSettings(false, false)),
            clientAuth = (json \ "clientAuth").asOpt[String].flatMap(ClientAuth.apply).getOrElse(ClientAuth.None),
            rules = (json \ "rules").asOpt(Reads.seq(TcpRule.fmt)).getOrElse(Seq.empty),
            metadata = (json \ "metadata").asOpt[Map[String, String]].getOrElse(Map.empty),
            tags = (json \ "tags").asOpt[Seq[String]].getOrElse(Seq.empty[String])
          )
        )
      } recover { case e =>
        JsError(e.getMessage)
      } get

    override def writes(o: TcpService): JsValue =
      o.location.jsonWithKey ++ Json.obj(
        "id"          -> o.id,
        "name"        -> o.name,
        "description" -> o.description,
        "enabled"     -> o.enabled,
        "tls"         -> o.tls.name,
        "sni"         -> o.sni.json,
        "clientAuth"  -> o.clientAuth.name,
        "port"        -> o.port,
        "interface"   -> o.interface,
        "rules"       -> JsArray(o.rules.map(_.json)),
        "metadata"    -> o.metadata,
        "tags"        -> JsArray(o.tags.map(JsString.apply))
      )
  }

  def runServers(env: Env): RunningServers = {
    new RunningServers(env).start()
  }

  def findAll()(implicit ec: ExecutionContext, env: Env): Future[Seq[TcpService]] =
    env.datastores.tcpServiceDataStore.findAll()

  def findByPort(port: Int)(implicit ec: ExecutionContext, env: Env): Future[Option[TcpService]] =
    findAll().map(_.find(_.port == port))

  def findAllFromState()(implicit ec: ExecutionContext, env: Env): Future[Seq[TcpService]] =
    env.proxyState.allTcpServices().vfuture

  def findByPortFromState(port: Int)(implicit ec: ExecutionContext, env: Env): Future[Option[TcpService]] =
    findAllFromState().map(_.find(_.port == port))

  def domainMatch(matchRule: String, domain: String): Boolean = {
    RegexPool(matchRule).matches(domain)
  }

  def routeWithoutSNI(
      incoming: Tcp.IncomingConnection,
      port: Int,
      id: String,
      tls: Boolean,
      start: Long,
      debugger: String => Sink[ByteString, Future[Done]]
  )(cb: (Long, Long) => Unit)(implicit
      ec: ExecutionContext,
      actorSystem: ActorSystem,
      materializer: Materializer,
      env: Env
  ): Future[TcpEvent] = {
    val dataIn    = new AtomicLong(0L)
    val dataOut   = new AtomicLong(0L)
    val targetRef = new AtomicReference[TcpTarget]()
    TcpService.findByPortFromState(incoming.localAddress.getPort).flatMap {
      case Some(service) if service.enabled => {
        try {
          log.info(s"local: ${incoming.localAddress}, remote: ${incoming.remoteAddress}")
          val fullLayer: Flow[ByteString, ByteString, Future[Tcp.OutgoingConnection]] = {
            val targets = service.rules.flatMap(_.targets)
            val index   = reqCounter.incrementAndGet() % (if (targets.nonEmpty) targets.size else 1)
            val target  = targets.apply(index.toInt)
            targetRef.set(target)
            target.tls match {
              case true  => {
                val remoteAddress = target.ip match {
                  case Some(ip) =>
                    new InetSocketAddress(
                      InetAddress.getByAddress(target.host, InetAddress.getByName(ip).getAddress),
                      target.port
                    )
                  case None     => new InetSocketAddress(target.host, target.port)
                }
                Tcp().outgoingConnectionWithTls(
                  remoteAddress,
                  () => DynamicSSLEngineProvider.createSSLEngine(ClientAuth.None, None, None, None, env)
                )
              }
              case false => {
                val remoteAddress = target.ip match {
                  case Some(ip) =>
                    new InetSocketAddress(
                      InetAddress.getByAddress(target.host, InetAddress.getByName(ip).getAddress),
                      target.port
                    )
                  case None     => new InetSocketAddress(target.host, target.port)
                }
                Tcp().outgoingConnection(remoteAddress)
              }
            }
          }
          val overhead = System.currentTimeMillis() - start
          fullLayer
            .alsoTo(Sink.foreach(bs => dataOut.addAndGet(bs.size))) // debugger("[RESP]: ")
            .alsoTo(Sink.onComplete(_ => cb(dataIn.get(), dataOut.get())))
            .joinMat(incoming.flow.alsoTo(Sink.foreach(bs => dataIn.addAndGet(bs.size))))(
              Keep.left
            )                                                       // debugger("[REQ]: ")
            .run()
            .map(_ => {
              val target    = Option(targetRef.get()).map(t => t.toAnalyticsString).getOrElse("--")
              val targetTls = Option(targetRef.get()).map(t => t.tls).getOrElse(false)
              TcpEvent(
                `@id` = env.snowflakeGenerator.nextIdStr(),
                `@timestamp` = DateTime.now(),
                reqId = id,
                protocol = if (tls) "Tcp/Tls" else "Tcp",
                to = Location("*", if (tls) "Tcp/Tls" else "Tcp", ""),
                target = Location(target, if (targetTls) "Tcp/Tls" else "Tcp", ""),
                remote = incoming.remoteAddress.toString,
                local = incoming.localAddress.toString,
                duration = 0L,
                overhead = overhead,
                data = DataInOut(dataIn.get(), dataOut.get()),
                gwError = None,
                `@serviceId` = service.id,
                `@service` = service.name,
                port = port,
                service = Some(service)
              )
            })
            .recover { case NonFatal(ex) =>
              val target    = Option(targetRef.get()).map(t => t.toAnalyticsString).getOrElse("--")
              val targetTls = Option(targetRef.get()).map(t => t.tls).getOrElse(false)
              TcpEvent(
                `@id` = env.snowflakeGenerator.nextIdStr(),
                `@timestamp` = DateTime.now(),
                reqId = id,
                protocol = if (tls) "Tcp/Tls" else "Tcp",
                to = Location("*", if (tls) "Tcp/Tls" else "Tcp", ""),
                target = Location(target, if (targetTls) "Tcp/Tls" else "Tcp", ""),
                remote = incoming.remoteAddress.toString,
                local = incoming.localAddress.toString,
                duration = 0L,
                overhead = overhead,
                data = DataInOut(dataIn.get(), dataOut.get()),
                gwError = Some(ex.getMessage),
                `@serviceId` = service.id,
                `@service` = service.name,
                port = port,
                service = Some(service)
              )
            }
        } catch {
          case NonFatal(e) =>
            val target    = Option(targetRef.get()).map(t => t.toAnalyticsString).getOrElse("--")
            val targetTls = Option(targetRef.get()).map(t => t.tls).getOrElse(false)
            log.error(s"Could not materialize handling flow for ${incoming}", e)
            Future.successful(
              TcpEvent(
                `@id` = env.snowflakeGenerator.nextIdStr(),
                `@timestamp` = DateTime.now(),
                reqId = id,
                protocol = if (tls) "Tcp/Tls" else "Tcp",
                to = Location("*", if (tls) "Tcp/Tls" else "Tcp", ""),
                target = Location(target, if (targetTls) "Tcp/Tls" else "Tcp", ""),
                remote = incoming.remoteAddress.toString,
                local = incoming.localAddress.toString,
                duration = 0L,
                overhead = 0L,
                data = DataInOut(dataIn.get(), dataOut.get()),
                gwError = Some(s"Could not materialize handling flow for ${incoming}: $e"),
                `@serviceId` = "otoroshi",
                `@service` = "otoroshi",
                port = port,
                service = None
              )
            )
        }
      }
      case _                                =>
        val target    = Option(targetRef.get()).map(t => t.toAnalyticsString).getOrElse("--")
        val targetTls = Option(targetRef.get()).map(t => t.tls).getOrElse(false)
        Future.successful(
          TcpEvent(
            `@id` = env.snowflakeGenerator.nextIdStr(),
            `@timestamp` = DateTime.now(),
            reqId = id,
            protocol = if (tls) "Tcp/Tls" else "Tcp",
            to = Location("*", if (tls) "Tcp/Tls" else "Tcp", ""),
            target = Location(target, if (targetTls) "Tcp/Tls" else "Tcp", ""),
            remote = incoming.remoteAddress.toString,
            local = incoming.localAddress.toString,
            duration = 0L,
            overhead = 0L,
            data = DataInOut(dataIn.get(), dataOut.get()),
            gwError = Some("No matching service !"),
            `@serviceId` = "otoroshi",
            `@service` = "otoroshi",
            port = port,
            service = None
          )
        )
    }
  }

  def routeWithSNI(
      incoming: AwesomeIncomingConnection,
      port: Int,
      id: String,
      tls: Boolean,
      start: Long,
      debugger: String => Sink[ByteString, Future[Done]]
  )(cb: (Long, Long) => Unit)(implicit
      ec: ExecutionContext,
      actorSystem: ActorSystem,
      materializer: Materializer,
      env: Env
  ): Future[TcpEvent] = {
    val dataIn    = new AtomicLong(0L)
    val dataOut   = new AtomicLong(0L)
    val targetRef = new AtomicReference[TcpTarget]()
    val ref       = new AtomicReference[String]()
    TcpService.findByPortFromState(incoming.localAddress.getPort).flatMap {
      case Some(service) if service.enabled && service.sni.enabled => {
        try {
          val fullLayer: Flow[ByteString, ByteString, Future[_]] = Flow.lazyFutureFlow { () =>
            incoming.domain.map { sniDomain =>
              ref.set(sniDomain + ":" + port)
              log.info(s"domain: $sniDomain, local: ${incoming.localAddress}, remote: ${incoming.remoteAddress}")
              service.rules.find(r => domainMatch(r.domain, sniDomain)) match {
                case Some(rule)                           => {
                  val targets = rule.targets
                  val index   = reqCounter.incrementAndGet() % (if (targets.nonEmpty) targets.size else 1)
                  val target  = targets.apply(index.toInt)
                  targetRef.set(target)
                  target.tls match {
                    case true  => {
                      val remoteAddress = target.ip match {
                        case Some(ip) =>
                          new InetSocketAddress(
                            InetAddress.getByAddress(target.host, InetAddress.getByName(ip).getAddress),
                            target.port
                          )
                        case None     => new InetSocketAddress(target.host, target.port)
                      }
                      Tcp().outgoingConnectionWithTls(
                        remoteAddress,
                        () => DynamicSSLEngineProvider.createSSLEngine(ClientAuth.None, None, None, None, env)
                      )
                    }
                    case false => {
                      val remoteAddress = target.ip match {
                        case Some(ip) =>
                          new InetSocketAddress(
                            InetAddress.getByAddress(target.host, InetAddress.getByName(ip).getAddress),
                            target.port
                          )
                        case None     => new InetSocketAddress(target.host, target.port)
                      }
                      Tcp().outgoingConnection(remoteAddress)
                    }
                  }
                }
                case None if service.sni.forwardIfNoMatch => {
                  val target        = service.sni.forwardsTo
                  val remoteAddress = target.ip match {
                    case Some(ip) =>
                      new InetSocketAddress(
                        InetAddress.getByAddress(target.host, InetAddress.getByName(ip).getAddress),
                        target.port
                      )
                    case None     => new InetSocketAddress(target.host, target.port)
                  }
                  Tcp().outgoingConnection(remoteAddress)
                }
                case None                                 => {
                  Flow[ByteString].flatMapConcat(_ => Source.failed(new RuntimeException("No domain matches")))
                }
              }
            } recover { case e =>
              log.error("SNI failed", e)
              Flow[ByteString].flatMapConcat(_ => Source.failed(e))
            }
          }
          val overhead                                           = System.currentTimeMillis() - start
          fullLayer
            .alsoTo(Sink.foreach(bs => dataOut.addAndGet(bs.size))) // debugger("[RESP]: ")
            .alsoTo(Sink.onComplete(_ => cb(dataIn.get(), dataOut.get())))
            .joinMat(incoming.flow.alsoTo(Sink.foreach(bs => dataIn.addAndGet(bs.size))))(
              Keep.left
            )                                                       // debugger("[REQ]: ")
            .run()
            .map(_ => {
              val target    = Option(targetRef.get()).map(t => t.toAnalyticsString).getOrElse("--")
              val targetTls = Option(targetRef.get()).map(t => t.tls).getOrElse(false)
              TcpEvent(
                `@id` = env.snowflakeGenerator.nextIdStr(),
                `@timestamp` = DateTime.now(),
                reqId = id,
                protocol = if (tls) "Tcp/Tls" else "Tcp",
                to = Location(Option(ref.get()).getOrElse("no-sni"), if (tls) "Tcp/Tls" else "Tcp", ""),
                target = Location(target, if (targetTls) "Tcp/Tls" else "Tcp", ""),
                remote = incoming.remoteAddress.toString,
                local = incoming.localAddress.toString,
                duration = 0L,
                overhead = overhead,
                data = DataInOut(dataIn.get(), dataOut.get()),
                gwError = None,
                `@serviceId` = service.id,
                `@service` = service.name,
                port = port,
                service = Some(service)
              )
            })
            .recover { case NonFatal(ex) =>
              val target    = Option(targetRef.get()).map(t => t.toAnalyticsString).getOrElse("--")
              val targetTls = Option(targetRef.get()).map(t => t.tls).getOrElse(false)
              TcpEvent(
                `@id` = env.snowflakeGenerator.nextIdStr(),
                `@timestamp` = DateTime.now(),
                reqId = id,
                protocol = if (tls) "Tcp/Tls" else "Tcp",
                to = Location(Option(ref.get()).getOrElse("no-sni"), if (tls) "Tcp/Tls" else "Tcp", ""),
                target = Location(target, if (targetTls) "Tcp/Tls" else "Tcp", ""),
                remote = incoming.remoteAddress.toString,
                local = incoming.localAddress.toString,
                duration = 0L,
                overhead = overhead,
                data = DataInOut(dataIn.get(), dataOut.get()),
                gwError = Some(ex.getMessage),
                `@serviceId` = service.id,
                `@service` = service.name,
                port = port,
                service = Some(service)
              )
            }
        } catch {
          case NonFatal(e) =>
            log.error(s"Could not materialize handling flow for ${incoming}", e)
            val target    = Option(targetRef.get()).map(t => t.toAnalyticsString).getOrElse("--")
            val targetTls = Option(targetRef.get()).map(t => t.tls).getOrElse(false)
            Future.successful(
              TcpEvent(
                `@id` = env.snowflakeGenerator.nextIdStr(),
                `@timestamp` = DateTime.now(),
                reqId = id,
                protocol = if (tls) "Tcp/Tls" else "Tcp",
                to = Location(Option(ref.get()).getOrElse("no-sni"), if (tls) "Tcp/Tls" else "Tcp", ""),
                target = Location(target, if (targetTls) "Tcp/Tls" else "Tcp", ""),
                remote = incoming.remoteAddress.toString,
                local = incoming.localAddress.toString,
                duration = 0L,
                overhead = 0L,
                data = DataInOut(dataIn.get(), dataOut.get()),
                gwError = Some(s"Could not materialize handling flow for ${incoming}: $e"),
                `@serviceId` = "otoroshi",
                `@service` = "otoroshi",
                port = port,
                service = None
              )
            )
        }
      }
      case _                                                       =>
        val target    = Option(targetRef.get()).map(t => t.toAnalyticsString).getOrElse("--")
        val targetTls = Option(targetRef.get()).map(t => t.tls).getOrElse(false)
        Future.successful(
          TcpEvent(
            `@id` = env.snowflakeGenerator.nextIdStr(),
            `@timestamp` = DateTime.now(),
            reqId = id,
            protocol = if (tls) "Tcp/Tls" else "Tcp",
            to = Location(Option(ref.get()).getOrElse("no-sni"), if (tls) "Tcp/Tls" else "Tcp", ""),
            target = Location(target, if (targetTls) "Tcp/Tls" else "Tcp", ""),
            remote = incoming.remoteAddress.toString,
            local = incoming.localAddress.toString,
            duration = 0L,
            overhead = 0L,
            data = DataInOut(dataIn.get(), dataOut.get()),
            gwError = Some("No matching service !"),
            `@serviceId` = "otoroshi",
            `@service` = "otoroshi",
            port = port,
            service = None
          )
        )
    }
  }
}

class TcpEngineProvider {
  def createSSLEngine(clientAuth: ClientAuth, env: Env): SSLEngine = {
    lazy val cipherSuites =
      env.configuration.getOptionalWithFileSupport[Seq[String]]("otoroshi.ssl.cipherSuites").filterNot(_.isEmpty)
    lazy val protocols    =
      env.configuration.getOptionalWithFileSupport[Seq[String]]("otoroshi.ssl.protocols").filterNot(_.isEmpty)

    val context: SSLContext    = DynamicSSLEngineProvider.currentServer
    if (DynamicSSLEngineProvider.logger.isDebugEnabled)
      DynamicSSLEngineProvider.logger.debug(s"Create SSLEngine from: $context")
    val rawEngine              = context.createSSLEngine()
    val engine                 = new CustomSSLEngine(
      rawEngine,
      None,
      env.datastores.globalConfigDataStore.latestUnsafe.tlsSettings.bannedAlpnProtocols
    )
    val rawEnabledCipherSuites = rawEngine.getEnabledCipherSuites.toSeq
    val rawEnabledProtocols    = rawEngine.getEnabledProtocols.toSeq
    cipherSuites.foreach(s => rawEngine.setEnabledCipherSuites(s.toArray))
    protocols.foreach(p => rawEngine.setEnabledProtocols(p.toArray))
    val sslParameters          = new SSLParameters
    val matchers               = new java.util.ArrayList[SNIMatcher]()
    clientAuth match {
      case ClientAuth.Want =>
        engine.setWantClientAuth(true)
        sslParameters.setWantClientAuth(true)
      case ClientAuth.Need =>
        engine.setNeedClientAuth(true)
        sslParameters.setNeedClientAuth(true)
      case _               =>
    }
    matchers.add(new SNIMatcher(0) {
      override def matches(sniServerName: SNIServerName): Boolean = {
        sniServerName match {
          case hn: SNIHostName =>
            val hostName = hn.getAsciiName
            if (DynamicSSLEngineProvider.logger.isDebugEnabled)
              DynamicSSLEngineProvider.logger.debug(s"createSSLEngine - for $hostName")
            engine.setEngineHostName(hostName)
          case _               =>
            if (DynamicSSLEngineProvider.logger.isDebugEnabled)
              DynamicSSLEngineProvider.logger.debug(s"Not a hostname :( ${sniServerName.toString}")
        }
        true
      }
    })
    sslParameters.setSNIMatchers(matchers)
    cipherSuites.orElse(Some(rawEnabledCipherSuites)).foreach(s => sslParameters.setCipherSuites(s.toArray))
    protocols.orElse(Some(rawEnabledProtocols)).foreach(p => sslParameters.setProtocols(p.toArray))
    engine.setSSLParameters(sslParameters)
    engine
  }
}

object TcpProxy {
  def apply(tcp: TcpService)(implicit system: ActorSystem, mat: Materializer): TcpProxy =
    new TcpProxy(tcp.interface, tcp.port, tcp.tls, tcp.sni.enabled, tcp.clientAuth, false)(system, mat)
  def apply(interface: String, port: Int, tls: TlsMode, sni: Boolean, clientAuth: ClientAuth, debug: Boolean = false)(
      implicit
      system: ActorSystem,
      mat: Materializer
  ): TcpProxy                                                                           = new TcpProxy(interface, port, tls, sni, clientAuth, debug)(system, mat)
}

class TcpProxy(
    interface: String,
    port: Int,
    tls: TlsMode,
    sni: Boolean,
    clientAuth: ClientAuth,
    debug: Boolean = false
)(implicit
    system: ActorSystem,
    mat: Materializer
) {

  private val log         = Logger("otoroshi-tcp-proxy")
  private implicit val ec = system.dispatcher
  private val provider    = new TcpEngineProvider()

  private def debugger(title: String): Sink[ByteString, Future[Done]] =
    debug match {
      case true  => Sink.foreach[ByteString](bs => log.info(title + bs.utf8String))
      case false => Sink.ignore
    }

  private def tcpBindTlsAndSNI(settings: ServerSettings, env: Env): Future[Tcp.ServerBinding] = {
    TcpUtils
      .bindTlsWithSSLEngineAndSNI(
        interface = interface,
        port = port,
        createSSLEngine = () => {
          provider.createSSLEngine(clientAuth, env)
        },
        backlog = settings.backlog,
        options = settings.socketOptions,
        idleTimeout = Duration.Inf,
        verifySession = session => {
          Success(())
        },
        closing = IgnoreComplete
      )
      .mapAsyncUnordered(settings.maxConnections) { incoming =>
        val id    = env.snowflakeGenerator.nextIdStr()
        val start = System.currentTimeMillis()
        val ref   = new AtomicReference[TcpEvent]()
        TcpService
          .routeWithSNI(incoming, port, id, true, start, debugger) { case (in, out) =>
            ref
              .get()
              .copy(duration = System.currentTimeMillis() - start, data = DataInOut(in, out))
              .toAnalytics()(env)
          }(ec, system, mat, env)
          .andThen { case Success(evt) =>
            ref.set(evt) //evt.copy(duration = System.currentTimeMillis() - start).toAnalytics()(env)
          }
      }
      .to(Sink.ignore)
      .run()
  }

  private def tcpBindTls(settings: ServerSettings, env: Env): Future[Tcp.ServerBinding] = {
    TcpUtils
      .bindTlsWithSSLEngine(
        interface = interface,
        port = port,
        createSSLEngine = () => {
          new TcpEngineProvider().createSSLEngine(clientAuth, env)
        },
        backlog = settings.backlog,
        options = settings.socketOptions,
        idleTimeout = Duration.Inf,
        verifySession = session => {
          Success(())
        },
        closing = IgnoreComplete
      )
      .mapAsyncUnordered(settings.maxConnections) { incoming =>
        val id    = env.snowflakeGenerator.nextIdStr()
        val start = System.currentTimeMillis()
        val ref   = new AtomicReference[TcpEvent]()
        TcpService
          .routeWithoutSNI(incoming, port, id, true, start, debugger) { case (in, out) =>
            ref
              .get()
              .copy(duration = System.currentTimeMillis() - start, data = DataInOut(in, out))
              .toAnalytics()(env)
          }(ec, system, mat, env)
          .andThen { case Success(evt) =>
            ref.set(evt) //evt.copy(duration = System.currentTimeMillis() - start).toAnalytics()(env)
          }
      }
      .to(Sink.ignore)
      .run()
  }

  private def tcpBindNoTls(settings: ServerSettings, env: Env): Future[Tcp.ServerBinding] = {
    Tcp()
      .bind(
        interface = interface,
        port = port,
        halfClose = false,
        backlog = settings.backlog,
        options = settings.socketOptions,
        idleTimeout = Duration.Inf
      )
      .mapAsyncUnordered(settings.maxConnections) { incoming =>
        val id    = env.snowflakeGenerator.nextIdStr()
        val start = System.currentTimeMillis()
        val ref   = new AtomicReference[TcpEvent]()
        TcpService
          .routeWithoutSNI(incoming, port, id, false, start, debugger) { case (in, out) =>
            ref
              .get()
              .copy(duration = System.currentTimeMillis() - start, data = DataInOut(in, out))
              .toAnalytics()(env)
          }(ec, system, mat, env)
          .andThen { case Success(evt) =>
            ref.set(evt) //evt.copy(duration = System.currentTimeMillis() - start).toAnalytics()(env)
          }
      }
      .to(Sink.ignore)
      .run()
  }

  private def tcpBindNoTlsAndSNI(settings: ServerSettings, env: Env): Future[Tcp.ServerBinding] = {
    Tcp()
      .bind(
        interface = interface,
        port = port,
        halfClose = false,
        backlog = settings.backlog,
        options = settings.socketOptions,
        idleTimeout = Duration.Inf
      )
      .map { incomingConnection =>
        val promise    = Promise[String]
        val firstChunk = new AtomicBoolean(false)
        AwesomeIncomingConnection(
          incomingConnection.copy(
            flow = incomingConnection.flow.alsoTo(Sink.foreach { bs =>
              if (firstChunk.compareAndSet(false, true)) {
                val packetString = bs.utf8String
                val matcher      = akka.TcpUtils.domainNamePattern.matcher(packetString)
                while (matcher.find()) {
                  val matchResult: MatchResult = matcher.toMatchResult
                  val expression: String       = matchResult.group()
                  promise.trySuccess(expression)
                }
                if (!promise.isCompleted) {
                  promise.tryFailure(new RuntimeException("SNI not found !"))
                }
              }
            })
          ),
          promise.future
        )
      }
      .mapAsyncUnordered(settings.maxConnections) { incoming =>
        val id    = env.snowflakeGenerator.nextIdStr()
        val start = System.currentTimeMillis()
        val ref   = new AtomicReference[TcpEvent]()
        TcpService
          .routeWithSNI(incoming, port, id, false, start, debugger) { case (in, out) =>
            val e = ref.get().copy(duration = System.currentTimeMillis() - start, data = DataInOut(in, out))
            // println(Json.prettyPrint(e.toJson(env)))
            e.toAnalytics()(env)
          }(ec, system, mat, env)
          .andThen { case Success(evt) =>
            ref.set(evt) //evt.copy(duration = System.currentTimeMillis() - start).toAnalytics()(env)
          }
      }
      .to(Sink.ignore)
      .run()
  }

  def start(env: Env): Future[Tcp.ServerBinding] = {
    val config   = env.configuration.underlying
    val settings = ServerSettings(config)
    (tls match {
      case TlsMode.Disabled            => tcpBindNoTls(settings, env)
      case TlsMode.PassThrough if sni  => tcpBindNoTlsAndSNI(settings, env)
      case TlsMode.PassThrough if !sni => tcpBindNoTls(settings, env)
      case TlsMode.Enabled if !sni     => tcpBindTls(settings, env)
      case TlsMode.Enabled if sni      => tcpBindTlsAndSNI(settings, env)
    }).andThen {
      case Success(_) if tls == TlsMode.Enabled => log.info(s"Tcp/Tls proxy listening on $interface:$port")
      case Success(_)                           => log.info(s"Tcp     proxy listening on $interface:$port")
      case Failure(e) if tls == TlsMode.Enabled =>
        log.error(s"Error while binding Tcp/Tls proxy on $interface:$port", e)
      case Failure(e)                           => log.error(s"Error while binding Tcp     proxy on $interface:$port", e)
    }
  }
}

case class RunningServer(port: Int, oldService: TcpService, binding: Future[Tcp.ServerBinding])

class RunningServers(env: Env) {

  import scala.concurrent.duration._

  private implicit val system = env.otoroshiActorSystem
  private implicit val ec     = env.otoroshiExecutionContext
  private implicit val mat    = env.otoroshiMaterializer
  private implicit val ev     = env
  private val ref             = new AtomicReference[Cancellable]()
  private val running         = new AtomicBoolean(false)
  private val syncing         = new AtomicBoolean(false)
  private val runningServers  = new AtomicReference[Seq[RunningServer]](Seq.empty)
  private val log             = Logger("otoroshi-tcp-proxy")

  private def updateRunningServers(): Unit = {
    if (running.get() && syncing.compareAndSet(false, true)) {
      TcpService
        .findAllFromState()
        .map { services =>
          val actualServers = runningServers.get()
          val existingPorts = actualServers.map(_.port)
          log.debug(s"[RunningServer] existing $existingPorts")
          val changed       = services.filter(s => existingPorts.contains(s.port)).filter { s =>
            val server = actualServers.find(_.port == s.port).get
            s.interface != server.oldService.interface ||
            s.sni != server.oldService.sni ||
            s.tls != server.oldService.tls ||
            s.clientAuth != server.oldService.clientAuth
          }
          log.debug(s"[RunningServer] changed ${changed.map(_.port)}")
          val notRunning    = services.filterNot(s => existingPorts.contains(s.port))
          log.debug(s"[RunningServer] notRunning ${notRunning.map(_.port)}")
          val willExistPort = (changed ++ notRunning ++ services).distinct.map(_.port)
          log.debug(s"[RunningServer] willExist ${willExistPort}")
          val toShutDown    = actualServers.filterNot(s => willExistPort.contains(s.port))
          log.debug(s"[RunningServer] toShutDown ${toShutDown.map(_.port)}")
          val allDown1      = Future.sequence(toShutDown.map { s =>
            log.info(
              s"Stopping Tcp proxy on ${s.oldService.interface}:${s.oldService.port} because it does not exists anymore"
            )
            s.binding.flatMap(_.unbind())
          })
          val allDown2      = Future.sequence(changed.map { s =>
            val server = actualServers.find(_.port == s.port).get
            log.info(
              s"Stopping Tcp proxy on ${server.oldService.interface}:${server.oldService.port} because the service changed"
            )
            server.binding.flatMap(_.unbind())
          })
          for {
            _ <- allDown1
            _ <- allDown2
          } yield {
            val running1      = changed.map(s => RunningServer(s.port, s, TcpProxy(s).start(env)))
            val running2      = notRunning.map(s => RunningServer(s.port, s, TcpProxy(s).start(env)))
            val changedPorts  = changed.map(_.port)
            val shutdownPorts = toShutDown.map(_.port)
            val stayServers   =
              actualServers.filterNot(s => changedPorts.contains(s.port) || shutdownPorts.contains(s.port))
            runningServers.set(stayServers ++ running1 ++ running2)
          }
        }
        .andThen { case _ =>
          syncing.compareAndSet(true, false)
        }
    }
  }

  def start(): RunningServers = {
    if (running.compareAndSet(false, true)) {
      ref.set(
        system.scheduler.scheduleAtFixedRate(1.second, 10.seconds)(
          SchedulerHelper.runnable(
            updateRunningServers()
          )
        )
      )
    }
    this
  }

  def stop(): Future[Unit] = {
    if (running.compareAndSet(true, false)) {
      Option(ref.get()).foreach(_.cancel())
      Future
        .sequence(runningServers.get().map { server =>
          log.info(s"Stopping Tcp proxy on ${server.oldService.interface}:${server.oldService.port}")
          server.binding.flatMap(_.unbind())
        })
        .map(_ => ())
    } else {
      FastFuture.successful(())
    }
  }
}

sealed trait TcpServiceDataStore extends BasicStore[TcpService] {
  def template(env: Env): TcpService = {
    val defaultService = TcpService(
      id = IdGenerator.namedId("tcp_service", env),
      enabled = true,
      tls = TlsMode.Disabled,
      sni = SniSettings(false, false),
      clientAuth = ClientAuth.None,
      port = 4200,
      tags = Seq.empty,
      metadata = Map.empty,
      rules = Seq(
        TcpRule(
          domain = "*",
          targets = Seq(
            TcpTarget(
              "42.42.42.42",
              None,
              4200,
              false
            )
          )
        )
      )
    )
    env.datastores.globalConfigDataStore
      .latest()(env.otoroshiExecutionContext, env)
      .templates
      .tcpService
      .map { template =>
        TcpService.fmt.reads(defaultService.json.asObject.deepMerge(template)).get
      }
      .getOrElse {
        defaultService
      }
  }
}

class KvTcpServiceDataStoreDataStore(redisCli: RedisLike, env: Env)
    extends TcpServiceDataStore
    with RedisLikeStore[TcpService] {

  override def fmt: Format[TcpService]                 = TcpService.fmt
  override def redisLike(implicit env: Env): RedisLike = redisCli
  override def key(id: String): String                 = s"${env.storageRoot}:tcp:services:$id"
  override def extractId(value: TcpService): String    = value.id
}

/*

  private val _services = Seq(
    TcpService(
      enabled = true,
      tls = TlsMode.Disabled,
      sni = SniSettings(false, false),
      clientAuth = ClientAuth.None,
      port = 1201,
      rules = Seq(TcpRule(
        domain = "*",
        targets = Seq(
          TcpTarget(
            "localhost",
            None,
            1301,
            false
          ),
          TcpTarget(
            "localhost",
            None,
            1302,
            false
          )
        )
      ))
    ),
    TcpService(
      enabled = true,
      tls = TlsMode.PassThrough,
      sni = SniSettings(false, false),
      clientAuth = ClientAuth.None,
      port = 1202,
      rules = Seq(TcpRule(
        domain = "*",
        targets = Seq(
          TcpTarget(
            "ssl.ancelin.org",
            Some("127.0.0.1"),
            1303,
            false
          ),
          TcpTarget(
            "ssl.ancelin.org",
            Some("127.0.0.1"),
            1304,
            false
          )
        )
      ))
    ),
    TcpService(
      enabled = true,
      tls = TlsMode.Enabled,
      sni = SniSettings(false, false),
      clientAuth = ClientAuth.None,
      port = 1203,
      rules = Seq(TcpRule(
        domain = "*",
        targets = Seq(
          TcpTarget(
            "localhost",
            None,
            1301,
            false
          ),
          TcpTarget(
            "localhost",
            None,
            1302,
            false
          )
        )
      ))
    ),
    TcpService(
      enabled = true,
      tls = TlsMode.Enabled,
      sni = SniSettings(true, false),
      clientAuth = ClientAuth.None,
      port = 1204,
      rules = Seq(
        TcpRule(
          domain = "ssl.ancelin.org",
          targets = Seq(
            TcpTarget(
              "localhost",
              None,
              1301,
              false
            )
          )
        ),
        TcpRule(
          domain = "ssl2.ancelin.org",
          targets = Seq(
            TcpTarget(
              "localhost",
              None,
              1302,
              false
            )
          )
        )
      )
    ),
    TcpService(
      enabled = true,
      tls = TlsMode.PassThrough,
      sni = SniSettings(true, false),
      clientAuth = ClientAuth.None,
      port = 1205,
      // test with
      // curl -v --resolve www.google.fr:1205:127.0.0.1 https://www.google.fr:1205/
      // curl -v --resolve www.amazon.fr:1205:127.0.0.1 https://www.amazon.fr:1205/ --compressed
      rules = Seq(
        TcpRule(
          domain = "www.google.fr",
          targets = Seq(
            TcpTarget(
              "www.google.fr",
              None,
              443,
              false
            )
          )
        ),
        TcpRule(
          domain = "www.amazon.fr",
          targets = Seq(
            TcpTarget(
              "www.amazon.fr",
              None,
              443,
              false
            )
          )
        )
      )
    )
  )
 */




© 2015 - 2025 Weber Informatics LLC | Privacy Policy