org.apache.spark.dataset.DataCsvtLoader.scala Maven / Gradle / Ivy
package org.apache.spark.dataset
import com.datastax.insight.core.driver.SparkContextBuilder
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import scala.collection.mutable.ListBuffer
object DataCsvtLoader {
def loader(dataPath:String,header:Boolean,separator:String):DataFrame = {
val spark = SparkContextBuilder.getSession
val sep = separator.replace("\\", "\\\\").replace("*", "\\*")
.replace("+", "\\+").replace("|", "\\|")
.replace("{", "\\{").replace("}", "\\}")
.replace("(", "\\(").replace(")", "\\)")
.replace("^", "\\^").replace("$", "\\$")
.replace("[", "\\[").replace("]", "\\]")
.replace("?", "\\?").replace(",", "\\,")
.replace(".", "\\.").replace("&", "\\&")
val sc = SparkContextBuilder.getContext
val rdd:RDD[String] = sc.textFile(dataPath)
val headerLine = rdd.first()
val rddDropHeader = if (header) {
rdd.filter(_ != headerLine)
} else {
rdd
}
val dataResult = rddDropHeader.map(line => {
Row.fromSeq(line.split(sep))
})
val columns = if (header) {
headerLine.split(sep)
} else {
val columnList = new ListBuffer[String]
for (index <- 0 until headerLine.split(sep).length) {
columnList.append("_c" + index)
}
columnList.toArray
}
val schema = {
val fields = columns.map(fieldName => StructField(fieldName, StringType, nullable = true))
StructType(fields)
}
SparkContextBuilder.getSession.createDataFrame(dataResult,schema)
}
}