All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.api.r.SQLUtils.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.api.r

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util.{Locale, Map => JMap}

import scala.collection.JavaConverters._
import scala.util.matching.Regex

import org.apache.spark.SparkContext
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericRowWithSchema}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.ShowTablesCommand
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.types._

private[sql] object SQLUtils extends Logging {
  SerDe.setSQLReadObject(readSqlObject).setSQLWriteObject(writeSqlObject)

  private[this] def withHiveExternalCatalog(sc: SparkContext): SparkContext = {
    sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive")
    sc
  }

  def getOrCreateSparkSession(
      jsc: JavaSparkContext,
      sparkConfigMap: JMap[Object, Object],
      enableHiveSupport: Boolean): SparkSession = {
    val spark =
      if (enableHiveSupport &&
          jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) ==
            "hive" &&
          // Note that the order of conditions here are on purpose.
          // `SparkSession.hiveClassesArePresent` checks if Hive's `HiveConf` is loadable or not;
          // however, `HiveConf` itself has some static logic to check if Hadoop version is
          // supported or not, which throws an `IllegalArgumentException` if unsupported.
          // If this is checked first, there's no way to disable Hive support in the case above.
          // So, we intentionally check if Hive classes are loadable or not only when
          // Hive support is explicitly enabled by short-circuiting. See also SPARK-26422.
          SparkSession.hiveClassesArePresent) {
        SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate()
      } else {
        if (enableHiveSupport) {
          logWarning("SparkR: enableHiveSupport is requested for SparkSession but " +
            s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to " +
            "'hive', falling back to without Hive support.")
        }
        SparkSession.builder().sparkContext(jsc.sc).getOrCreate()
      }
    setSparkContextSessionConf(spark, sparkConfigMap)
    spark
  }

  def setSparkContextSessionConf(
      spark: SparkSession,
      sparkConfigMap: JMap[Object, Object]): Unit = {
    for ((name, value) <- sparkConfigMap.asScala) {
      spark.sessionState.conf.setConfString(name.toString, value.toString)
    }
    for ((name, value) <- sparkConfigMap.asScala) {
      spark.sparkContext.conf.set(name.toString, value.toString)
    }
  }

  def getSessionConf(spark: SparkSession): JMap[String, String] = {
    spark.conf.getAll.asJava
  }

  def getJavaSparkContext(spark: SparkSession): JavaSparkContext = {
    new JavaSparkContext(spark.sparkContext)
  }

  def createStructType(fields: Seq[StructField]): StructType = {
    StructType(fields)
  }

  // Support using regex in string interpolation
  private[this] implicit class RegexContext(sc: StringContext) {
    def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*)
  }

  def createStructField(name: String, dataType: String, nullable: Boolean): StructField = {
    val dtObj = CatalystSqlParser.parseDataType(dataType)
    StructField(name, dtObj, nullable)
  }

  def createDF(rdd: RDD[Array[Byte]], schema: StructType, sparkSession: SparkSession): DataFrame = {
    val num = schema.fields.length
    val rowRDD = rdd.map(bytesToRow(_, schema))
    sparkSession.createDataFrame(rowRDD, schema)
  }

  def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
    df.rdd.map(r => rowToRBytes(r))
  }

  private[this] def doConversion(data: Object, dataType: DataType): Object = {
    data match {
      case d: java.lang.Double if dataType == FloatType =>
        java.lang.Float.valueOf(d.toFloat)
      // Scala Map is the only allowed external type of map type in Row.
      case m: java.util.Map[_, _] => m.asScala
      case _ => data
    }
  }

  private[sql] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
    val bis = new ByteArrayInputStream(bytes)
    val dis = new DataInputStream(bis)
    val num = SerDe.readInt(dis)
    Row.fromSeq((0 until num).map { i =>
      doConversion(SerDe.readObject(dis, jvmObjectTracker = null), schema.fields(i).dataType)
    })
  }

  private[sql] def rowToRBytes(row: Row): Array[Byte] = {
    val bos = new ByteArrayOutputStream()
    val dos = new DataOutputStream(bos)

    val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray
    SerDe.writeObject(dos, cols, jvmObjectTracker = null)
    bos.toByteArray()
  }

  // Schema for DataFrame of serialized R data
  // TODO: introduce a user defined type for serialized R data.
  val SERIALIZED_R_DATA_SCHEMA = StructType(Seq(StructField("R", BinaryType)))

  /**
   * The helper function for dapply() on R side.
   */
  def dapply(
      df: DataFrame,
      func: Array[Byte],
      packageNames: Array[Byte],
      broadcastVars: Array[Object],
      schema: StructType): DataFrame = {
    val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]])
    val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema
    df.mapPartitionsInR(func, packageNames, bv, realSchema)
  }

  /**
   * The helper function for gapply() on R side.
   */
  def gapply(
      gd: RelationalGroupedDataset,
      func: Array[Byte],
      packageNames: Array[Byte],
      broadcastVars: Array[Object],
      schema: StructType): DataFrame = {
    val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]])
    val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema
    gd.flatMapGroupsInR(func, packageNames, bv, realSchema)
  }


  def dfToCols(df: DataFrame): Array[Array[Any]] = {
    val localDF: Array[Row] = df.collect()
    val numCols = df.columns.length
    val numRows = localDF.length

    val colArray = new Array[Array[Any]](numCols)
    for (colNo <- 0 until numCols) {
      colArray(colNo) = new Array[Any](numRows)
      for (rowNo <- 0 until numRows) {
        colArray(colNo)(rowNo) = localDF(rowNo)(colNo)
      }
    }
    colArray
  }

  def readSqlObject(dis: DataInputStream, dataType: Char): Object = {
    dataType match {
      case 's' =>
        // Read StructType for DataFrame
        val fields = SerDe.readList(dis, jvmObjectTracker = null).asInstanceOf[Array[Object]]
        Row.fromSeq(fields)
      case _ => null
    }
  }

  def writeSqlObject(dos: DataOutputStream, obj: Object): Boolean = {
    obj match {
      // Handle struct type in DataFrame
      case v: GenericRowWithSchema =>
        dos.writeByte('s')
        SerDe.writeObject(dos, v.schema.fieldNames, jvmObjectTracker = null)
        SerDe.writeObject(dos, v.values, jvmObjectTracker = null)
        true
      case _ =>
        false
    }
  }

  def getTables(sparkSession: SparkSession, databaseName: String): DataFrame = {
    databaseName match {
      case n: String if n != null && n.trim.nonEmpty =>
        Dataset.ofRows(sparkSession, ShowTablesCommand(Some(n), None))
      case _ =>
        Dataset.ofRows(sparkSession, ShowTablesCommand(None, None))
    }
  }

  def getTableNames(sparkSession: SparkSession, databaseName: String): Array[String] = {
    val db = databaseName match {
      case _ if databaseName != null && databaseName.trim.nonEmpty =>
        databaseName
      case _ =>
        sparkSession.catalog.currentDatabase
    }
    sparkSession.sessionState.catalog.listTables(db).map(_.table).toArray
  }

  def createArrayType(column: Column): ArrayType = {
    new ArrayType(ExprUtils.evalTypeExpr(column.expr), true)
  }

  /**
   * R callable function to read a file in Arrow stream format and create an `RDD`
   * using each serialized ArrowRecordBatch as a partition.
   */
  def readArrowStreamFromFile(
      sparkSession: SparkSession,
      filename: String): JavaRDD[Array[Byte]] = {
    ArrowConverters.readArrowStreamFromFile(sparkSession.sqlContext, filename)
  }

  /**
   * R callable function to create a `DataFrame` from a `JavaRDD` of serialized
   * ArrowRecordBatches.
   */
  def toDataFrame(
      arrowBatchRDD: JavaRDD[Array[Byte]],
      schema: StructType,
      sparkSession: SparkSession): DataFrame = {
    ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, sparkSession.sqlContext)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy