All Downloads are FREE. Search and download functionalities are using the official Maven repository.

spark.api.python.PythonRDD.scala Maven / Gradle / Ivy

package spark.api.python

import java.io._
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.JavaConversions._

import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
import spark._
import spark.rdd.PipedRDD


private[spark] class PythonRDD[T: ClassManifest](
    parent: RDD[T],
    command: Seq[String],
    envVars: JMap[String, String],
    preservePartitoning: Boolean,
    pythonExec: String,
    broadcastVars: JList[Broadcast[Array[Byte]]],
    accumulator: Accumulator[JList[Array[Byte]]])
  extends RDD[Array[Byte]](parent) {

  val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt

  // Similar to Runtime.exec(), if we are given a single string, split it into words
  // using a standard StringTokenizer (i.e. by spaces)
  def this(parent: RDD[T], command: String, envVars: JMap[String, String],
      preservePartitoning: Boolean, pythonExec: String,
      broadcastVars: JList[Broadcast[Array[Byte]]],
      accumulator: Accumulator[JList[Array[Byte]]]) =
    this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
      broadcastVars, accumulator)

  override def getPartitions = parent.partitions

  override val partitioner = if (preservePartitoning) parent.partitioner else None


  override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
    val startTime = System.currentTimeMillis
    val env = SparkEnv.get
    val worker = env.createPythonWorker(pythonExec, envVars.toMap)

    // Start a thread to feed the process input from our parent's iterator
    new Thread("stdin writer for " + pythonExec) {
      override def run() {
        SparkEnv.set(env)
        val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
        val dataOut = new DataOutputStream(stream)
        val printOut = new PrintWriter(stream)
        // Partition index
        dataOut.writeInt(split.index)
        // sparkFilesDir
        PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
        // Broadcast variables
        dataOut.writeInt(broadcastVars.length)
        for (broadcast <- broadcastVars) {
          dataOut.writeLong(broadcast.id)
          dataOut.writeInt(broadcast.value.length)
          dataOut.write(broadcast.value)
        }
        dataOut.flush()
        // Serialized user code
        for (elem <- command) {
          printOut.println(elem)
        }
        printOut.flush()
        // Data values
        for (elem <- parent.iterator(split, context)) {
          PythonRDD.writeAsPickle(elem, dataOut)
        }
        dataOut.flush()
        printOut.flush()
        worker.shutdownOutput()
      }
    }.start()

    // Return an iterator that read lines from the process's stdout
    val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
    return new Iterator[Array[Byte]] {
      def next(): Array[Byte] = {
        val obj = _nextObj
        if (hasNext) {
          // FIXME: can deadlock if worker is waiting for us to
          // respond to current message (currently irrelevant because
          // output is shutdown before we read any input)
          _nextObj = read()
        }
        obj
      }

      private def read(): Array[Byte] = {
        try {
          stream.readInt() match {
            case length if length > 0 =>
              val obj = new Array[Byte](length)
              stream.readFully(obj)
              obj
            case -3 =>
              // Timing data from worker
              val bootTime = stream.readLong()
              val initTime = stream.readLong()
              val finishTime = stream.readLong()
              val boot = bootTime - startTime
              val init = initTime - bootTime
              val finish = finishTime - initTime
              val total = finishTime - startTime
              logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
              read
            case -2 =>
              // Signals that an exception has been thrown in python
              val exLength = stream.readInt()
              val obj = new Array[Byte](exLength)
              stream.readFully(obj)
              throw new PythonException(new String(obj))
            case -1 =>
              // We've finished the data section of the output, but we can still
              // read some accumulator updates; let's do that, breaking when we
              // get a negative length record.
              var len2 = stream.readInt()
              while (len2 >= 0) {
                val update = new Array[Byte](len2)
                stream.readFully(update)
                accumulator += Collections.singletonList(update)
                len2 = stream.readInt()
              }
              new Array[Byte](0)
          }
        } catch {
          case eof: EOFException => {
            throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
          }
          case e => throw e
        }
      }

      var _nextObj = read()

      def hasNext = _nextObj.length != 0
    }
  }

  val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}

/** Thrown for exceptions in user Python code. */
private class PythonException(msg: String) extends Exception(msg)

/**
 * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
 * This is used by PySpark's shuffle operations.
 */
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
  RDD[(Array[Byte], Array[Byte])](prev) {
  override def getPartitions = prev.partitions
  override def compute(split: Partition, context: TaskContext) =
    prev.iterator(split, context).grouped(2).map {
      case Seq(a, b) => (a, b)
      case x          => throw new SparkException("PairwiseRDD: unexpected value: " + x)
    }
  val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}

private[spark] object PythonRDD {

  /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
  def stripPickle(arr: Array[Byte]) : Array[Byte] = {
    arr.slice(2, arr.length - 1)
  }

  /**
   * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
   * The data format is a 32-bit integer representing the pickled object's length (in bytes),
   * followed by the pickled data.
   *
   * Pickle module:
   *
   *    http://docs.python.org/2/library/pickle.html
   *
   * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
   *
   *    http://hg.python.org/cpython/file/2.6/Lib/pickle.py
   *    http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
   *
   * @param elem the object to write
   * @param dOut a data output stream
   */
  def writeAsPickle(elem: Any, dOut: DataOutputStream) {
    if (elem.isInstanceOf[Array[Byte]]) {
      val arr = elem.asInstanceOf[Array[Byte]]
      dOut.writeInt(arr.length)
      dOut.write(arr)
    } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
      val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
      val length = t._1.length + t._2.length - 3 - 3 + 4  // stripPickle() removes 3 bytes
      dOut.writeInt(length)
      dOut.writeByte(Pickle.PROTO)
      dOut.writeByte(Pickle.TWO)
      dOut.write(PythonRDD.stripPickle(t._1))
      dOut.write(PythonRDD.stripPickle(t._2))
      dOut.writeByte(Pickle.TUPLE2)
      dOut.writeByte(Pickle.STOP)
    } else if (elem.isInstanceOf[String]) {
      // For uniformity, strings are wrapped into Pickles.
      val s = elem.asInstanceOf[String].getBytes("UTF-8")
      val length = 2 + 1 + 4 + s.length + 1
      dOut.writeInt(length)
      dOut.writeByte(Pickle.PROTO)
      dOut.writeByte(Pickle.TWO)
      dOut.write(Pickle.BINUNICODE)
      dOut.writeInt(Integer.reverseBytes(s.length))
      dOut.write(s)
      dOut.writeByte(Pickle.STOP)
    } else {
      throw new SparkException("Unexpected RDD type")
    }
  }

  def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
  JavaRDD[Array[Byte]] = {
    val file = new DataInputStream(new FileInputStream(filename))
    val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
    try {
      while (true) {
        val length = file.readInt()
        val obj = new Array[Byte](length)
        file.readFully(obj)
        objs.append(obj)
      }
    } catch {
      case eof: EOFException => {}
      case e => throw e
    }
    JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
  }

  def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
    import scala.collection.JavaConverters._
    writeIteratorToPickleFile(items.asScala, filename)
  }

  def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
    val file = new DataOutputStream(new FileOutputStream(filename))
    for (item <- items) {
      writeAsPickle(item, file)
    }
    file.close()
  }

  def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
    implicit val cm : ClassManifest[T] = rdd.elementClassManifest
    rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
  }
}

private object Pickle {
  val PROTO: Byte = 0x80.toByte
  val TWO: Byte = 0x02.toByte
  val BINUNICODE: Byte = 'X'
  val STOP: Byte = '.'
  val TUPLE2: Byte = 0x86.toByte
  val EMPTY_LIST: Byte = ']'
  val MARK: Byte = '('
  val APPENDS: Byte = 'e'
}

private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
  override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}

/**
 * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
 * collects a list of pickled strings that we pass to Python through a socket.
 */
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
  extends AccumulatorParam[JList[Array[Byte]]] {

  val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
  
  override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList

  override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
      : JList[Array[Byte]] = {
    if (serverHost == null) {
      // This happens on the worker node, where we just want to remember all the updates
      val1.addAll(val2)
      val1
    } else {
      // This happens on the master, where we pass the updates to Python through a socket
      val socket = new Socket(serverHost, serverPort)
      val in = socket.getInputStream
      val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
      out.writeInt(val2.size)
      for (array <- val2) {
        out.writeInt(array.length)
        out.write(array)
      }
      out.flush()
      // Wait for a byte from the Python side as an acknowledgement
      val byteRead = in.read()
      if (byteRead == -1) {
        throw new SparkException("EOF reached before Python server acknowledged")
      }
      socket.close()
      null
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy