All Downloads are FREE. Search and download functionalities are using the official Maven repository.

source.rest_json_after_2.source Maven / Gradle / Ivy

The newest version!
package org.apache.spark.sql.execution.datasources.rest.json

import java.net.URL
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference}

import com.google.common.base.Objects
import net.sf.json.{JSONArray, JSONObject}
import org.apache.http.client.fluent.Request
import org.apache.http.util.EntityUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JacksonParser}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import streaming.common.JSONPath

class DefaultSource extends RelationProvider with DataSourceRegister {
  override def shortName(): String = "restJSON"

  override def createRelation(
                               sqlContext: SQLContext,
                               parameters: Map[String, String]
                               ): BaseRelation = {
    val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
    val url = parameters.getOrElse("url", parameters.getOrElse("path",""))
    val xPath = parameters.getOrElse("xPath", "$")
    val tryTimes = parameters.get("tryTimes").map(_.toInt).getOrElse(3)

    new RestJSONRelation(None, url, xPath, samplingRatio, tryTimes, None)(sqlContext)
  }

}

private[sql] class RestJSONRelation(
                                     val inputRDD: Option[RDD[String]],
                                     val url: String,
                                     val xPath: String,
                                     val samplingRatio: Double,
                                     val tryFetchTimes: Int,
                                     val maybeDataSchema: Option[StructType]
                                     )(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan {

  /** Constraints to be imposed on schema to be stored. */
  private def checkConstraints(schema: StructType): Unit = {
    if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
      val duplicateColumns = schema.fieldNames.groupBy(identity).collect {
        case (x, ys) if ys.length > 1 => "\"" + x + "\""
      }.mkString(", ")
      throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " +
        s"cannot save to JSON format")
    }
  }

  override val needConversion: Boolean = false

  override def schema: StructType = dataSchema

  private def createBaseRdd(inputPaths: Array[String]): RDD[String] = {
    val counter = new AtomicInteger()
    val success = new AtomicBoolean(false)
    val holder = new AtomicReference[RDD[String]]()
    do {
      tryCreateBaseRDD(inputPaths: Array[String]) match {
        case Some(i) =>
          counter.incrementAndGet()
          success.set(true)
          holder.set(i)
        case None =>
          counter.incrementAndGet()
          throw new RuntimeException(s"try fetch ${inputPaths(0)} ${tryFetchTimes} times,still fail.")
      }
    } while (!success.get() && counter.get() < tryFetchTimes)
    holder.get()
  }

  private def tryCreateBaseRDD(inputPaths: Array[String]): Option[RDD[String]] = {
    val url = inputPaths.head
    val res = Request.Get(new URL(url).toURI).execute()
    val response = res.returnResponse()
    val content = EntityUtils.toString(response.getEntity)
    if (response != null && response.getStatusLine.getStatusCode == 200) {
      import scala.collection.JavaConversions._
      val extractContent = JSONArray.fromObject(JSONPath.read(content, xPath)).
        map(f => JSONObject.fromObject(f).toString).toSeq
      Some(sqlContext.sparkContext.makeRDD(extractContent))
    } else {
      None
    }
  }

  lazy val dataSchema = {
    val parsedOptions: JSONOptions = new JSONOptions(Map[String, String]())
    val jsonSchema = maybeDataSchema.getOrElse {
      org.apache.spark.sql.execution.datasources.json.InferSchema.infer(
        inputRDD.getOrElse(createBaseRdd(Array(url))),
        sqlContext.conf.columnNameOfCorruptRecord, parsedOptions)
    }
    checkConstraints(jsonSchema)

    jsonSchema
  }

  def buildScan(): RDD[Row] = {
    val parsedOptions: JSONOptions = new JSONOptions(Map[String, String]())
    JacksonParser.parse(
      inputRDD.getOrElse(createBaseRdd(Array(url))),
      dataSchema,
      sqlContext.conf.columnNameOfCorruptRecord,
      parsedOptions
    ).asInstanceOf[RDD[Row]]
  }

  override def equals(other: Any): Boolean = other match {
    case that: RestJSONRelation =>
      ((inputRDD, that.inputRDD) match {
        case (Some(thizRdd), Some(thatRdd)) => thizRdd eq thatRdd
        case (None, None) => true
        case _ => false
      }) && url == that.url &&
        dataSchema == that.dataSchema &&
        schema == that.schema
    case _ => false
  }

  override def hashCode(): Int = {
    Objects.hashCode(
      inputRDD,
      url,
      dataSchema,
      schema)
  }


}







© 2015 - 2024 Weber Informatics LLC | Privacy Policy