All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.springml.spark.salesforce.DatasetRelation.scala Maven / Gradle / Ivy

The newest version!
package com.springml.spark.salesforce

import java.math.BigDecimal
import java.net.URLEncoder
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Random

import com.springml.salesforce.wave.api.{ForceAPI, WaveAPI}
import org.apache.log4j.Logger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DateType, DecimalType}
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, LongType, ShortType}
import org.apache.spark.sql.types.{StructField, StructType, StringType, TimestampType}

import scala.collection.JavaConversions.{asScalaBuffer, mapAsScalaMap}

/**
 * Relation class for reading data from Salesforce and construct RDD
 */
case class DatasetRelation(
    waveAPI: WaveAPI,
    forceAPI: ForceAPI,
    query: String,
    userSchema: StructType,
    sqlContext: SQLContext,
    resultVariable: Option[String],
    pageSize: Int,
    sampleSize: Int,
    encodeFields: Option[String],
    inferSchema: Boolean,
    replaceDatasetNameWithId: Boolean,
    sdf: SimpleDateFormat) extends BaseRelation with TableScan {

  private val logger = Logger.getLogger(classOf[DatasetRelation])

  val records = read()

  def read(): java.util.List[java.util.Map[String, String]] = {
    var records: java.util.List[java.util.Map[String, String]]= null
    // Query getting executed here
    if (waveAPI != null) {
      records = queryWave()
    } else if (forceAPI != null) {
      records = querySF()
    }

    records
  }

  private def queryWave(): java.util.List[java.util.Map[String, String]] = {
    var records: java.util.List[java.util.Map[String, String]]= null

    var saql = query
    if (replaceDatasetNameWithId) {
      logger.debug("Original Query " + query)
      saql = replaceDatasetNameWithId(query, 0)
      logger.debug("Modified Query " + saql)
    }

    if (resultVariable == null || !resultVariable.isDefined) {
      val resultSet = waveAPI.query(saql)
      records = resultSet.getResults.getRecords
    } else {
      var resultSet = waveAPI.queryWithPagination(saql, resultVariable.get, pageSize)
      records = resultSet.getResults.getRecords

      while (!resultSet.isDone()) {
        resultSet = waveAPI.queryMore(resultSet)
        records.addAll(resultSet.getResults.getRecords)
      }
    }

    records
  }

  def replaceDatasetNameWithId(query : String, startIndex : Integer) : String = {
    var modQuery = query

    logger.debug("start Index : " + startIndex)
    logger.debug("query : " + query)
    val loadIndex = query.indexOf("load", startIndex)
    logger.debug("loadIndex : " + loadIndex + "\n")
    if (loadIndex != -1) {
      val startDatasetIndex = query.indexOf('\"', loadIndex + 1)
      val endDatasetIndex = query.indexOf('\"', startDatasetIndex + 1)
      val datasetName = query.substring(startDatasetIndex + 1, endDatasetIndex)

      val datasetId = waveAPI.getDatasetId(datasetName)
      if (datasetId != null) {
        modQuery = query.replaceAll(datasetName, datasetId)
      }

      modQuery = replaceDatasetNameWithId(modQuery, endDatasetIndex + 1)
    }

    modQuery
  }

  private def querySF(): java.util.List[java.util.Map[String, String]] = {
      var records: java.util.List[java.util.Map[String, String]]= null

      var resultSet = forceAPI.query(query)
      records = resultSet.filterRecords()

      while (!resultSet.isDone()) {
        resultSet = forceAPI.queryMore(resultSet)
        records.addAll(resultSet.filterRecords())
      }

      return records
  }

  private def cast(fieldValue: String, toType: DataType,
      nullable: Boolean = true, fieldName: String): Any = {
    if (fieldValue == "" && nullable && !toType.isInstanceOf[StringType]) {
      null
    } else {
      toType match {
        case _: ByteType => fieldValue.toByte
        case _: ShortType => fieldValue.toShort
        case _: IntegerType => fieldValue.toInt
        case _: LongType => fieldValue.toLong
        case _: FloatType => fieldValue.toFloat
        case _: DoubleType => fieldValue.toDouble
        case _: BooleanType => fieldValue.toBoolean
        case _: DecimalType => new BigDecimal(fieldValue.replaceAll(",", ""))
        case _: TimestampType => {
          if (sdf != null) {
            new Timestamp(sdf.parse(fieldValue).getTime)
          } else {
            Timestamp.valueOf(fieldValue)
          }
        }
        case _: DateType => Date.valueOf(fieldValue)
        case _: StringType => encode(fieldValue, fieldName)
        case _ => throw new RuntimeException(s"Unsupported data type: ${toType.typeName}")
      }
    }
  }

  private def encode(value: String, fieldName: String): String = {
    if (shouldEncode(fieldName)) {
      URLEncoder.encode(value, "UTF-8")
    } else {
      value
    }
  }

  private def shouldEncode(fieldName: String) : Boolean = {
    if (encodeFields != null && encodeFields.isDefined) {
      val toBeEncodedField = encodeFields.get.split(",")
      return toBeEncodedField.contains(fieldName)
    }

    false
  }

  private def sampleRDD: RDD[Array[String]] = {
    logger.debug("Sample Size : " + getSampleSize)
    // Constructing RDD from records
    val sampleRowArray = new Array[Array[String]](getSampleSize)
    for (i <- 0 to getSampleSize - 1) {
      val row = records(i)
      logger.debug("rows size : " + row.size())
      val fieldArray = new Array[String](row.size())

      var fieldIndex: Int = 0
      for (column <- row) {
        fieldArray(fieldIndex) = column._2
        fieldIndex = fieldIndex + 1
      }

      sampleRowArray(i) = fieldArray
    }

    // Converting the Array into RDD
    sqlContext.sparkContext.parallelize(sampleRowArray)
  }

  private def getSampleSize : Integer = {
    // If the record is less than sampleSize, then the whole data is used as sample
    val totalRecordsSize = records.size()
    logger.debug("Total Record Size: " + totalRecordsSize)
    if (totalRecordsSize < sampleSize) {
      logger.debug("Total Record Size " + totalRecordsSize
          + " is Smaller than Sample Size "
          + sampleSize + ". So total records are used for sampling")
      totalRecordsSize
    } else {
      sampleSize
    }
  }

  private def header: Array[String] = {
    val sampleList = sample

    var header : Array[String] = null
    for (currentRecord <- sampleList) {
      logger.debug("record size " + currentRecord.size())
      val recordHeader = new Array[String](currentRecord.size())
      var index: Int = 0
      for ((k, _) <- currentRecord) {
        logger.debug("Key " + k)
        recordHeader(index) = k
        index = index + 1
      }

      if (header == null || header.length < recordHeader.length) {
        header = recordHeader
      }
    }

    header
  }

  private def sample: java.util.List[java.util.Map[String, String]] = {
    val sampleRecords = new java.util.ArrayList[java.util.Map[String, String]]()
    val random = new Random()
    val totalSize = records.size()
    for (i <- 0 to getSampleSize) {
      sampleRecords += records.get(random.nextInt(totalSize))
    }

    sampleRecords
  }

  override def schema: StructType = {
    if (userSchema != null) {
      userSchema
    } else if (records == null || records.size() == 0) {
      new StructType()
    } else if (inferSchema) {
      InferSchema(sampleRDD, header, sdf)
    } else {
      val schemaHeader = header
      val structFields = new Array[StructField](schemaHeader.length)
      var index: Int = 0
      logger.debug("header size " + schemaHeader.length)
      for (fieldEntry <- schemaHeader) {
        logger.debug("header (" + index + ") = " + fieldEntry)
        structFields(index) = StructField(fieldEntry, StringType, nullable = true)
        index = index + 1
      }

      StructType(structFields)
    }
  }

  override def buildScan(): RDD[Row] = {
    val schemaFields = schema.fields
    logger.info("Total records size : " + records.size())
    val rowArray = new Array[Row](records.size())
    var rowIndex: Int = 0
    for (row <- records) {
      val fieldArray = new Array[Any](schemaFields.length)
      logger.debug("Total Fields length : " + schemaFields.length)
      var fieldIndex: Int = 0
      for (fields <- schemaFields) {
        val value = fieldValue(row, fields.name)
        logger.debug("fieldValue " + value)
        fieldArray(fieldIndex) = cast(value, fields.dataType, fields.nullable, fields.name)
        fieldIndex = fieldIndex + 1
      }

      logger.debug("rowIndex : " + rowIndex)
      rowArray(rowIndex) = Row.fromSeq(fieldArray)
      rowIndex = rowIndex + 1
    }
    sqlContext.sparkContext.parallelize(rowArray)
  }

  private def fieldValue(row: java.util.Map[String, String], name: String) : String = {
    if (row.contains(name)) {
      row(name)
    } else {
      logger.debug("Value not found for " + name)
      ""
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy