All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.io.http.SimpleHTTPTransformer.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.io.http

import com.microsoft.ml.spark.core.contracts.{HasInputCol, HasOutputCol}
import com.microsoft.ml.spark.core.env.InternalWrapper
import com.microsoft.ml.spark.core.schema.DatasetExtensions.{findUnusedColumnName => newCol}
import com.microsoft.ml.spark.stages.{DropColumns, FlattenBatch, HasMiniBatcher, Lambda}
import org.apache.commons.io.IOUtils
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

object SimpleHTTPTransformer extends ComplexParamsReadable[SimpleHTTPTransformer]

trait HasErrorCol extends Params {
  val errorCol = new Param[String](this, "errorCol", "column to hold http errors")

  def setErrorCol(v: String): this.type = set(errorCol, v)

  def getErrorCol: String = $(errorCol)

}

object ErrorUtils extends Serializable {
  val ErrorSchema: StructType = new StructType()
    .add("response", StringType, nullable = true)
    .add("status", StatusLineData.schema, nullable = true)

  protected def addError(fromRow: Row => HTTPResponseData)(responseRow: Row): Option[Row] = {
    val respOpt = Option(responseRow).map(fromRow)
    respOpt match {
      case Some(resp) if resp.statusLine.statusCode == 200 => None
      case Some(resp) =>
        Some(Row(resp.entity.map(entity =>
          IOUtils.toString(entity.content, "UTF-8")), resp.statusLine))
      case None => None
    }
  }

  def nullifyResponse(errorRow: Row, responseRow: Row): Option[Row] = {
    if (errorRow == null && responseRow != null) {
      Some(responseRow)
    } else {
      None
    }
  }

  def addErrorUDF: UserDefinedFunction = {
    val fromRow = HTTPResponseData.makeFromRowConverter
    udf(addError(fromRow) _, ErrorSchema)
  }

  val NullifyResponseUDF: UserDefinedFunction = udf(nullifyResponse _, HTTPSchema.Response)

}

@InternalWrapper
class SimpleHTTPTransformer(val uid: String)
  extends Transformer with HTTPParams with HasMiniBatcher with HasHandler
    with HasInputCol with HasOutputCol with ComplexParamsWritable with HasErrorCol {

  def this() = this(Identifiable.randomUID("SimpleHTTPTransformer"))

  val flattenOutputBatches: Param[Boolean] = new BooleanParam(
    this, "flattenOutputBatches", "whether to flatten the output batches")

  /** @group getParam */
  def getFlattenOutputBatches: Boolean = $(flattenOutputBatches)

  /** @group setParam */
  def setFlattenOutputBatches(value: Boolean): this.type = set(flattenOutputBatches, value)

  val inputParser: Param[Transformer] = new TransformerParam(
    this, "inputParser", "format to parse the column to",
    { case _: HTTPInputParser => true; case _ => false })

  /** @group getParam */
  def getInputParser: HTTPInputParser = $(inputParser).asInstanceOf[HTTPInputParser]

  /** @group setParam */
  def setInputParser(value: HTTPInputParser): this.type = set(inputParser, value)

  setDefault(
    inputParser -> new JSONInputParser(),
    handler->HandlingUtils.advancedUDF(0,50,100,500),
    errorCol -> (this.uid + "_errors"))

  def setUrl(url: String): SimpleHTTPTransformer.this.type = {
    getInputParser match {
      case jip: JSONInputParser => setInputParser(jip.setUrl(url))
      case _ => throw new IllegalArgumentException("this setting is only availible when using a JSONInputParser")
    }
  }

  val outputParser: Param[Transformer] = new TransformerParam(
    this, "outputParser", "format to parse the column to",
    { case _: HTTPOutputParser => true; case _ => false })

  /** @group getParam */
  def getOutputParser: HTTPOutputParser = $(outputParser).asInstanceOf[HTTPOutputParser]

  /** @group setParam */
  def setOutputParser(value: HTTPOutputParser): this.type = set(outputParser, value)

  private def makePipeline(schema: StructType): PipelineModel = {
    val colsToAvoid = schema.fieldNames.toSet ++ Set(getOutputCol)

    val parsedInputCol = newCol("parsedInput")(colsToAvoid)
    val unparsedOutputCol = newCol("unparsedOutput")(colsToAvoid)

    val mb = get(miniBatcher)

    val inputParser = Some(getInputParser
      .setInputCol(getInputCol)
      .setOutputCol(parsedInputCol))

    val client = Some(new HTTPTransformer()
      .setHandler(getHandler)
      .setConcurrency(getConcurrency)
      .setConcurrentTimeout(getConcurrentTimeout)
      .setInputCol(parsedInputCol)
      .setOutputCol(unparsedOutputCol))

    val parseErrors = Some(Lambda(_
      .withColumn(getErrorCol, ErrorUtils.addErrorUDF(col(unparsedOutputCol)))
      .withColumn(unparsedOutputCol, ErrorUtils.NullifyResponseUDF(col(getErrorCol), col(unparsedOutputCol)))
    ))

    val outputParser =
      Some(getOutputParser
        .setInputCol(unparsedOutputCol)
        .setOutputCol(getOutputCol))

    val dropCols = Some(new DropColumns().setCols(Array(parsedInputCol, unparsedOutputCol)))

    val flatten = mb.flatMap(_ =>
      get(flattenOutputBatches) match {
        case None => Some(new FlattenBatch())
        case Some(true) => Some(new FlattenBatch())
        case Some(false) => None
      }
    )

    NamespaceInjections.pipelineModel(Array(
      mb, inputParser, client, parseErrors, outputParser, dropCols, flatten
    ).flatten)

  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    makePipeline(dataset.schema).transform(dataset.toDF())
  }

  override def copy(extra: ParamMap): this.type = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    makePipeline(schema).transformSchema(schema)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy