All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.spotify.scio.jdbc.sharded.JdbcShardedSource.scala Maven / Gradle / Ivy

There is a newer version: 0.14.9
Show newest version
/*
 * Copyright 2020 Spotify AB.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.spotify.scio.jdbc.sharded

import java.sql.Connection
import java.sql.ResultSet
import org.apache.beam.sdk.coders.{Coder => BCoder}
import org.apache.beam.sdk.io.BoundedSource
import org.apache.beam.sdk.options.PipelineOptions
import org.apache.beam.sdk.transforms.windowing.BoundedWindow
import org.joda.time.Instant
import org.slf4j.LoggerFactory
import java.util.{List => jList}
import scala.jdk.CollectionConverters._

final private[jdbc] class JdbcShardedSource[T, S](
  private val readOptions: JdbcShardedReadOptions[T, S],
  coder: BCoder[T],
  private val query: Option[ShardQuery] = None
) extends BoundedSource[T] {

  private val ShardBoundsQueryTemplate = "SELECT min(%s) min, max(%s) max FROM %s"

  private val log = LoggerFactory.getLogger(this.getClass)

  private def getShardColumnRange: Option[Range[S]] = {
    val connection = JdbcUtils.createConnection(readOptions.connectionOptions)
    try {
      if (
        JdbcUtils
          .getIndexedColumns(connection, readOptions.tableName)
          .find(_.equalsIgnoreCase(readOptions.shardColumn))
          .isEmpty
      ) {

        throw new UnsupportedOperationException(
          s"Shard column '${readOptions.shardColumn}' isn't indexed. Sharding would be " +
            s"inefficient'"
        )
      }

      val query = ShardBoundsQueryTemplate.format(
        readOptions.shardColumn,
        readOptions.shardColumn,
        readOptions.tableName
      )
      log.info("Executing query = [{}]", query)
      val rs = connection.createStatement.executeQuery(query)

      if (rs.next) {
        Some(
          Range(
            readOptions.shard.columnValueDecoder(rs, "min"),
            readOptions.shard.columnValueDecoder(rs, "max")
          )
        )
      } else {
        log.warn("The table is empty. Nothing to read.")
        None
      }
    } finally {
      if (connection != null) {
        connection.close()
      }
    }
  }

  override def split(
    desiredBundleSizeBytes: Long,
    options: PipelineOptions
  ): jList[_ <: BoundedSource[T]] = {
    getShardColumnRange match {
      case None =>
        List.empty.asJava
      case Some(range) =>
        readOptions.shard
          .partition(range, readOptions.numShards)
          .map { query =>
            new JdbcShardedSource(readOptions, coder, Some(query))
          }
          .asJava
    }
  }

  override def getEstimatedSizeBytes(options: PipelineOptions): Long =
    // Returning 0 because there is no good estimate
    0

  override def getOutputCoder: BCoder[T] = coder

  override def createReader(options: PipelineOptions): BoundedSource.BoundedReader[T] =
    query match {
      case Some(q) =>
        new JdbcShardedReader(this, q)
      case None =>
        throw new UnsupportedOperationException(
          "Not possible to create a read for a source with " +
            "the empty query"
        )
    }

  private class JdbcShardedReader(source: JdbcShardedSource[T, S], query: ShardQuery)
      extends BoundedSource.BoundedReader[T] {
    private var connection: Connection = _
    private var resultSet: ResultSet = _

    override def start(): Boolean = {
      connection = JdbcUtils.createConnection(source.readOptions.connectionOptions)
      val statement = connection.createStatement(
        ResultSet.TYPE_FORWARD_ONLY,
        ResultSet.CONCUR_READ_ONLY
      )

      if (source.readOptions.fetchSize != JdbcShardedReadOptions.UnboundedFetchSize) {
        log.info("Setting a user defined fetch size: [%s]".format(source.readOptions.fetchSize))
        statement.setFetchSize(source.readOptions.fetchSize)
      }

      val queryString =
        ShardQuery.toSelectStatement(query, readOptions.tableName, readOptions.shardColumn)

      log.info(s"Running a query: [$queryString]")
      resultSet = statement.executeQuery(queryString)
      resultSet.next
    }

    override def advance: Boolean = resultSet.next

    override def getCurrent: T = source.readOptions.rowMapper(resultSet)

    override def getCurrentTimestamp: Instant = BoundedWindow.TIMESTAMP_MIN_VALUE

    override def close(): Unit = {
      if (connection != null) {
        connection.close()
        log.info("JDBC connection closed")
      }
    }

    override def getCurrentSource: BoundedSource[T] = source
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy