All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.sourced.engine.iterator.MetadataIterator.scala Maven / Gradle / Ivy

The newest version!
package tech.sourced.engine.iterator

import java.sql.{Connection, DriverManager, ResultSet}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.types.{ArrayType, BooleanType, TimestampType}

class MetadataIterator(fields: Seq[Attribute],
                       dbPath: String,
                       sql: String)
  extends Iterator[Map[String, Any]] with Logging {

  private val iter = new JDBCQueryIterator(fields, dbPath, sql)

  override def hasNext: Boolean = iter.hasNext

  override def next(): Map[String, Any] = {
    val values = iter.next()
    Map[String, Any](fields.zipWithIndex.map {
      case (attr, idx) if attr.dataType == BooleanType =>
        (attr.name, values(idx) match {
          case 0 => false
          case 1 => true
          case _ => null
        })
      case (attr, idx) if attr.dataType.isInstanceOf[ArrayType] =>
        (attr.name, values(idx).toString.split("\\|"))
      case (attr, idx) if attr.dataType == TimestampType =>
        (attr.name, new java.sql.Timestamp(values(idx).asInstanceOf[Long]))
      case (attr, idx) =>
        (attr.name, values(idx))
    }: _*)
  }

  def close(): Unit = iter.close()

}

class JDBCQueryIterator(fields: Seq[Attribute],
                        dbPath: String,
                        sql: String)
  extends Iterator[Array[Any]] with Logging {

  private var rs: ResultSet = _
  private var conn: Connection = _
  private var nextCollected = false
  private var hasRows = false

  private[iterator] def close(): Unit = {
    try {
      if (rs != null && !rs.isClosed) {
        rs.close()
      }
    } finally {
      if (conn != null && !conn.isClosed) {
        try {
          conn.close()
        } catch {
          case e: Exception => log.warn(s"could not close connection", e)
        }
      }
    }
  }

  override def hasNext: Boolean = {
    if (rs == null) {
      initializeResultSet()
    } else if (hasRows && !nextCollected) {
      // FIXME: RDD groupBy somehow calls hasNext twice, so we can't
      // advance the cursor until the next row has been collected to make sure
      // we don't skip rows.
      return true
    }

    try {
      if (!rs.isClosed && rs.next) {
        hasRows = true
        nextCollected = false
        true
      } else {
        close()
        false
      }
    } catch {
      case e: Exception =>
        log.warn(s"caught an exception in JDBCIterator.hasNext", e)
        close()
        false
    }
  }

  private def initializeResultSet(): Unit = {
    conn = DriverManager.getConnection(s"jdbc:sqlite:$dbPath")
    val stmt = conn.prepareStatement(sql)
    try {
      rs = stmt.executeQuery()
    } catch {
      case e: Exception =>
        log.warn(s"could not execute query", e)
        close()
    }
  }

  override def next(): Array[Any] = {
    nextCollected = true
    fields.zipWithIndex
      .map(f => rs.getObject(f._2 + 1))
      .toArray
      .asInstanceOf[Array[Any]]
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy