
io.projectglow.transformers.pipe.Piper.scala Maven / Gradle / Ivy
/*
* Copyright 2019 The Glow Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.projectglow.transformers.pipe
import java.lang.{IllegalStateException => ISE}
import java.io._
import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.io.Source
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLUtils, SparkSession}
import org.apache.spark.storage.StorageLevel
import io.projectglow.common.{GlowLogging, WithUtils}
/**
* Based on Spark's PipedRDD with the following modifications:
* - Act only on DataFrames instead of generic RDDs
* - Use the input and output formatters to determine output schema
* - Use the input and output formatters to return a DataFrame
*/
private[projectglow] object Piper extends GlowLogging {
private val cachedRdds = mutable.ListBuffer[RDD[_]]()
def clearCache(): Unit = cachedRdds.synchronized {
SparkSession.getActiveSession match {
case None => // weird
case Some(spark) =>
cachedRdds.foreach { rdd =>
if (rdd.sparkContext == spark.sparkContext) {
rdd.unpersist()
}
}
}
cachedRdds.clear()
}
// Pipes a single row of the input DataFrame to get the output schema before piping all of it.
def pipe(
informatter: InputFormatter,
outputformatter: OutputFormatter,
cmd: Seq[String],
env: Map[String, String],
df: DataFrame,
quarantineLocation: Option[(String, String)] = None): DataFrame = {
logger.info(s"Beginning pipe with cmd $cmd")
val quarantineInfo = quarantineLocation.map { a =>
val (location, flavor) = a
PipeIterator.QuarantineInfo(df, location, PipeIterator.QuarantineWriter(flavor))
}
val rawRdd = df.queryExecution.toRdd
val inputRdd = if (rawRdd.getNumPartitions == 0) {
logger.warn("Not piping any rows, as the input DataFrame has zero partitions.")
SQLUtils.createEmptyRDD(df.sparkSession)
} else {
rawRdd
}
// Each partition consists of an iterator with the schema, followed by [[InternalRow]]s with the
// schema
val schemaInternalRowRDD = inputRdd.mapPartitions { it =>
if (it.isEmpty) {
Iterator.empty
} else {
new PipeIterator(cmd, env, it, informatter, outputformatter)
}
}.persist(StorageLevel.DISK_ONLY)
cachedRdds.synchronized {
cachedRdds.append(schemaInternalRowRDD)
}
// Quarantining is potentially very wasteful due to the throw-based control
// flow implemented at the level below.
quarantineInfo.foreach { quarantineInfo =>
try {
schemaInternalRowRDD.mapPartitions { it =>
if (it.nonEmpty) {
val result = if (it.asInstanceOf[PipeIterator].error) {
Iterator(true)
} else {
Iterator.empty
}
result
} else {
Iterator.empty
}
}.filter(identity).take(1).nonEmpty
} catch { case _: Throwable => quarantineInfo.flavor.quarantine(quarantineInfo) }
}
val schemaSeq = schemaInternalRowRDD.mapPartitions { it =>
if (it.hasNext) {
Iterator(it.next.asInstanceOf[StructType])
} else {
Iterator.empty
}
}.collect.distinct
if (schemaSeq.length != 1) {
throw new IllegalStateException(
s"Cannot infer schema: saw ${schemaSeq.length} distinct schemas.")
}
val schema = schemaSeq.head
val internalRowRDD = schemaInternalRowRDD.mapPartitions { it =>
it.drop(1).asInstanceOf[Iterator[InternalRow]]
}
SQLUtils.internalCreateDataFrame(df.sparkSession, internalRowRDD, schema, isStreaming = false)
}
}
private[projectglow] class ProcessHelper(
cmd: Seq[String],
environment: Map[String, String],
inputFn: OutputStream => Unit,
context: TaskContext)
extends GlowLogging {
private val _childThreadException = new AtomicReference[Throwable](null)
private var process: Process = _
def startProcess(): BufferedInputStream = {
val pb = new ProcessBuilder(cmd.asJava)
val pbEnv = pb.environment()
environment.foreach { case (k, v) => pbEnv.put(k, v) }
process = pb.start()
val stdinWriterThread = new Thread(s"${ProcessHelper.STDIN_WRITER_THREAD_PREFIX} for $cmd") {
override def run(): Unit = {
SQLUtils.setTaskContext(context)
val out = process.getOutputStream
try {
inputFn(out)
} catch {
case t: Throwable => _childThreadException.set(t)
} finally {
out.close()
}
}
}
stdinWriterThread.start()
val stderrReaderThread = new Thread(s"${ProcessHelper.STDERR_READER_THREAD_PREFIX} for $cmd") {
override def run(): Unit = {
val err = process.getErrorStream
try {
for (line <- Source.fromInputStream(err).getLines) {
logger.info(s"Got stderr line")
// scalastyle:off println
System.err.println(line)
// scalastyle:on println
}
} catch {
case t: Throwable => _childThreadException.set(t)
} finally {
err.close()
}
}
}
stderrReaderThread.start()
new BufferedInputStream(process.getInputStream)
}
def waitForProcess(): Int = {
if (process == null) {
throw new IllegalStateException(s"Process hasn't been started yet")
}
process.waitFor()
}
def propagateChildException(): Unit = {
childThreadExceptionO.foreach { t =>
processO.foreach(_.destroy())
throw t
}
}
def childThreadExceptionO: Option[Throwable] = Option(_childThreadException.get())
def processO: Option[Process] = Option(process)
}
object ProcessHelper {
val STDIN_WRITER_THREAD_PREFIX = "stdin writer"
val STDERR_READER_THREAD_PREFIX = "stderr reader"
}
class PipeIterator(
cmd: Seq[String],
environment: Map[String, String],
_input: Iterator[InternalRow],
inputFormatter: InputFormatter,
outputFormatter: OutputFormatter)
extends Iterator[Any] {
import PipeIterator.illegalStateException
private val input = _input.toSeq
private val processHelper = new ProcessHelper(cmd, environment, writeInput, TaskContext.get)
private val inputStream = processHelper.startProcess()
private val baseIterator = outputFormatter.makeIterator(inputStream)
private def writeInput(stream: OutputStream): Unit = {
WithUtils.withCloseable(inputFormatter) { informatter =>
informatter.init(stream)
input.foreach(informatter.write)
}
}
override def hasNext: Boolean = {
val result = if (baseIterator.hasNext) {
true
} else {
val exitStatus = processHelper.waitForProcess()
if (exitStatus != 0) {
throw illegalStateException(s"Subprocess exited with status $exitStatus")
}
false
}
processHelper.propagateChildException()
result
}
override def next(): Any = baseIterator.next()
def error: Boolean = {
(0 == processHelper.waitForProcess())
}
}
object PipeIterator {
// This would typically be a typeclass, but this code base appears to be
// subclass polymorphic rather than typeclass polymorphic
trait QuarantineWriter extends Product with Serializable {
def quarantine(qi: QuarantineInfo): Unit
}
final object QuarantineWriter {
def apply(flavor: String): QuarantineWriter = flavor match {
case "delta" => QuarantineWriterDelta
case "csv" => QuarantineWriterCsv
case _ => throw illegalStateException(s"unknown QuarantineWriter flavor: $flavor")
}
}
final case object QuarantineWriterDelta extends QuarantineWriter {
override def quarantine(qi: QuarantineInfo): Unit = {
qi.df.write.format("delta").mode("append").saveAsTable(qi.location)
}
}
final case object QuarantineWriterCsv extends QuarantineWriter {
override def quarantine(qi: QuarantineInfo): Unit = {
val df = qi.df.write.mode("append").csv(qi.location)
}
}
/* ~~~Scalastyle template evidently does not accept standard scaladoc comments~~~
* ~~~Scalastyle states "Insert a space after the start of the comment" ~~~
* Data for Quarantining records which fail in process.
* @param df The [[DataFrame]] being processed.
* @param location The delta table to write to. Typically of the form `classifier.tableName`.
*/
final case class QuarantineInfo(df: DataFrame, location: String, flavor: QuarantineWriter)
def illegalStateException(message: String): IllegalStateException =
new ISE("[PipeIterator] " + message)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy