All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.mlsql.arrow.python.runner.ArrowPythonRunner.scala Maven / Gradle / Ivy

package tech.mlsql.arrow.python.runner

import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch}
import tech.mlsql.arrow.context.CommonTaskContext
import tech.mlsql.arrow.{ArrowUtils, ArrowWriter, Utils}

import java.io._
import java.net._
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._


/**
 * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream.
 */
class ArrowPythonRunner(
                         funcs: Seq[ChainedPythonFunctions],
                         schema: StructType,
                         timeZoneId: String,
                         conf: Map[String, String])
  extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
    funcs, conf) {

  protected override def newWriterThread(
                                          worker: Socket,
                                          inputIterator: Iterator[Iterator[InternalRow]],
                                          partitionIndex: Int,
                                          context: CommonTaskContext): WriterThread = {
    new WriterThread(worker, inputIterator, partitionIndex, context) {

      protected override def writeCommand(dataOut: DataOutputStream): Unit = {

        // Write config for the worker as a number of key -> value pairs of strings
        dataOut.writeInt(conf.size + 1)
        for ((k, v) <- conf) {
          Utils.writeUTF(k, dataOut)
          Utils.writeUTF(v, dataOut)
        }
        Utils.writeUTF("timezone", dataOut)
        Utils.writeUTF(timeZoneId, dataOut)

        val command = funcs.head.funcs.head.command
        Utils.writeUTF(command, dataOut)

      }

      protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
        val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
        val allocator = ArrowUtils.rootAllocator.newChildAllocator(
          s"stdout writer for $pythonExec", 0, Long.MaxValue)
        val root = VectorSchemaRoot.create(arrowSchema, allocator)

        Utils.tryWithSafeFinally {
          val arrowWriter = ArrowWriter.create(root)
          val writer = new ArrowStreamWriter(root, null, dataOut)
          writer.start()

          while (inputIterator.hasNext) {
            val nextBatch = inputIterator.next()

            while (nextBatch.hasNext) {
              arrowWriter.write(nextBatch.next())
            }

            arrowWriter.finish()
            writer.writeBatch()
            arrowWriter.reset()
          }
          // end writes footer to the output stream and doesn't clean any resources.
          // It could throw exception if the output stream is closed, so it should be
          // in the try block.
          writer.end()
        } {
          // If we close root and allocator in TaskCompletionListener, there could be a race
          // condition where the writer thread keeps writing to the VectorSchemaRoot while
          // it's being closed by the TaskCompletion listener.
          // Closing root and allocator here is cleaner because root and allocator is owned
          // by the writer thread and is only visible to the writer thread.
          //
          // If the writer thread is interrupted by TaskCompletionListener, it should either
          // (1) in the try block, in which case it will get an InterruptedException when
          // performing io, and goes into the finally block or (2) in the finally block,
          // in which case it will ignore the interruption and close the resources.
          root.close()
          allocator.close()
        }
      }
    }
  }

  protected override def newReaderIterator(
                                            stream: DataInputStream,
                                            writerThread: WriterThread,
                                            startTime: Long,
                                            worker: Socket,
                                            releasedOrClosed: AtomicBoolean,
                                            context: CommonTaskContext): Iterator[ColumnarBatch] = {
    new ReaderIterator(stream, writerThread, startTime, worker, releasedOrClosed, context) {

      private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
        s"stdin reader for $pythonExec", 0, Long.MaxValue)

      private var reader: ArrowStreamReader = _
      private var root: VectorSchemaRoot = _
      private var schema: StructType = _
      private var vectors: Array[ColumnVector] = _
      context.readerRegister(() => {})(reader, allocator)

      private var batchLoaded = true

      protected override def read(): ColumnarBatch = {
        if (writerThread.exception.isDefined) {
          throw writerThread.exception.get
        }
        try {
          if (reader != null && batchLoaded) {
            batchLoaded = reader.loadNextBatch()
            if (batchLoaded) {
              val batch = new ColumnarBatch(vectors)
              batch.setNumRows(root.getRowCount)
              batch
            } else {
              reader.close(false)
              allocator.close()
              // Reach end of stream. Call `read()` again to read control data.
              read()
            }
          } else {
            stream.readInt() match {
              case SpecialLengths.START_ARROW_STREAM =>
                try {
                  reader = new ArrowStreamReader(stream, allocator)
                  root = reader.getVectorSchemaRoot()
                  schema = ArrowUtils.fromArrowSchema(root.getSchema())
                  vectors = root.getFieldVectors().asScala.map { vector =>
                    new ArrowColumnVector(vector)
                  }.toArray[ColumnVector]
                  read()
                } catch {
                  case e: IOException if (e.getMessage.contains("Missing schema") || e.getMessage.contains("Expected schema but header was")) =>
                    logInfo("Arrow read schema fail", e)
                    reader = null
                    read()
                }

              case SpecialLengths.ARROW_STREAM_CRASH =>
                read()

              case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
                throw handlePythonException()

              case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
                throw handlePythonException()

              case SpecialLengths.END_OF_DATA_SECTION =>
                handleEndOfDataSection()
                null
            }
          }
        } catch handleException
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy