source.rest_json_after_2.source Maven / Gradle / Ivy
The newest version!
package org.apache.spark.sql.execution.datasources.rest.json
import java.net.URL
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference}
import com.google.common.base.Objects
import net.sf.json.{JSONArray, JSONObject}
import org.apache.http.client.fluent.Request
import org.apache.http.util.EntityUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JacksonParser}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import streaming.common.JSONPath
class DefaultSource extends RelationProvider with DataSourceRegister {
override def shortName(): String = "restJSON"
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]
): BaseRelation = {
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
val url = parameters.getOrElse("url", parameters.getOrElse("path",""))
val xPath = parameters.getOrElse("xPath", "$")
val tryTimes = parameters.get("tryTimes").map(_.toInt).getOrElse(3)
new RestJSONRelation(None, url, xPath, samplingRatio, tryTimes, None)(sqlContext)
}
}
private[sql] class RestJSONRelation(
val inputRDD: Option[RDD[String]],
val url: String,
val xPath: String,
val samplingRatio: Double,
val tryFetchTimes: Int,
val maybeDataSchema: Option[StructType]
)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan {
/** Constraints to be imposed on schema to be stored. */
private def checkConstraints(schema: StructType): Unit = {
if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
val duplicateColumns = schema.fieldNames.groupBy(identity).collect {
case (x, ys) if ys.length > 1 => "\"" + x + "\""
}.mkString(", ")
throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " +
s"cannot save to JSON format")
}
}
override val needConversion: Boolean = false
override def schema: StructType = dataSchema
private def createBaseRdd(inputPaths: Array[String]): RDD[String] = {
val counter = new AtomicInteger()
val success = new AtomicBoolean(false)
val holder = new AtomicReference[RDD[String]]()
do {
tryCreateBaseRDD(inputPaths: Array[String]) match {
case Some(i) =>
counter.incrementAndGet()
success.set(true)
holder.set(i)
case None =>
counter.incrementAndGet()
throw new RuntimeException(s"try fetch ${inputPaths(0)} ${tryFetchTimes} times,still fail.")
}
} while (!success.get() && counter.get() < tryFetchTimes)
holder.get()
}
private def tryCreateBaseRDD(inputPaths: Array[String]): Option[RDD[String]] = {
val url = inputPaths.head
val res = Request.Get(new URL(url).toURI).execute()
val response = res.returnResponse()
val content = EntityUtils.toString(response.getEntity)
if (response != null && response.getStatusLine.getStatusCode == 200) {
import scala.collection.JavaConversions._
val extractContent = JSONArray.fromObject(JSONPath.read(content, xPath)).
map(f => JSONObject.fromObject(f).toString).toSeq
Some(sqlContext.sparkContext.makeRDD(extractContent))
} else {
None
}
}
lazy val dataSchema = {
val parsedOptions: JSONOptions = new JSONOptions(Map[String, String]())
val jsonSchema = maybeDataSchema.getOrElse {
org.apache.spark.sql.execution.datasources.json.InferSchema.infer(
inputRDD.getOrElse(createBaseRdd(Array(url))),
sqlContext.conf.columnNameOfCorruptRecord, parsedOptions)
}
checkConstraints(jsonSchema)
jsonSchema
}
def buildScan(): RDD[Row] = {
val parsedOptions: JSONOptions = new JSONOptions(Map[String, String]())
JacksonParser.parse(
inputRDD.getOrElse(createBaseRdd(Array(url))),
dataSchema,
sqlContext.conf.columnNameOfCorruptRecord,
parsedOptions
).asInstanceOf[RDD[Row]]
}
override def equals(other: Any): Boolean = other match {
case that: RestJSONRelation =>
((inputRDD, that.inputRDD) match {
case (Some(thizRdd), Some(thatRdd)) => thizRdd eq thatRdd
case (None, None) => true
case _ => false
}) && url == that.url &&
dataSchema == that.dataSchema &&
schema == that.schema
case _ => false
}
override def hashCode(): Int = {
Objects.hashCode(
inputRDD,
url,
dataSchema,
schema)
}
}