All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.execution.streaming.continuous.HTTPSinkV2.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package org.apache.spark.sql.execution.streaming.continuous

import com.microsoft.ml.spark._
import com.microsoft.ml.spark.io.http.HTTPResponseData
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types._

import scala.collection.mutable

class HTTPSinkProviderV2 extends DataSourceV2
  with StreamWriteSupport
  with DataSourceRegister {

  override def createStreamWriter(queryId: String,
                                  schema: StructType,
                                  mode: OutputMode,
                                  options: DataSourceOptions): StreamWriter = {
    new HTTPWriter(schema, options)
  }

  def shortName(): String = "HTTPv2"
}

/** Common methods used to create writes for the the console sink */
class HTTPWriter(schema: StructType, options: DataSourceOptions)
  extends StreamWriter with Logging {

  protected val idCol: String = options.get("idCol").orElse("id")
  protected val replyCol: String = options.get("replyCol").orElse("reply")
  protected val name: String = options.get("name").get

  val idColIndex: Int = schema.fieldIndex(idCol)
  val replyColIndex: Int = schema.fieldIndex(replyCol)

  assert(SparkSession.getActiveSession.isDefined)

  def createWriterFactory(): DataWriterFactory[InternalRow] =
    HTTPWriterFactory(idColIndex, replyColIndex, name)

  override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

  def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
    HTTPSourceStateHolder.cleanUp(name)
  }

}

private[streaming] case class HTTPWriterFactory(idColIndex: Int,
                                                replyColIndex: Int,
                                                name: String)
  extends DataWriterFactory[InternalRow] {
  def createDataWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = {
    new HTTPDataWriter(partitionId, idColIndex, replyColIndex, name, epochId)
  }
}

private[streaming] class HTTPDataWriter(partitionId: Int,
                                        val idColIndex: Int,
                                        val replyColIndex: Int,
                                        val name: String,
                                        epochId: Long)
  extends DataWriter[InternalRow] with Logging {
  logInfo(s"Creating writer on PID:$partitionId")
  HTTPSourceStateHolder.getServer(name).commit(epochId - 1, partitionId)

  private val ids: mutable.ListBuffer[(String, Int)] = new mutable.ListBuffer[(String, Int)]()

  private val fromRow = HTTPResponseData.makeFromInternalRowConverter

  override def write(row: InternalRow): Unit = {
    val id = row.getStruct(idColIndex, 2)
    val mid = id.getString(0)
    val rid = id.getString(1)
    val pid = id.getInt(2)
    val reply = fromRow(row.getStruct(replyColIndex, 4)) //scalastyle:ignore magic.number
    HTTPSourceStateHolder.getServer(name).replyTo(mid, rid, reply)
    ids.append((rid, pid))
  }

  override def commit(): HTTPCommitMessage = {
    val msg = HTTPCommitMessage(ids.toArray)
    ids.foreach { case (rid, pid) =>
      HTTPSourceStateHolder.getServer(name).commit(rid)
    }
    ids.clear()
    msg
  }

  override def abort(): Unit = {
    if (TaskContext.get().getKillReason().contains("Stage cancelled")) {
      HTTPSourceStateHolder.cleanUp(name)
    }
  }
}

private[streaming] case class HTTPCommitMessage(ids: Array[(String, Int)]) extends WriterCommitMessage




© 2015 - 2024 Weber Informatics LLC | Privacy Policy