All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.io.http.HTTPTransformer.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.io.http

import com.microsoft.ml.spark.core.contracts.{HasInputCol, HasOutputCol, Wrappable}
import com.microsoft.ml.spark.io.http.HandlingUtils.HandlerFunc
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration

trait HasHandler extends Params {
  val handler: UDFParam = new UDFParam(
    this, "handler", "Which strategy to use when handling requests")

  /** @group getParam */
  def getHandler: HandlerFunc = $(handler).f.asInstanceOf[HandlerFunc]

  def setHandler(v: HandlerFunc): HasHandler.this.type = {
    set(handler, udf(v, StringType))
  }
}

trait HTTPParams extends Wrappable {
  val concurrency: Param[Int] = new IntParam(
    this, "concurrency", "max number of concurrent calls")

  /** @group getParam */
  def getConcurrency: Int = $(concurrency)

  /** @group setParam */
  def setConcurrency(value: Int): this.type = set(concurrency, value)

  val timeout: Param[Double] = new DoubleParam(
    this, "timeout", "number of seconds to wait before closing the connection")

  /** @group getParam */
  def getTimeout: Double = $(timeout)

  /** @group setParam */
  def setTimeout(value: Double): this.type = set(timeout, value)

  val concurrentTimeout: Param[Double] = new DoubleParam(
    this, "concurrentTimeout", "max number seconds to wait on futures if concurrency >= 1")

  /** @group getParam */
  def getConcurrentTimeout: Double = $(concurrentTimeout)

  /** @group setParam */
  def setConcurrentTimeout(value: Double): this.type = set(concurrentTimeout, value)

  setDefault(concurrency -> 1,
    timeout -> 60.0,
    concurrentTimeout -> 100.0)

}

trait HasURL extends Wrappable {

  val url: Param[String] = new Param[String](this, "url", "Url of the service")

  /** @group getParam */
  def getUrl: String = $(url)

  /** @group setParam */
  def setUrl(value: String): this.type = set(url, value)

}

object HTTPTransformer extends ComplexParamsReadable[HTTPTransformer]

class HTTPTransformer(val uid: String)
  extends Transformer with HTTPParams with HasInputCol
    with HasOutputCol with HasHandler
    with ComplexParamsWritable {

  setDefault(handler -> HandlingUtils.advancedUDF(100,500,1000)) //scalastyle:ignore magic.number

  def this() = this(Identifiable.randomUID("HTTPTransformer"))

  val clientHolder = SharedVariable {
    getConcurrency match {
      case 1 => new SingleThreadedHTTPClient(getHandler, (getTimeout*1000).toInt)
      case n if n > 1 =>
        val dur = Duration.fromNanos((getConcurrentTimeout * math.pow(10, 9)).toLong) //scalastyle:ignore magic.number
        val ec = ExecutionContext.global
        new AsyncHTTPClient(getHandler,n, dur, (getTimeout*1000).toInt)(ec)
    }
  }

  /** @param dataset - The input dataset, to be transformed
    * @return The DataFrame that results from column selection
    */
  override def transform(dataset: Dataset[_]): DataFrame = {
    val df = dataset.toDF()
    val enc = RowEncoder(transformSchema(df.schema))
    val colIndex = df.schema.fieldNames.indexOf(getInputCol)
    val fromRow = HTTPRequestData.makeFromRowConverter
    val toRow = HTTPResponseData.makeToRowConverter
    df.mapPartitions { it =>
      if (!it.hasNext) {
        Iterator()
      }else{
        val c = clientHolder.get
        val responsesWithContext = c.sendRequestsWithContext(it.map{row =>
          c.RequestWithContext(Option(row.getStruct(colIndex)).map(fromRow), Some(row))
        })
        responsesWithContext.map { rwc =>
          Row.merge(rwc.context.get.asInstanceOf[Row], Row(rwc.response.flatMap(Option(_)).map(toRow).orNull))
        }
      }
    }(enc)
  }

  def copy(extra: ParamMap): HTTPTransformer = defaultCopy(extra)

  def transformSchema(schema: StructType): StructType = {
    assert(schema(getInputCol).dataType == HTTPSchema.Request)
    schema.add(getOutputCol, HTTPSchema.Response, nullable=true)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy