All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.exasol.spark.util.Converter.scala Maven / Gradle / Ivy

The newest version!
package com.exasol.spark.util

import java.sql.PreparedStatement
import java.sql.ResultSet

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
 * A helper object with functions to convert JDBC [[java.sql.ResultSet]] into
 * Spark [[org.apache.spark.sql.Row]] or vice versa.
 *
 * Most of the functions here are adapted from
 * `spark/sql/execution/datasources/jdbc/JdbcUtils.scala` class.
 */
object Converter extends Logging {

  /**
   * Converts a [[java.sql.ResultSet]] into an iterator of
   * [[org.apache.spark.sql.Row]]-s.
   */
  def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = {
    val encoder = RowEncoder(schema).resolveAndBind()
    val internalRows = resultSetToSparkInternalRows(resultSet, schema)
    internalRows.map(encoder.fromRow)
  }

  @SuppressWarnings(Array("org.wartremover.warts.AsInstanceOf"))
  def resultSetToSparkInternalRows(
    resultSet: ResultSet,
    schema: StructType
  ): Iterator[InternalRow] = new NextIterator[InternalRow] {
    private[this] val rs = resultSet
    private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
    private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))

    override protected def close(): Unit =
      try {
        rs.close()
      } catch {
        case e: Exception => logWarning("Exception closing resultset", e)
      }

    override protected def getNext(): InternalRow =
      if (rs.next()) {
        var i = 0
        while (i < getters.length) {
          getters(i).apply(rs, mutableRow, i)
          if (rs.wasNull) mutableRow.setNullAt(i)
          i = i + 1
        }
        mutableRow
      } else {
        finished = true
        null.asInstanceOf[InternalRow] // scalastyle:ignore null
      }
  }

  // A `JDBCValueGetter` is responsible for getting a value from `ResultSet`
  // into a field for `MutableRow`. The last argument `Int` means the index for
  // the value to be set in the row and also used for the value in `ResultSet`.
  private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit

  /**
   * Creates `JDBCValueGetter`s according to
   * [[org.apache.spark.sql.types.StructType]], which can set each value from
   * `ResultSet` to each field of [[org.apache.spark.sql.catalyst.InternalRow]]
   * correctly.
   */
  private def makeGetters(schema: StructType): Array[JDBCValueGetter] =
    schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))

  // scalastyle:off null
  private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
    case BooleanType =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        row.setBoolean(pos, rs.getBoolean(pos + 1))

    case DateType =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        // DateTimeUtils.fromJavaDate does not handle null value, so we need to
        // check it.
        val dateVal = rs.getDate(pos + 1)
        if (dateVal != null) {
          row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
        } else {
          row.update(pos, null)
        }

    case dt: DecimalType =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        val decimal = nullSafeConvert[java.math.BigDecimal](
          rs.getBigDecimal(pos + 1),
          d => Decimal(d, dt.precision, dt.scale)
        )
        row.update(pos, decimal)

    case DoubleType =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        row.setDouble(pos, rs.getDouble(pos + 1))

    case FloatType =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        row.setFloat(pos, rs.getFloat(pos + 1))

    case IntegerType =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        row.setInt(pos, rs.getInt(pos + 1))

    case LongType if metadata.contains("binarylong") =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        val bytes = rs.getBytes(pos + 1)
        var ans = 0L
        var j = 0
        while (j < bytes.length) {
          ans = 256 * ans + (255 & bytes(j))
          j = j + 1
        }
        row.setLong(pos, ans)

    case LongType =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        row.setLong(pos, rs.getLong(pos + 1))

    case ShortType =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        row.setShort(pos, rs.getShort(pos + 1))

    case StringType =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))

    case TimestampType =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        val t = rs.getTimestamp(pos + 1)
        if (t != null) {
          row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
        } else {
          row.update(pos, null)
        }

    case BinaryType =>
      (rs: ResultSet, row: InternalRow, pos: Int) =>
        row.update(pos, rs.getBytes(pos + 1))

    case _ =>
      throw new IllegalArgumentException(
        s"Received an unsupported Spark type ${dt.catalogString}"
      )
  }
  // scalastyle:on null

  // scalastyle:off null
  private def nullSafeConvert[T](input: T, f: T => Any): Any =
    if (input == null) {
      null
    } else {
      f(input)
    }
  // scalastyle:on null

  // A `JDBCValueSetter` is responsible for setting a value from `Row` into a
  // field for `PreparedStatement`. The last argument `Int` means the index for
  // the value to be set in the SQL statement and also used for the value in
  // `Row`.
  private[spark] type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit

  private[spark] def makeSetter(dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))

    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))

    case DoubleType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDouble(pos + 1, row.getDouble(pos))

    case FloatType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setFloat(pos + 1, row.getFloat(pos))

    case ShortType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setShort(pos + 1, row.getShort(pos))

    case ByteType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setByte(pos + 1, row.getByte(pos))

    case BooleanType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBoolean(pos + 1, row.getBoolean(pos))

    case StringType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setString(pos + 1, row.getString(pos))

    case BinaryType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

    case TimestampType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))

    case DateType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))

    case _: DecimalType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

    case _ =>
      (_: PreparedStatement, _: Row, pos: Int) =>
        throw new IllegalArgumentException(s"Cannot translate non-null value for field $pos")
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy