All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.api.csharp.SQLUtils.scala Maven / Gradle / Ivy

The newest version!
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

package org.apache.spark.sql.api.csharp

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import org.apache.spark.SparkContext
import org.apache.spark.api.csharp.SerDe
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{DataType, FloatType, StructField, StructType}
import org.apache.spark.sql._

/**
 * Utility functions for DataFrame in SparkCLR
 * The implementation is mostly identical to the SQLUtils used by R
 * since CSharpSpark derives most of the design ideas and
 * implementation constructs from SparkR
 */
object SQLUtils {
  def createSQLContext(sc: SparkContext): SQLContext = {
    new SQLContext(sc)
  }

  def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = {
    new JavaSparkContext(sqlCtx.sparkContext)
  }

  def toSeq[T](arr: Array[T]): Seq[T] = {
    arr.toSeq
  }

  def createStructType(fields : Seq[StructField]): StructType = {
    StructType(fields)
  }

  def getSQLDataType(dataType: String): DataType = {
    dataType match {
      case "byte" => org.apache.spark.sql.types.ByteType
      case "integer" => org.apache.spark.sql.types.IntegerType
      case "float" => org.apache.spark.sql.types.FloatType
      case "double" => org.apache.spark.sql.types.DoubleType
      case "numeric" => org.apache.spark.sql.types.DoubleType
      case "character" => org.apache.spark.sql.types.StringType
      case "string" => org.apache.spark.sql.types.StringType
      case "binary" => org.apache.spark.sql.types.BinaryType
      case "raw" => org.apache.spark.sql.types.BinaryType
      case "logical" => org.apache.spark.sql.types.BooleanType
      case "boolean" => org.apache.spark.sql.types.BooleanType
      case "timestamp" => org.apache.spark.sql.types.TimestampType
      case "date" => org.apache.spark.sql.types.DateType
      case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
    }
  }

  def createStructField(name: String, dataType: String, nullable: Boolean): StructField = {
    val dtObj = getSQLDataType(dataType)
    StructField(name, dtObj, nullable)
  }

  def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = {
    val num = schema.fields.size
    val rowRDD = rdd.map(bytesToRow(_, schema))
    sqlContext.createDataFrame(rowRDD, schema)
  }

  def dfToRowRDD(df: DataFrame): RDD[Array[Byte]] = {
    df.map(r => rowToCSharpBytes(r))
  }

  private[this] def doConversion(data: Object, dataType: DataType): Object = {
    data match {
      case d: java.lang.Double if dataType == FloatType =>
        new java.lang.Float(d)
      case _ => data
    }
  }

  private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
    val bis = new ByteArrayInputStream(bytes)
    val dis = new DataInputStream(bis)
    val num = SerDe.readInt(dis)
    Row.fromSeq((0 until num).map { i =>
      doConversion(SerDe.readObject(dis), schema.fields(i).dataType)
    }.toSeq)
  }

  private[this] def rowToCSharpBytes(row: Row): Array[Byte] = {
    val bos = new ByteArrayOutputStream()
    val dos = new DataOutputStream(bos)

    SerDe.writeInt(dos, row.length)
    (0 until row.length).map { idx =>
      val obj: Object = row(idx).asInstanceOf[Object]
      SerDe.writeObject(dos, obj)
    }
    bos.toByteArray()
  }

  def dfToCols(df: DataFrame): Array[Array[Byte]] = {
    // localDF is Array[Row]
    val localDF = df.collect()
    val numCols = df.columns.length
    // dfCols is Array[Array[Any]]
    val dfCols = convertRowsToColumns(localDF, numCols)

    dfCols.map { col =>
      colToCSharpBytes(col)
    }
  }

  def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = {
    (0 until numCols).map { colIdx =>
      localDF.map { row =>
        row(colIdx)
      }
    }.toArray
  }

  def colToCSharpBytes(col: Array[Any]): Array[Byte] = {
    val numRows = col.length
    val bos = new ByteArrayOutputStream()
    val dos = new DataOutputStream(bos)

    SerDe.writeInt(dos, numRows)

    col.map { item =>
      val obj: Object = item.asInstanceOf[Object]
      SerDe.writeObject(dos, obj)
    }
    bos.toByteArray()
  }

  def saveMode(mode: String): SaveMode = {
    mode match {
      case "append" => SaveMode.Append
      case "overwrite" => SaveMode.Overwrite
      case "error" => SaveMode.ErrorIfExists
      case "ignore" => SaveMode.Ignore
    }
  }

  def loadDF(
              sqlContext: SQLContext,
              source: String,
              options: java.util.Map[String, String]): DataFrame = {
    sqlContext.read.format(source).options(options).load()
  }

  def loadDF(
              sqlContext: SQLContext,
              source: String,
              schema: StructType,
              options: java.util.Map[String, String]): DataFrame = {
    sqlContext.read.format(source).schema(schema).options(options).load()
  }

  def loadDF(
              sqlContext: SQLContext,
              source: String,
              schema: StructType): DataFrame = {
    sqlContext.read.format(source).schema(schema).load()
  }

  def loadTextFile(sqlContext: SQLContext, path: String, hasHeader: java.lang.Boolean, inferSchema: java.lang.Boolean) : DataFrame = {
    var dfReader = sqlContext.read.format("com.databricks.spark.csv")
    if (hasHeader)
    {
      dfReader = dfReader.option("header", "true")
    }
    if (inferSchema)
    {
      dfReader = dfReader.option("inferSchema", "true")
    }
      dfReader.load(path)
  }

  def loadTextFile(sqlContext: SQLContext, path: String, delimiter: String, schema: StructType) : DataFrame = {
    val stringRdd = sqlContext.sparkContext.textFile(path)
    val rowRdd = stringRdd.map{s =>
      val columns = s.split(delimiter)
      columns.length match {
        case 1 => RowFactory.create(columns(0))
        case 2 => RowFactory.create(columns(0),columns(1))
        case 3 => RowFactory.create(columns(0),columns(1),columns(2))
        case 4 => RowFactory.create(columns(0),columns(1),columns(2),columns(3))
        case 5 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4))
        case 6 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5))
        case 7 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6))
        case 8 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7))
        case 9 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8))
        case 10 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9))
        case 11 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10))
        case 12 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11))
        case 13 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12))
        case 14 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13))
        case 15 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14))
        case 16 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15))
        case 17 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16))
        case 18 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17))
        case 19 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18))
        case 20 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18),columns(19))
        case 21 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18),columns(19),columns(20))
        case 22 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18),columns(19),columns(20),columns(21))
        case 23 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18),columns(19),columns(20),columns(21),columns(22))
        case 24 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18),columns(19),columns(20),columns(21),columns(22),columns(23))
        case 25 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18),columns(19),columns(20),columns(21),columns(22),columns(23),columns(24))
        case 26 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18),columns(19),columns(20),columns(21),columns(22),columns(23),columns(24),columns(25))
        case 27 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18),columns(19),columns(20),columns(21),columns(22),columns(23),columns(24),columns(25),columns(26))
        case 28 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18),columns(19),columns(20),columns(21),columns(22),columns(23),columns(24),columns(25),columns(26),columns(27))
        case 29 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18),columns(19),columns(20),columns(21),columns(22),columns(23),columns(24),columns(25),columns(26),columns(27),columns(28))
        case 30 => RowFactory.create(columns(0),columns(1),columns(2),columns(3),columns(4),columns(5),columns(6),columns(7),columns(8),columns(9),columns(10),columns(11),columns(12),columns(13),columns(14),columns(15),columns(16),columns(17),columns(18),columns(19),columns(20),columns(21),columns(22),columns(23),columns(24),columns(25),columns(26),columns(27),columns(28),columns(29))
        case _ => throw new Exception("Text files with more than 30 columns currently not supported") //TODO - if requirement comes up, generate code for additional columns
      }
    }

    sqlContext.createDataFrame(rowRdd, schema)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy