All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.alibaba.hologres.spark.source.BaseHoloPartitionReader.scala Maven / Gradle / Ivy

There is a newer version: 1.5.0
Show newest version
package com.alibaba.hologres.spark.source

import com.alibaba.hologres.client.model.TableSchema
import com.alibaba.hologres.org.postgresql.jdbc.PgConnection
import com.alibaba.hologres.spark.config.HologresConfigs
import com.alibaba.hologres.spark.exception.SparkHoloException
import com.alibaba.hologres.spark.utils.JDBCUtil
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types.StructType
import org.slf4j.LoggerFactory

import java.sql.{PreparedStatement, ResultSet, SQLException}

class BaseHoloPartitionReader(hologresConfigs: HologresConfigs,
                              shardIdRange: (Int, Int),
                              holoSchema: TableSchema,
                              sparkSchema: StructType) {
  private val logger = LoggerFactory.getLogger(getClass)

  private var conn: PgConnection = _
  private var statement: PreparedStatement = _
  private var resultSet: ResultSet = _
  private val recordLength: Int = sparkSchema.fields.length
  private val fieldReaders: Array[FieldReader] = {
    val fieldReaders = new Array[FieldReader](recordLength)
    for (i <- 0 until recordLength) {
      val holoColumn = holoSchema.getColumn(holoSchema.getColumnIndex(sparkSchema.fields.apply(i).name))
      fieldReaders.update(i, FieldReaderUtils.createFieldReader(holoColumn.getType, holoColumn.getTypeName))
    }
    fieldReaders
  }
  init()

  def init(): Unit = {
    val queryTemplate: String = JDBCUtil.getSimpleSelectFromStatement(holoSchema.getTableNameObj.getFullName, sparkSchema.fields.map(_.name))
    val query: String = queryTemplate + " WHERE hg_shard_id >= " + shardIdRange._1 + " and hg_shard_id < " + shardIdRange._2
    logger.info("the bulk read query: {}", query)
    logger.info("the sparkSchema: {}", sparkSchema)

    conn = JDBCUtil.createConnection(hologresConfigs).unwrap(classOf[PgConnection])
    conn.setAutoCommit(false)
    statement = conn.prepareStatement(query, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
    statement.setFetchSize(hologresConfigs.scan_batch_size)
    statement.setQueryTimeout(hologresConfigs.scan_timeout_seconds)
    resultSet = statement.executeQuery
  }

  def next(): Boolean = {
    resultSet.next
  }

  def get(): InternalRow = {
    convertHologresRecordToRow(resultSet)
  }

  private def convertHologresRecordToRow(resultSet: ResultSet): InternalRow = {
    val res: Array[Any] = new Array[Any](recordLength)
    for (i <- 0 until recordLength) {
      if (resultSet.getObject(i + 1) == null) {
        res(i) = null
      } else {
        res(i) = fieldReaders.apply(i).readValue(resultSet, i + 1)
      }
    }
    new GenericInternalRow(res)
  }

  def close(): Unit = {
    if (conn != null) {
      try {
        conn.close()
      } catch {
        case e: SQLException =>
          throw new SparkHoloException(e)
      } finally {
        conn = null
      }
    }
    logger.debug("Close....")
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy