All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.ytsaurus.spyt.wrapper.TcpProxyService.scala Maven / Gradle / Ivy

The newest version!
package tech.ytsaurus.spyt.wrapper

import org.slf4j.{Logger, LoggerFactory}
import tech.ytsaurus.client.CompoundClient
import tech.ytsaurus.client.request.CreateNode
import tech.ytsaurus.core.cypress.{CypressNodeType, YPath}
import tech.ytsaurus.spyt.HostAndPort
import tech.ytsaurus.spyt.wrapper.discovery.Address
import tech.ytsaurus.ysontree.{YTree, YTreeNode}

import scala.annotation.tailrec
import scala.concurrent.duration.DurationInt
import scala.language.postfixOps


class TcpProxyService(isEnabled: Boolean, startPort: Int, endPort: Int) {
  import TcpProxyService._

  private def proxyAddress(implicit yt: CompoundClient): Option[String] = {
    if (isEnabled) {
      try {
        val externalAddresses = YtWrapper.attribute(DEFAULT_ROUTES, "external_addresses", None).asList()
        if (externalAddresses.size() >= 1) {
          Some(externalAddresses.get(0).stringValue())
        } else {
          None
        }
      } catch {
        case e: Exception =>
          log.warn(f"Error while get external addresses request", e)
          None
      }
    } else {
      None
    }
  }

  private def isPortBusy(port: Int)(implicit yt: CompoundClient): Boolean = YtWrapper.exists(portYPath(port))

  private def createPortNode(address: String, port: Int)(implicit yt: CompoundClient): Unit = {
    val attributes = java.util.Map.of(
      "expiration_timeout", YTree.longNode(EXPIRATION_TIMEOUT),
      "endpoints", buildEndpointsNode(address)
    )
    val createRequest = CreateNode.builder()
      .setPath(portYPath(port)).setType(CypressNodeType.MAP).setAttributes(attributes)
      .build()
    yt.createNode(createRequest).join()
  }

  private def tryTakePort(address: String, port: Int)(implicit yt: CompoundClient): Boolean = {
    if (!isPortBusy(port)) {
      try {
        createPortNode(address, port)
        true
      } catch {
        case e: Exception =>
          log.warn(f"Error while creating port $port map node request", e)
          false
      }
    } else {
      false
    }
  }

  @tailrec
  private def takeFreePortIterative(address: String, current: Int)(implicit yt: CompoundClient): Int = {
    if (current >= endPort) {
      throw new IllegalStateException("No free ports found")
    }
    if (tryTakePort(address, current)) {
      current
    } else {
      log.debug(f"Port $current is busy")
      takeFreePortIterative(address, current + 1)
    }
  }

  private def takeFreePort(address: String)(implicit yt: CompoundClient): Int = {
    log.debug(f"Search free port for address $address")
    takeFreePortIterative(address, startPort)
  }

  private def takeFreePorts(addresses: Seq[String])(implicit yt: CompoundClient): Map[String, Int] = {
    addresses.map(x => x -> takeFreePort(x)).toMap
  }

  private def pingPortNode(address: String, port: Int)(implicit yt: CompoundClient): Unit = {
    try {
      if (!isPortBusy(port)) log.warn(f"Reserved port $port for address $address is free now")
    } catch {
      case e: Exception => log.warn(f"Error while ping port $port map node", e)
    }
  }

  def register(addresses: String*)(implicit yt: CompoundClient): Option[TcpRouter] = {
    if (isEnabled) {
      val externalAddress = proxyAddress.get
      val addressesWithPorts = takeFreePorts(addresses)
      log.info(f"External address: $externalAddress. Ports for given addresses $addressesWithPorts")
      val tcpRouter = TcpRouter(externalAddress, addressesWithPorts)
      val thread = new Thread(() => {
        while (true) {
          Thread.sleep((30 seconds).toMillis)
          log.debug(f"Ping proxy port nodes")
          tcpRouter.mapping.foreach { case (address, port) => pingPortNode(address, port) }
        }
      })
      thread.setDaemon(true)
      thread.start()
      log.info("TcpProxyService started")
      Some(tcpRouter)
    } else {
      None
    }
  }

  def register(address: Address)(implicit yt: CompoundClient): Option[TcpRouter] = {
    register(address.hostAndPort.toString, address.webUiHostAndPort.toString, address.restHostAndPort.toString)
  }
}

object TcpProxyService {
  private val log: Logger = LoggerFactory.getLogger(getClass)

  def apply(): TcpProxyService = {
    val isEnabled: Boolean = sys.env.get("SPARK_YT_TCP_PROXY_ENABLED").exists(_.toBoolean)
    val startPort: Int =  sys.env.get("SPARK_YT_TCP_PROXY_RANGE_START").map(_.toInt).getOrElse(30000)
    val endPort: Int = startPort + sys.env.get("SPARK_YT_TCP_PROXY_RANGE_SIZE").map(_.toInt).getOrElse(1000)
    new TcpProxyService(isEnabled, startPort, endPort)
  }

  def apply(isEnabled: Boolean, startPort: Int, portRange: Int): TcpProxyService = {
    val endPort: Int = startPort + portRange
    new TcpProxyService(isEnabled, startPort, endPort)
  }

  private val EXPIRATION_TIMEOUT: Long = (10 minutes).toMillis

  private val DEFAULT_ROUTES: YPath = YPath.simple("//sys/tcp_proxies/routes/default")

  private def portYPath(port: Int): YPath = DEFAULT_ROUTES.child(port.toString)

  private def buildEndpointsNode(address: String): YTreeNode = YTree.listBuilder().value(address).endList().build()

  def updateTcpAddress(address: String, port: Int)(implicit yt: CompoundClient): Unit = {
    log.info(f"Update address $address request for port $port")
    try {
      YtWrapper.setAttribute(portYPath(port).toString, "endpoints", buildEndpointsNode(address))
    } catch {
      case e: Exception => log.warn(f"Error while updating port $port map node to address $address", e)
    }
  }

  case class TcpRouter(externalAddress: String, mapping: Map[String, Int]) {
    def getExternalAddress(internalName: String): HostAndPort = HostAndPort(externalAddress, getPort(internalName))

    def getPort(internalName: String): Int = mapping(internalName)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy