All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.ytsaurus.spyt.submit.SubmissionClient.scala Maven / Gradle / Ivy

The newest version!
package tech.ytsaurus.spyt.submit

import io.netty.channel.DefaultEventLoopGroup
import org.apache.commons.io.FileUtils
import org.apache.spark.deploy.rest._
import org.apache.spark.launcher.InProcessLauncher
import org.slf4j.LoggerFactory
import tech.ytsaurus.spyt.wrapper.client.{YtClientConfiguration, YtClientProvider}
import tech.ytsaurus.spyt.wrapper.config.Utils.{parseRemoteConfig, remoteClusterConfigPath, remoteGlobalConfigPath, remoteVersionConfigPath}
import tech.ytsaurus.spyt.wrapper.discovery.CypressDiscoveryService
import tech.ytsaurus.core.cypress.YPath

import java.io.File
import java.util.UUID
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.{Executors, ThreadFactory}
import scala.annotation.tailrec
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.{Failure, Success, Try}

class SubmissionClient(proxy: String,
                       baseDiscoveryPath: String,
                       user: String,
                       token: String) {
  private val log = LoggerFactory.getLogger(getClass)

  private val eventLogPath = CypressDiscoveryService.eventLogPath(baseDiscoveryPath)

  private val cluster = new AtomicReference[SparkCluster](SparkCluster.get(proxy, baseDiscoveryPath, user, token))

  private val threadFactory = new ThreadFactory() {
    override def newThread(runnable: Runnable): Thread = {
      val thread = Executors.defaultThreadFactory().newThread(runnable)
      thread.setDaemon(true)
      thread
    }
  }
  private val loop = new DefaultEventLoopGroup(1, threadFactory).next()

  def newLauncher(): InProcessLauncher = new InProcessLauncher()

  // for Java
  def submit(launcher: InProcessLauncher): Try[String] = {
    submit(launcher, RetryConfig())
  }

  private def addConf(launcher: InProcessLauncher, config: Map[String, String]): Unit = {
    config.foreach { case (key, value) => launcher.setConf(key, value) }
  }

  def submit(launcher: InProcessLauncher,
             retryConfig: RetryConfig): Try[String] = {
    val yt = YtClientProvider.ytClient(YtClientConfiguration.default(proxy, user, token))
    val remoteGlobalConfig = parseRemoteConfig(remoteGlobalConfigPath, yt)
    val remoteVersionConfig = parseRemoteConfig(remoteVersionConfigPath(cluster.get().version), yt)
    val remoteClusterConfig = parseRemoteConfig(remoteClusterConfigPath(baseDiscoveryPath), yt)
    val remoteConfig = remoteGlobalConfig ++ remoteVersionConfig ++ remoteClusterConfig

    launcher.setDeployMode("cluster")

    val ipv6Enabled = remoteConfig.get("spark.hadoop.yt.preferenceIpv6.enabled").exists(_.toBoolean)

    addConf(launcher, remoteConfig)

    if (ipv6Enabled) {
      log.debug("preferIPv6Addresses was added to extraJavaOptions")
      launcher.setConf("spark.driver.extraJavaOptions", "-Djava.net.preferIPv6Addresses=true")
      launcher.setConf("spark.executor.extraJavaOptions", "-Djava.net.preferIPv6Addresses=true")
    }

    launcher.setConf("spark.master.rest.enabled", "true")
    launcher.setConf("spark.master.rest.failover", "false")
    launcher.setConf("spark.rest.client.awaitTermination.enabled", "false")

    launcher.setConf("spark.hadoop.yt.proxy", proxy)
    launcher.setConf("spark.hadoop.yt.user", user)
    launcher.setConf("spark.hadoop.yt.token", token)

    launcher.setConf("spark.eventLog.dir", "ytEventLog:/" + eventLogPath)

    if (useDedicatedDriverOp) {
      launcher.setConf("spark.driver.resource.driverop.amount", "1")
    }

    submitInner(launcher, retryConfig)
  }

  def getStatus(id: String): DriverState = {
    tryToGetStatus(id).getOrElse(DriverState.UNDEFINED)
  }

  def getApplicationStatus(driverId: String): ApplicationState = {
    tryToGetApplicationStatus(driverId).getOrElse(ApplicationState.UNDEFINED)
  }

  def getStringStatus(id: String): String = {
    getStatus(id).name()
  }

  def getStringApplicationStatus(driverId: String): String = {
    getApplicationStatus(driverId).name()
  }

  private def tryToGetApplicationStatus(id: String): Try[ApplicationState] = {
    val res = Try {
      val responseAppId = RestSubmissionClientWrapper.requestApplicationId(cluster.get().client, id)
      val responseAppStatus = RestSubmissionClientWrapper.requestApplicationStatus(cluster.get().client, responseAppId.appId)
      if (responseAppStatus.success) {
        ApplicationState.valueOf(responseAppStatus.appState)
      } else {
        ApplicationState.UNKNOWN
      }
    }
    if (res.isFailure) {
      log.warn(s"Failed to get status of application $id")
      forceClusterUpdate()
    }
    res
  }

  private def tryToGetStatus(id: String): Try[DriverState] = {
    val res = Try {
      val response = RestSubmissionClientWrapper.requestSubmissionStatus(cluster.get().client, id)
      if (response.success) {
        DriverState.valueOf(response.driverState)
      } else {
        DriverState.UNKNOWN // master restarted and can't find this submission id
      }
    }
    if (res.isFailure) {
      log.warn(s"Failed to get status of submission $id")
      forceClusterUpdate()
    }
    res
  }

  def getActiveDrivers: Seq[String] = {
    getActiveDriverInfos.map(_.driverId)
  }

  def getActiveDriverInfos: Seq[DriverInfo] = {
    getAllDriverInfos.filter(d => !DriverState.valueOf(d.status).isFinal)
  }

  def getCompletedDrivers: Seq[String] = {
    getCompletedDriverInfos.map(_.driverId)
  }

  def getCompletedDriverInfos: Seq[DriverInfo] = {
    getAllDriverInfos.filter(d => DriverState.valueOf(d.status).isFinal)
  }

  def getAllDrivers: Seq[String] = {
    getAllDriverInfos.map(_.driverId)
  }

  def getAllDriverInfos: Seq[DriverInfo] = {
    val response = MasterClient.allDrivers(cluster.get().masterHostAndPort)
    if (response.isFailure) {
      log.warn(s"Failed to get list of drivers")
      log.warn(response.failed.get.toString)
    }
    response.getOrElse(Nil)
  }

  def useDedicatedDriverOp: Boolean =
    MasterClient.activeWorkers(cluster.get().masterHostAndPort) match {
      case Failure(err) =>
        log.warn(s"Failed to query list of active workers", err)
        false
      case Success(workers) =>
        workers.exists(_.isDriverOp)
    }

  def kill(id: String): Boolean = {
    val response = RestSubmissionClientWrapper.killSubmission(cluster.get().client, id)
    response.success.booleanValue()
  }

  case class SubmissionFiles(id: File, error: File) {
    def delete(): Unit = {
      if (id.exists()) id.delete()
      if (error.exists()) error.delete()
    }
  }

  private def updateCluster(): Unit = {
    log.debug(s"Update cluster addresses from $baseDiscoveryPath")
    cluster.set(SparkCluster.get(proxy, baseDiscoveryPath, user, token))
  }

  private def forceClusterUpdate(): Unit = {
    loop.submit(new Runnable {
      override def run(): Unit = updateCluster()
    })
  }

  private def configureCluster(launcher: InProcessLauncher): Unit = {
    val clusterValue = cluster.get()
    launcher.setMaster(clusterValue.master)
    launcher.setConf("spark.rest.master", clusterValue.masterRest)
    launcher.setConf("spark.yt.cluster.version", clusterValue.version)
  }

  private def prepareSubmissionFiles(launcher: InProcessLauncher): SubmissionFiles = {
    val fileName = s"spark-submission-${UUID.randomUUID()}"
    val idFile = new File(FileUtils.getTempDirectory, s"$fileName-id")
    val errorFile = new File(FileUtils.getTempDirectory, s"$fileName-error")
    launcher.setConf("spark.rest.client.submissionIdFile", idFile.getAbsolutePath)
    launcher.setConf("spark.rest.client.submissionErrorFile", errorFile.getAbsolutePath)
    SubmissionFiles(idFile, errorFile)
  }

  @tailrec
  private def submitInner(launcher: InProcessLauncher,
                          retryConfig: RetryConfig = RetryConfig(),
                          retry: Int = 1): Try[String] = {
    val submissionFiles = prepareSubmissionFiles(launcher)
    val submissionId = Try {
      configureCluster(launcher)
      launcher.startApplication()
      getSubmissionId(submissionFiles, retryConfig.waitSubmissionIdRetryLimit)
    }
    submissionFiles.delete()

    submissionId match {
      case Failure(e) if !retryConfig.enableRetry =>
        forceClusterUpdate()
        Failure(new RuntimeException("Failed to submit job and retry is disabled", e))
      case Failure(e) if retry >= retryConfig.retryLimit =>
        forceClusterUpdate()
        Failure(new RuntimeException(s"Failed to submit job and retry limit ${retryConfig.retryLimit} exceeded", e))
      case Failure(e) =>
        forceClusterUpdate()
        log.warn("Failed to submit job and retry is enabled")
        log.warn(e.getMessage)
        log.info(s"Retry to submit job in ${retryConfig.retryInterval.toCoarsest}")
        Thread.sleep(retryConfig.retryInterval.toMillis)
        submitInner(launcher, retryConfig, retry + 1)
      case success => success
    }
  }

  @tailrec
  private def waitSubmissionResultFile(submissionFiles: SubmissionFiles, retryLimit: Int): Unit = {
    if (!submissionFiles.id.exists() && !submissionFiles.error.exists()) {
      if (retryLimit <= 0) {
        throw new RuntimeException(s"Files with submission result were not created")
      } else {
        log.warn(s"Waiting for submission id in file: ${submissionFiles.id}")
        Thread.sleep((5 seconds).toMillis)
        waitSubmissionResultFile(submissionFiles, retryLimit - 1)
      }
    }
  }

  private def getSubmissionId(submissionFiles: SubmissionFiles, waitSubmissionIdRetryLimit: Int): String = {
    waitSubmissionResultFile(submissionFiles, waitSubmissionIdRetryLimit)
    if (submissionFiles.error.exists()) {
      val message = FileUtils.readFileToString(submissionFiles.error)
      throw new RuntimeException(s"Spark submission finished with error: $message")
    }
    FileUtils.readFileToString(submissionFiles.id)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy