All Downloads are FREE. Search and download functionalities are using the official Maven repository.

spark.scheduler.ResultTask.scala Maven / Gradle / Ivy

The newest version!
package spark.scheduler

import spark._
import java.io._
import util.{MetadataCleaner, TimeStampedHashMap}
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

private[spark] object ResultTask {

  // A simple map between the stage id to the serialized byte array of a task.
  // Served as a cache for task serialization because serialization can be
  // expensive on the master node if it needs to launch thousands of tasks.
  val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]

  val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)

  def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
    synchronized {
      val old = serializedInfoCache.get(stageId).orNull
      if (old != null) {
        return old
      } else {
        val out = new ByteArrayOutputStream
        val ser = SparkEnv.get.closureSerializer.newInstance
        val objOut = ser.serializeStream(new GZIPOutputStream(out))
        objOut.writeObject(rdd)
        objOut.writeObject(func)
        objOut.close()
        val bytes = out.toByteArray
        serializedInfoCache.put(stageId, bytes)
        return bytes
      }
    }
  }

  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
    synchronized {
      val loader = Thread.currentThread.getContextClassLoader
      val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
      val ser = SparkEnv.get.closureSerializer.newInstance
      val objIn = ser.deserializeStream(in)
      val rdd = objIn.readObject().asInstanceOf[RDD[_]]
      val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
      return (rdd, func)
    }
  }

  def clearCache() {
    synchronized {
      serializedInfoCache.clear()
    }
  }
}


private[spark] class ResultTask[T, U](
    stageId: Int,
    var rdd: RDD[T],
    var func: (TaskContext, Iterator[T]) => U,
    var partition: Int,
    @transient locs: Seq[String],
    val outputId: Int)
  extends Task[U](stageId) with Externalizable {

  def this() = this(0, null, null, 0, null, 0)

  var split = if (rdd == null) {
    null
  } else {
    rdd.partitions(partition)
  }

  override def run(attemptId: Long): U = {
    val context = new TaskContext(stageId, partition, attemptId)
    metrics = Some(context.taskMetrics)
    try {
      func(context, rdd.iterator(split, context))
    } finally {
      context.executeOnCompleteCallbacks()
    }
  }

  override def preferredLocations: Seq[String] = locs

  override def toString = "ResultTask(" + stageId + ", " + partition + ")"

  override def writeExternal(out: ObjectOutput) {
    RDDCheckpointData.synchronized {
      split = rdd.partitions(partition)
      out.writeInt(stageId)
      val bytes = ResultTask.serializeInfo(
        stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
      out.writeInt(bytes.length)
      out.write(bytes)
      out.writeInt(partition)
      out.writeInt(outputId)
      out.writeObject(split)
    }
  }

  override def readExternal(in: ObjectInput) {
    val stageId = in.readInt()
    val numBytes = in.readInt()
    val bytes = new Array[Byte](numBytes)
    in.readFully(bytes)
    val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
    rdd = rdd_.asInstanceOf[RDD[T]]
    func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
    partition = in.readInt()
    val outputId = in.readInt()
    split = in.readObject().asInstanceOf[Partition]
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy