com.microsoft.ml.spark.io.http.HTTPSchema.scala Maven / Gradle / Ivy
The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.
package com.microsoft.ml.spark.io.http
import java.net.{SocketException, URI}
import com.microsoft.ml.spark.core.env.StreamUtilities.using
import com.microsoft.ml.spark.core.schema.SparkBindings
import com.sun.net.httpserver.HttpExchange
import org.apache.commons.io.IOUtils
import org.apache.http._
import org.apache.http.client.methods._
import org.apache.http.entity.{ByteArrayEntity, StringEntity}
import org.apache.http.message.BasicHeader
import org.apache.spark.internal.{Logging => SLogging}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, lit, struct, typedLit, udf}
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.sql.{Column, Row}
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
case class HeaderData(name: String, value: String) {
def this(h: Header) = {
this(h.getName, h.getValue)
}
def toHTTPCore: Header = new BasicHeader(name, value)
}
object HeaderData extends SparkBindings[HeaderData]
case class EntityData(content: Array[Byte],
contentEncoding: Option[HeaderData],
contentLength: Option[Long],
contentType: Option[HeaderData],
isChunked: Boolean,
isRepeatable: Boolean,
isStreaming: Boolean) {
def this(e: HttpEntity) = {
this(
try {
IOUtils.toByteArray(e.getContent)
} catch {
case _: SocketException => Array() //TODO investigate why sockets fail sometimes
},
Option(e.getContentEncoding).map(new HeaderData(_)),
Option(e.getContentLength),
Option(e.getContentType).map(new HeaderData(_)),
e.isChunked,
e.isRepeatable,
e.isStreaming)
}
def toHttpCore: HttpEntity = {
val e = new ByteArrayEntity(content)
contentEncoding.foreach { ce => e.setContentEncoding(ce.toHTTPCore) }
contentLength.foreach { cl => assert(e.getContentLength == cl)}
contentType.foreach(h => e.setContentType(h.toHTTPCore))
e.setChunked(isChunked)
assert(e.isRepeatable == isRepeatable)
assert(e.isStreaming == isStreaming)
e
}
}
object EntityData extends SparkBindings[EntityData]
case class StatusLineData(protocolVersion: ProtocolVersionData,
statusCode: Int,
reasonPhrase: String) {
def this(s: StatusLine) = {
this(new ProtocolVersionData(s.getProtocolVersion),
s.getStatusCode,
s.getReasonPhrase)
}
}
object StatusLineData extends SparkBindings[StatusLineData]
case class HTTPResponseData(headers: Array[HeaderData],
entity: Option[EntityData],
statusLine: StatusLineData,
locale: String) {
def this(response: CloseableHttpResponse) = {
this(response.getAllHeaders.map(new HeaderData(_)),
Option(response.getEntity).map(new EntityData(_)),
new StatusLineData(response.getStatusLine),
response.getLocale.toString)
}
def respondToHTTPExchange(request: HttpExchange): Unit = {
val responseHeaders = request.getResponseHeaders
val headersToAdd = headers ++ Seq(
entity.flatMap(_.contentType),
entity.flatMap(_.contentEncoding)).flatten
if (headersToAdd.nonEmpty) {
headersToAdd.foreach(h => responseHeaders.add(h.name, h.value))
}
try {
request.sendResponseHeaders(statusLine.statusCode,
entity.flatMap(_.contentLength).getOrElse(0L))
} catch {
case e: java.io.IOException =>
HTTPResponseData.warn(s"Could not write headers: ${e.getMessage}")
}
try {
entity.foreach(entity => using(request.getResponseBody) {
_.write(entity.content)
}.get)
} catch {
case e: java.io.IOException =>
HTTPResponseData.warn(s"Could not send bytes: ${e.getMessage}")
}
}
}
object HTTPResponseData extends SparkBindings[HTTPResponseData] with SLogging {
def warn(msg: => String): Unit = logWarning(msg)
}
case class ProtocolVersionData(protocol: String, major: Int, minor: Int) {
def this(v: ProtocolVersion) = {
this(v.getProtocol, v.getMajor, v.getMinor)
}
def toHTTPCore: ProtocolVersion = {
new ProtocolVersion(protocol, major, minor)
}
}
object ProtocolVersionData extends SparkBindings[ProtocolVersionData]
case class RequestLineData(method: String,
uri: String,
protoclVersion: Option[ProtocolVersionData]) {
def this(l: RequestLine) = {
this(l.getMethod,
l.getUri,
Some(new ProtocolVersionData(l.getProtocolVersion)))
}
}
object RequestLineData extends SparkBindings[RequestLineData]
case class HTTPRequestData(requestLine: RequestLineData,
headers: Array[HeaderData],
entity: Option[EntityData]) {
def this(r: HttpRequestBase) = {
this(new RequestLineData(r.getRequestLine),
r.getAllHeaders.map(new HeaderData(_)),
r match {
case re: HttpEntityEnclosingRequestBase => Option(re.getEntity).map(new EntityData(_))
case _ => None
})
}
def toHTTPCore: HttpRequestBase = {
val request = requestLine.method.toUpperCase match {
case "GET" => new HttpGet()
case "HEAD" => new HttpHead()
case "DELETE" => new HttpDelete()
case "OPTIONS" => new HttpOptions()
case "TRACE" => new HttpTrace()
case "POST" => new HttpPost()
case "PUT" => new HttpPut()
case "PATCH" => new HttpPatch()
}
request match {
case re: HttpEntityEnclosingRequestBase =>
entity.foreach(e => re.setEntity(e.toHttpCore))
case _ if entity.isDefined =>
throw new IllegalArgumentException(s"Entity is defined but method is ${requestLine.method}")
case _ =>
}
request.setURI(new URI(requestLine.uri))
requestLine.protoclVersion.foreach(pv =>
request.setProtocolVersion(pv.toHTTPCore))
request.setHeaders(headers.map(_.toHTTPCore))
request
}
}
object HTTPRequestData extends SparkBindings[HTTPRequestData] {
def fromHTTPExchange(httpe: HttpExchange): HTTPRequestData = {
val requestHeaders = httpe.getRequestHeaders
val isChunked = Option(requestHeaders.getFirst("Transfer-Encoding")=="chunked").getOrElse(false)
HTTPRequestData(
RequestLineData(
httpe.getRequestMethod,
httpe.getRequestURI.toString,
Option(httpe.getProtocol).map{p =>
val Array(v, n) = p.split("/".toCharArray.head)
val Array(major, minor) = n.split(".".toCharArray.head)
ProtocolVersionData(v, major.toInt, minor.toInt)
}),
httpe.getRequestHeaders.asScala.flatMap {
case (k, vs) => vs.map(v => HeaderData(k,v))
}.toArray,
Some(EntityData(
IOUtils.toByteArray(httpe.getRequestBody),
Option(requestHeaders.getFirst("Content-Encoding")).map(HeaderData("Content-Encoding", _)),
Option(requestHeaders.getFirst("Content-Length")).map(_.toLong),
Option(requestHeaders.getFirst("Content-Type")).map(HeaderData("Content-Type", _)),
isChunked = isChunked,
isRepeatable = false,
isStreaming = isChunked
))
)
}
}
object HTTPSchema {
val Response: DataType = HTTPResponseData.schema
val Request: DataType = HTTPRequestData.schema
//Convenience Functions for making and parsing HTTP objects
//scalastyle:off
private def stringToEntity(s: String): EntityData = {
new EntityData(new StringEntity(s, "UTF-8"))
}
private def binaryToEntity(arr: Array[Byte]): EntityData = {
new EntityData(new ByteArrayEntity(arr))
}
private def entityToString(e: EntityData): Option[String] = {
if (e.content.isEmpty) {
None
} else {
Some(IOUtils.toString(e.content,
e.contentEncoding.map(h => h.value).getOrElse("UTF-8")))
}
}
private def entity_to_string_udf: UserDefinedFunction = {
val fromRow = EntityData.makeFromRowConverter
udf({ x: Row =>
val sOpt = Option(x).flatMap(r => entityToString(fromRow(r)))
sOpt.orNull
}, StringType)
}
def entity_to_string(c: Column): Column = entity_to_string_udf(c)
private val string_to_entity_udf: UserDefinedFunction =
udf({ x: String => stringToEntity(x) }, EntityData.schema)
def string_to_entity(c: Column): Column = string_to_entity_udf(c)
private def request_to_string_udf: UserDefinedFunction = {
val fromRow = HTTPRequestData.makeFromRowConverter
udf({ x: Row =>
val sOpt = Option(x)
.flatMap(r => fromRow(r).entity)
.map(entityToString)
sOpt.orNull
}, StringType)
}
def request_to_string(c: Column): Column = request_to_string_udf(c)
def stringToResponse(x: String, code: Int, reason: String): HTTPResponseData = {
HTTPResponseData(
Array(),
Some(stringToEntity(x)),
StatusLineData(null, code, reason),
"en")
}
private val string_to_response_udf: UserDefinedFunction =
udf(stringToResponse _, HTTPResponseData.schema)
def string_to_response(str: Column, code: Column = lit(200), reason: Column = lit("Success")): Column =
string_to_response_udf(str, code, reason)
def emptyResponse(code: Int, reason: String): HTTPResponseData = {
HTTPResponseData(
Array(),
None,
StatusLineData(null, code, reason),
"en")
}
private val empty_response_udf: UserDefinedFunction =
udf(emptyResponse _, HTTPResponseData.schema)
def empty_response(code: Column = lit(200), reason: Column = lit("Success")): Column =
empty_response_udf(code, reason)
def binaryToResponse(x: Array[Byte], code: Int, reason: String): HTTPResponseData = {
HTTPResponseData(
Array(),
Some(binaryToEntity(x)),
StatusLineData(null, code, reason),
"en")
}
private val binary_to_response_udf: UserDefinedFunction =
udf(binaryToResponse _, HTTPResponseData.schema)
def binary_to_response(ba: Column, code: Column = lit(200), reason: Column = lit("Success")): Column =
binary_to_response_udf(ba, code, reason)
def to_http_request(urlCol: Column, headersCol: Column, methodCol: Column, jsonEntityCol: Column): Column = {
val pvd: Option[ProtocolVersionData] = None
struct(
struct(
methodCol.alias("method"),
urlCol.alias("uri"),
typedLit(pvd).alias("protocolVersion")).alias("requestLine"),
headersCol.alias("headers"),
string_to_entity(jsonEntityCol).alias("entity")
).cast(Request)
}
def to_http_request(urlCol: String, headersCol: String, methodCol: String, jsonEntityCol: String): Column = {
to_http_request(col(urlCol), col(headersCol), col(methodCol), col(jsonEntityCol))
}
//scalastyle:on
}