All Downloads are FREE. Search and download functionalities are using the official Maven repository.

spark.rdd.CheckpointRDD.scala Maven / Gradle / Ivy

The newest version!
package spark.rdd

import spark._
import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.{NullWritable, BytesWritable}
import org.apache.hadoop.util.ReflectionUtils
import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat

private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}

/**
 * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
 */
private[spark]
class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
  extends RDD[T](sc, Nil) {

  @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)

  override def getPartitions: Array[Partition] = {
    val dirContents = fs.listStatus(new Path(checkpointPath))
    val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
    val numPartitions =  partitionFiles.size
    if (numPartitions > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
        ! partitionFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1)))) {
      throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
    }
    Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
  }

  checkpointData = Some(new RDDCheckpointData[T](this))
  checkpointData.get.cpFile = Some(checkpointPath)

  override def getPreferredLocations(split: Partition): Seq[String] = {
    val status = fs.getFileStatus(new Path(checkpointPath))
    val locations = fs.getFileBlockLocations(status, 0, status.getLen)
    locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
  }

  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
    CheckpointRDD.readFromFile(file, context)
  }

  override def checkpoint() {
    // Do nothing. CheckpointRDD should not be checkpointed.
  }
}

private[spark] object CheckpointRDD extends Logging {

  def splitIdToFile(splitId: Int): String = {
    "part-%05d".format(splitId)
  }

  def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
    val outputDir = new Path(path)
    val fs = outputDir.getFileSystem(new Configuration())

    val finalOutputName = splitIdToFile(ctx.splitId)
    val finalOutputPath = new Path(outputDir, finalOutputName)
    val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)

    if (fs.exists(tempOutputPath)) {
      throw new IOException("Checkpoint failed: temporary path " +
        tempOutputPath + " already exists")
    }
    val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt

    val fileOutputStream = if (blockSize < 0) {
      fs.create(tempOutputPath, false, bufferSize)
    } else {
      // This is mainly for testing purpose
      fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
    }
    val serializer = SparkEnv.get.serializer.newInstance()
    val serializeStream = serializer.serializeStream(fileOutputStream)
    serializeStream.writeAll(iterator)
    serializeStream.close()

    if (!fs.rename(tempOutputPath, finalOutputPath)) {
      if (!fs.exists(finalOutputPath)) {
        fs.delete(tempOutputPath, false)
        throw new IOException("Checkpoint failed: failed to save output of task: "
          + ctx.attemptId + " and final output path does not exist")
      } else {
        // Some other copy of this task must've finished before us and renamed it
        logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
        fs.delete(tempOutputPath, false)
      }
    }
  }

  def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
    val fs = path.getFileSystem(new Configuration())
    val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
    val fileInputStream = fs.open(path, bufferSize)
    val serializer = SparkEnv.get.serializer.newInstance()
    val deserializeStream = serializer.deserializeStream(fileInputStream)

    // Register an on-task-completion callback to close the input stream.
    context.addOnCompleteCallback(() => deserializeStream.close())

    deserializeStream.asIterator.asInstanceOf[Iterator[T]]
  }

  // Test whether CheckpointRDD generate expected number of partitions despite
  // each split file having multiple blocks. This needs to be run on a
  // cluster (mesos or standalone) using HDFS.
  def main(args: Array[String]) {
    import spark._

    val Array(cluster, hdfsPath) = args
    val sc = new SparkContext(cluster, "CheckpointRDD Test")
    val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
    val path = new Path(hdfsPath, "temp")
    val fs = path.getFileSystem(new Configuration())
    sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
    val cpRDD = new CheckpointRDD[Int](sc, path.toString)
    assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
    assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
    fs.delete(path)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy