org.apache.spark.BarrierTaskContext.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.util.{Properties, Timer, TimerTask}
import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.{Failure, Success => ScalaSuccess, Try}
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._
/**
* :: Experimental ::
* A [[TaskContext]] with extra contextual info and tooling for tasks in a barrier stage.
* Use [[BarrierTaskContext#get]] to obtain the barrier context for a running barrier task.
*/
@Experimental
@Since("2.4.0")
class BarrierTaskContext private[spark] (
taskContext: TaskContext) extends TaskContext with Logging {
import BarrierTaskContext._
// Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls.
private val barrierCoordinator: RpcEndpointRef = {
val env = SparkEnv.get
RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv)
}
// Local barrierEpoch that identify a barrier() call from current task, it shall be identical
// with the driver side epoch.
private var barrierEpoch = 0
// Number of tasks of the current barrier stage, a barrier() call must collect enough requests
// from different tasks within the same barrier stage attempt to succeed.
private lazy val numTasks = getTaskInfos().size
private def runBarrier(message: String, requestMethod: RequestMethod.Value): Array[String] = {
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
s"the global sync, current barrier epoch is $barrierEpoch.")
logTrace("Current callSite: " + Utils.getCallSite())
val startTime = System.currentTimeMillis()
val timerTask = new TimerTask {
override def run(): Unit = {
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " +
s"under the global sync since $startTime, has been waiting for " +
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
s"current barrier epoch is $barrierEpoch.")
}
}
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)
try {
val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]](
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, message, requestMethod),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
timeout = new RpcTimeout(365.days, "barrierTimeout"))
// Wait the RPC future to be completed, but every 1 second it will jump out waiting
// and check whether current spark task is killed. If killed, then throw
// a `TaskKilledException`, otherwise continue wait RPC until it completes.
while (!abortableRpcFuture.future.isCompleted) {
try {
// wait RPC future for at most 1 second
Thread.sleep(1000)
} catch {
case _: InterruptedException => // task is killed by driver
} finally {
Try(taskContext.killTaskIfInterrupted()) match {
case ScalaSuccess(_) => // task is still running healthily
case Failure(e) => abortableRpcFuture.abort(e)
}
}
}
// messages which consist of all barrier tasks' messages. The future will return the
// desired messages if it is completed successfully. Otherwise, exception could be thrown.
val messages = abortableRpcFuture.future.value.get.get
barrierEpoch += 1
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " +
"global sync successfully, waited for " +
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
s"current barrier epoch is $barrierEpoch.")
messages
} catch {
case e: SparkException =>
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " +
"to perform global sync, waited for " +
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
s"current barrier epoch is $barrierEpoch.")
throw e
} finally {
timerTask.cancel()
timer.purge()
}
}
/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
* MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
* stage have reached this routine.
*
* CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
* possible code branches. Otherwise, you may get the job hanging or a SparkException after
* timeout. Some examples of '''misuses''' are listed below:
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* if (context.partitionId() == 0) {
* // Do nothing.
* } else {
* context.barrier()
* }
* iter
* }
* }}}
*
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
* second function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* try {
* // Do something that might throw an Exception.
* doSomething()
* context.barrier()
* } catch {
* case e: Exception => logWarning("...", e)
* }
* context.barrier()
* iter
* }
* }}}
*/
@Experimental
@Since("2.4.0")
def barrier(): Unit = runBarrier("", RequestMethod.BARRIER)
/**
* :: Experimental ::
* Blocks until all tasks in the same stage have reached this routine. Each task passes in
* a message and returns with a list of all the messages passed in by each of those tasks.
*
* CAUTION! The allGather method requires the same precautions as the barrier method
*
* The message is type String rather than Array[Byte] because it is more convenient for
* the user at the cost of worse performance.
*/
@Experimental
@Since("3.0.0")
def allGather(message: String): Array[String] = runBarrier(message, RequestMethod.ALL_GATHER)
/**
* :: Experimental ::
* Returns [[BarrierTaskInfo]] for all tasks in this barrier stage, ordered by partition ID.
*/
@Experimental
@Since("2.4.0")
def getTaskInfos(): Array[BarrierTaskInfo] = {
val addressesStr = Option(taskContext.getLocalProperty("addresses")).getOrElse("")
addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_))
}
// delegate methods
override def isCompleted(): Boolean = taskContext.isCompleted()
override def isInterrupted(): Boolean = taskContext.isInterrupted()
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
taskContext.addTaskCompletionListener(listener)
this
}
override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
taskContext.addTaskFailureListener(listener)
this
}
override def stageId(): Int = taskContext.stageId()
override def stageAttemptNumber(): Int = taskContext.stageAttemptNumber()
override def partitionId(): Int = taskContext.partitionId()
override def attemptNumber(): Int = taskContext.attemptNumber()
override def taskAttemptId(): Long = taskContext.taskAttemptId()
override def getLocalProperty(key: String): String = taskContext.getLocalProperty(key)
override def taskMetrics(): TaskMetrics = taskContext.taskMetrics()
override def getMetricsSources(sourceName: String): Seq[Source] = {
taskContext.getMetricsSources(sourceName)
}
override def resources(): Map[String, ResourceInformation] = taskContext.resources()
override def resourcesJMap(): java.util.Map[String, ResourceInformation] = {
resources().asJava
}
override private[spark] def killTaskIfInterrupted(): Unit = taskContext.killTaskIfInterrupted()
override private[spark] def getKillReason(): Option[String] = taskContext.getKillReason()
override private[spark] def taskMemoryManager(): TaskMemoryManager = {
taskContext.taskMemoryManager()
}
override private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = {
taskContext.registerAccumulator(a)
}
override private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
taskContext.setFetchFailed(fetchFailed)
}
override private[spark] def markInterrupted(reason: String): Unit = {
taskContext.markInterrupted(reason)
}
override private[spark] def markTaskFailed(error: Throwable): Unit = {
taskContext.markTaskFailed(error)
}
override private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = {
taskContext.markTaskCompleted(error)
}
override private[spark] def fetchFailed: Option[FetchFailedException] = {
taskContext.fetchFailed
}
override private[spark] def getLocalProperties: Properties = taskContext.getLocalProperties
}
@Experimental
@Since("2.4.0")
object BarrierTaskContext {
/**
* :: Experimental ::
* Returns the currently active BarrierTaskContext. This can be called inside of user functions to
* access contextual information about running barrier tasks.
*/
@Experimental
@Since("2.4.0")
def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext]
private val timer = new Timer("Barrier task timer for barrier() calls.")
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy