com.spotify.scio.bigquery.BigQueryClient.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of scio-bigquery_2.10 Show documentation
Show all versions of scio-bigquery_2.10 Show documentation
Scio add-on for Google BigQuery
/*
* Copyright 2016 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.spotify.scio.bigquery
import java.io.{File, FileInputStream, StringReader}
import java.util.UUID
import java.util.regex.Pattern
import com.google.api.client.auth.oauth2.Credential
import com.google.api.client.googleapis.auth.oauth2.GoogleCredential
import com.google.api.client.googleapis.json.GoogleJsonResponseException
import com.google.api.client.http.javanet.NetHttpTransport
import com.google.api.client.http.{HttpRequest, HttpRequestInitializer}
import com.google.api.client.json.JsonObjectParser
import com.google.api.client.json.jackson2.JacksonFactory
import com.google.api.services.bigquery.model._
import com.google.api.services.bigquery.{Bigquery, BigqueryScopes}
import com.google.cloud.dataflow.sdk.io.BigQueryIO
import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.CreateDisposition._
import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.WriteDisposition._
import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.{CreateDisposition, WriteDisposition}
import com.google.cloud.dataflow.sdk.options.GcpOptions.DefaultProjectFactory
import com.google.cloud.dataflow.sdk.util.{BigQueryTableInserter, BigQueryTableRowIterator}
import com.google.cloud.hadoop.util.ApiErrorExtractor
import com.google.common.base.Charsets
import com.google.common.hash.Hashing
import com.google.common.io.Files
import org.apache.commons.io.FileUtils
import org.joda.time.format.{DateTimeFormat, PeriodFormatterBuilder}
import org.joda.time.{Instant, Period}
import org.slf4j.{Logger, LoggerFactory}
import scala.collection.mutable.{Map => MMap}
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
import scala.util.{Failure, Random, Success, Try}
/** Utility for BigQuery data types. */
object BigQueryUtil {
// Ported from com.google.cloud.dataflow.sdk.io.BigQueryIO
private val PROJECT_ID_REGEXP = "[a-z][-a-z0-9:.]{4,61}[a-z0-9]"
private val DATASET_REGEXP = "[-\\w.]{1,1024}"
private val TABLE_REGEXP = "[-\\w$@]{1,1024}"
private val DATASET_TABLE_REGEXP =
s"((?$PROJECT_ID_REGEXP):)?(?$DATASET_REGEXP)\\.(?$TABLE_REGEXP)"
private val QUERY_TABLE_SPEC = Pattern.compile(s"(?<=\\[)$DATASET_TABLE_REGEXP(?=\\])")
/** Parse a schema string. */
def parseSchema(schemaString: String): TableSchema =
new JsonObjectParser(new JacksonFactory)
.parseAndClose(new StringReader(schemaString), classOf[TableSchema])
}
/** A query job that may delay execution. */
private[scio] trait QueryJob {
def waitForResult(): Unit
val jobReference: Option[JobReference]
val query: String
val table: TableReference
}
/** A simple BigQuery client. */
// scalastyle:off number.of.methods
class BigQueryClient private (private val projectId: String,
credential: Credential = null) { self =>
def this(projectId: String, secretFile: File) =
this(
projectId,
GoogleCredential
.fromStream(new FileInputStream(secretFile))
.createScoped(BigQueryClient.SCOPES))
private lazy val bigquery: Bigquery = {
val c = Option(credential).getOrElse(
GoogleCredential.getApplicationDefault.createScoped(BigQueryClient.SCOPES))
val requestInitializer = new HttpRequestInitializer {
override def initialize(request: HttpRequest): Unit = {
BigQueryClient.connectTimeoutMs.foreach(request.setConnectTimeout)
BigQueryClient.readTimeoutMs.foreach(request.setReadTimeout)
// Credential also implements HttpRequestInitializer
c.initialize(request)
}
}
new Bigquery.Builder(new NetHttpTransport, new JacksonFactory, c)
.setHttpRequestInitializer(requestInitializer)
.setApplicationName("scio")
.build()
}
private val logger: Logger = LoggerFactory.getLogger(classOf[BigQueryClient])
private val TABLE_PREFIX = "scio_query"
private val TIME_FORMATTER = DateTimeFormat.forPattern("yyyyMMddHHmmss")
private val PERIOD_FORMATTER = new PeriodFormatterBuilder()
.appendHours().appendSuffix("h")
.appendMinutes().appendSuffix("m")
.appendSecondsWithOptionalMillis().appendSuffix("s")
.toFormatter
private val STAGING_DATASET_PREFIX = "scio_bigquery_staging_"
private val STAGING_DATASET_TABLE_EXPIRATION_MS = 86400000L
private val STAGING_DATASET_DESCRIPTION = "Staging dataset for temporary tables"
private def isInteractive =
Thread
.currentThread()
.getStackTrace
.exists { e =>
e.getClassName.startsWith("scala.tools.nsc.interpreter.") ||
e.getClassName.startsWith("org.scalatest.tools.")
}
private val PRIORITY = if (isInteractive) "INTERACTIVE" else "BATCH"
/** Get schema for a query without executing it. */
def getQuerySchema(sqlQuery: String): TableSchema = withCacheKey(sqlQuery) {
if (isLegacySql(sqlQuery, flattenResults = false)) {
// Dry-run not supported for legacy query, using view as a work around
logger.info("Getting legacy query schema with view")
val location = extractLocation(sqlQuery)
prepareStagingDataset(location)
val temp = temporaryTable(location)
// Create temporary table view and get schema
logger.info(s"Creating temporary view ${BigQueryIO.toTableSpec(temp)}")
val view = new ViewDefinition().setQuery(sqlQuery)
val viewTable = new Table().setView(view).setTableReference(temp)
val schema = bigquery
.tables().insert(temp.getProjectId, temp.getDatasetId, viewTable)
.execute().getSchema
// Delete temporary table
logger.info(s"Deleting temporary view ${BigQueryIO.toTableSpec(temp)}")
bigquery.tables().delete(temp.getProjectId, temp.getDatasetId, temp.getTableId).execute()
schema
} else {
// Get query schema via dry-run
logger.info("Getting SQL query schema with dry-run")
runQuery(
sqlQuery, null,
flattenResults = false, useLegacySql = false, dryRun = true)
.get.getStatistics.getQuery.getSchema
}
}
/** Get rows from a query. */
def getQueryRows(sqlQuery: String, flattenResults: Boolean = false): Iterator[TableRow] = {
val queryJob = newQueryJob(sqlQuery, flattenResults)
queryJob.waitForResult()
getTableRows(queryJob.table)
}
/** Get rows from a table. */
def getTableRows(tableSpec: String): Iterator[TableRow] =
getTableRows(BigQueryIO.parseTableSpec(tableSpec))
/** Get rows from a table. */
def getTableRows(table: TableReference): Iterator[TableRow] = new Iterator[TableRow] {
private val iterator = BigQueryTableRowIterator.fromTable(table, bigquery)
private var _isOpen = false
private var _hasNext = false
private def init(): Unit = if (!_isOpen) {
iterator.open()
_isOpen = true
_hasNext = iterator.advance()
}
override def hasNext: Boolean = {
init()
_hasNext
}
override def next(): TableRow = {
init()
if (_hasNext) {
val r = iterator.getCurrent
_hasNext = iterator.advance()
r
} else {
throw new NoSuchElementException
}
}
}
/** Get schema from a table. */
def getTableSchema(tableSpec: String): TableSchema =
getTableSchema(BigQueryIO.parseTableSpec(tableSpec))
/** Get schema from a table. */
def getTableSchema(table: TableReference): TableSchema =
withCacheKey(BigQueryIO.toTableSpec(table)) {
getTable(table).getSchema
}
/** Get table metadata **/
def getTable(tableSpec: String): Table =
getTable(BigQueryIO.parseTableSpec(tableSpec))
/** Get table metadata **/
def getTable(table: TableReference): Table = {
val p = if (table.getProjectId == null) this.projectId else table.getProjectId
bigquery.tables().get(p, table.getDatasetId, table.getTableId).execute()
}
/**
* Make a query and save results to a destination table.
*
* A temporary table will be created if `destinationTable` is `null` and a cached table will be
* returned instead if one exists.
*/
def query(sqlQuery: String,
destinationTable: String = null,
flattenResults: Boolean = false): TableReference =
if (destinationTable != null) {
val tableRef = BigQueryIO.parseTableSpec(destinationTable)
val queryJob = delayedQueryJob(sqlQuery, tableRef, flattenResults)
queryJob.waitForResult()
tableRef
} else {
val queryJob = newQueryJob(sqlQuery, flattenResults)
queryJob.waitForResult()
queryJob.table
}
/** Write rows to a table. */
def writeTableRows(table: TableReference, rows: List[TableRow], schema: TableSchema,
writeDisposition: WriteDisposition,
createDisposition: CreateDisposition): Unit = {
val inserter = new BigQueryTableInserter(bigquery)
inserter.getOrCreateTable(table, writeDisposition, createDisposition, schema)
inserter.insertAll(table, rows.asJava)
}
/** Write rows to a table. */
def writeTableRows(tableSpec: String, rows: List[TableRow], schema: TableSchema = null,
writeDisposition: WriteDisposition = WRITE_EMPTY,
createDisposition: CreateDisposition = CREATE_IF_NEEDED): Unit =
writeTableRows(
BigQueryIO.parseTableSpec(tableSpec), rows, schema, writeDisposition, createDisposition)
/** Wait for all jobs to finish. */
def waitForJobs(jobs: QueryJob*): Unit = {
val numTotal = jobs.size
var pendingJobs = jobs.filter(_.jobReference.isDefined)
while (pendingJobs.nonEmpty) {
val remainingJobs = pendingJobs.filter { j =>
val jobId = j.jobReference.get.getJobId
val poll = bigquery.jobs().get(projectId, jobId).execute()
val error = poll.getStatus.getErrorResult
if (error != null) {
throw new RuntimeException(s"Query job failed: id: $jobId, error: $error")
}
if (poll.getStatus.getState == "DONE") {
logJobStatistics(j.query, poll)
false
} else {
true
}
}
pendingJobs = remainingJobs
val numDone = numTotal - pendingJobs.size
logger.info(s"Query: $numDone out of $numTotal completed")
if (pendingJobs.nonEmpty) {
Thread.sleep(10000)
}
}
}
// =======================================================================
// Job execution
// =======================================================================
private[scio] def newQueryJob(sqlQuery: String, flattenResults: Boolean): QueryJob = {
try {
val sourceTimes = extractTables(sqlQuery).map(t => BigInt(getTable(t).getLastModifiedTime))
val temp = getCacheDestinationTable(sqlQuery).get
val time = BigInt(getTable(temp).getLastModifiedTime)
if (sourceTimes.forall(_ < time)) {
logger.info(s"Cache hit for query: `$sqlQuery`")
logger.info(s"Existing destination table: ${BigQueryIO.toTableSpec(temp)}")
new QueryJob {
override def waitForResult(): Unit = {}
override val jobReference: Option[JobReference] = None
override val query: String = sqlQuery
override val table: TableReference = temp
}
} else {
logger.info(s"Cache invalid for query: `$sqlQuery`")
logger.info(s"New destination table: ${BigQueryIO.toTableSpec(temp)}")
setCacheDestinationTable(sqlQuery, temp)
delayedQueryJob(sqlQuery, temp, flattenResults)
}
} catch {
case NonFatal(_) =>
val temp = temporaryTable(extractLocation(sqlQuery))
logger.info(s"Cache miss for query: `$sqlQuery`")
logger.info(s"New destination table: ${BigQueryIO.toTableSpec(temp)}")
setCacheDestinationTable(sqlQuery, temp)
delayedQueryJob(sqlQuery, temp, flattenResults)
}
}
private def prepareStagingDataset(location: String): Unit = {
val datasetId = STAGING_DATASET_PREFIX + location.toLowerCase
try {
bigquery.datasets().get(projectId, datasetId).execute()
logger.info(s"Staging dataset $projectId:$datasetId already exists")
} catch {
case e: GoogleJsonResponseException if new ApiErrorExtractor().itemNotFound(e) =>
logger.info(s"Creating staging dataset $projectId:$datasetId")
val dsRef = new DatasetReference().setProjectId(projectId).setDatasetId(datasetId)
val ds = new Dataset()
.setDatasetReference(dsRef)
.setDefaultTableExpirationMs(STAGING_DATASET_TABLE_EXPIRATION_MS)
.setDescription(STAGING_DATASET_DESCRIPTION)
.setLocation(location)
bigquery
.datasets()
.insert(projectId, ds)
.execute()
case NonFatal(e) => throw e
}
}
private def temporaryTable(location: String): TableReference = {
val now = Instant.now().toString(TIME_FORMATTER)
val tableId = TABLE_PREFIX + "_" + now + "_" + Random.nextInt(Int.MaxValue)
new TableReference()
.setProjectId(projectId)
.setDatasetId(STAGING_DATASET_PREFIX + location.toLowerCase)
.setTableId(tableId)
}
private def delayedQueryJob(sqlQuery: String,
destinationTable: TableReference,
flattenResults: Boolean): QueryJob = new QueryJob {
override def waitForResult(): Unit = self.waitForJobs(this)
override lazy val jobReference: Option[JobReference] = {
val location = extractLocation(sqlQuery)
prepareStagingDataset(location)
val isLegacy = isLegacySql(sqlQuery, flattenResults)
if (isLegacy) {
logger.info(s"Executing legacy query: `$sqlQuery`")
} else {
logger.info(s"Executing SQL query: `$sqlQuery`")
}
val tryRun = runQuery(sqlQuery, destinationTable, flattenResults, isLegacy, dryRun = false)
Some(tryRun.get.getJobReference)
}
override val query: String = sqlQuery
override val table: TableReference = destinationTable
}
private def logJobStatistics(sqlQuery: String, job: Job): Unit = {
val jobId = job.getJobReference.getJobId
val stats = job.getStatistics
logger.info(s"Query completed: jobId: $jobId")
logger.info(s"Query: `$sqlQuery`")
val elapsed = PERIOD_FORMATTER.print(new Period(stats.getEndTime - stats.getCreationTime))
val pending = PERIOD_FORMATTER.print(new Period(stats.getStartTime - stats.getCreationTime))
val execution = PERIOD_FORMATTER.print(new Period(stats.getEndTime - stats.getStartTime))
logger.info(s"Elapsed: $elapsed, pending: $pending, execution: $execution")
val bytes = FileUtils.byteCountToDisplaySize(stats.getQuery.getTotalBytesProcessed)
val cacheHit = stats.getQuery.getCacheHit
logger.info(s"Total bytes processed: $bytes, cache hit: $cacheHit")
}
// =======================================================================
// Query handling
// =======================================================================
private val dryRunCache: MMap[(String, Boolean, Boolean), Try[Job]] = MMap.empty
private def runQuery(sqlQuery: String,
destinationTable: TableReference,
flattenResults: Boolean,
useLegacySql: Boolean,
dryRun: Boolean): Try[Job] = {
def run = Try {
val queryConfig = new JobConfigurationQuery()
.setQuery(sqlQuery)
.setUseLegacySql(useLegacySql)
.setFlattenResults(flattenResults)
.setPriority(PRIORITY)
.setCreateDisposition("CREATE_IF_NEEDED")
.setWriteDisposition("WRITE_EMPTY")
if (!dryRun) {
queryConfig.setAllowLargeResults(true).setDestinationTable(destinationTable)
}
val jobConfig = new JobConfiguration().setQuery(queryConfig).setDryRun(dryRun)
val fullJobId = projectId + "-" + UUID.randomUUID().toString
val jobReference = new JobReference().setProjectId(projectId).setJobId(fullJobId)
val job = new Job().setConfiguration(jobConfig).setJobReference(jobReference)
bigquery.jobs().insert(projectId, job).execute()
}
if (dryRun) {
dryRunCache.getOrElseUpdate((sqlQuery, flattenResults, useLegacySql), run)
} else {
run
}
}
private def isLegacySql(sqlQuery: String, flattenResults: Boolean): Boolean = {
def isInvalidQuery(e: GoogleJsonResponseException): Boolean =
e.getDetails.getErrors.get(0).getReason == "invalidQuery"
def dryRunQuery(useLegacySql: Boolean): Try[Job] =
runQuery(sqlQuery, null, flattenResults, useLegacySql, dryRun = true)
// dry run with SQL syntax first
dryRunQuery(false) match {
case Success(_) => false
case Failure(e: GoogleJsonResponseException) if isInvalidQuery(e) =>
// dry run with legacy syntax next
dryRunQuery(true) match {
case Success(_) =>
logger.warn("Legacy syntax is deprecated, use SQL syntax instead. " +
"See https://cloud.google.com/bigquery/sql-reference/")
logger.warn(s"Legacy query: `$sqlQuery`")
true
case Failure(f) => throw f
}
case Failure(e) => throw e
}
}
/** Extract tables to be accessed by a query. */
def extractTables(sqlQuery: String): Set[TableReference] = {
val isLegacy = isLegacySql(sqlQuery, flattenResults = false)
val tryJob = runQuery(sqlQuery, null, flattenResults = false, isLegacy, dryRun = true)
tryJob.get.getStatistics.getQuery.getReferencedTables.asScala.toSet
}
/** Extract locations of tables to be access by a query. */
def extractLocation(sqlQuery: String): String = {
val locations = extractTables(sqlQuery)
.map(t => (t.getProjectId, t.getDatasetId))
.map { case (pId, dId) =>
val l = bigquery.datasets().get(pId, dId).execute().getLocation
if (l != null) l else "US"
}
require(locations.size == 1, "Tables in the query must be in the same location")
locations.head
}
// =======================================================================
// Schema and query caching
// =======================================================================
private def withCacheKey(key: String)
(method: => TableSchema): TableSchema = getCacheSchema(key) match {
case Some(schema) => schema
case None =>
val schema = method
setCacheSchema(key, schema)
schema
}
private def setCacheSchema(key: String, schema: TableSchema): Unit =
Files.write(schema.toPrettyString, schemaCacheFile(key), Charsets.UTF_8)
private def getCacheSchema(key: String): Option[TableSchema] = Try {
BigQueryUtil.parseSchema(scala.io.Source.fromFile(schemaCacheFile(key)).mkString)
}.toOption
private def setCacheDestinationTable(key: String, table: TableReference): Unit =
Files.write(BigQueryIO.toTableSpec(table), tableCacheFile(key), Charsets.UTF_8)
private def getCacheDestinationTable(key: String): Option[TableReference] = Try {
BigQueryIO.parseTableSpec(scala.io.Source.fromFile(tableCacheFile(key)).mkString)
}.toOption
private def cacheFile(key: String, suffix: String): File = {
val cacheDir = BigQueryClient.cacheDirectory
val filename = Hashing.murmur3_128().hashString(key, Charsets.UTF_8).toString + suffix
val cacheFile = new File(s"$cacheDir/$filename")
Files.createParentDirs(cacheFile)
cacheFile
}
private def schemaCacheFile(key: String): File = cacheFile(key, ".schema.json")
private def tableCacheFile(key: String): File = cacheFile(key, ".table.txt")
}
// scalastyle:on number.of.methods
/** Companion object for [[BigQueryClient]]. */
object BigQueryClient {
/** System property key for billing project. */
val PROJECT_KEY: String = "bigquery.project"
/** System property key for JSON secret path. */
val SECRET_KEY: String = "bigquery.secret"
/** System property key for local schema cache directory. */
val CACHE_DIRECTORY_KEY: String = "bigquery.cache.directory"
/** Default cache directory. */
val CACHE_DIRECTORY_DEFAULT: String = sys.props("user.dir") + "/.bigquery"
/**
* System property key for timeout in milliseconds to establish a connection.
* Default is 20000 (20 seconds). 0 for an infinite timeout.
*/
val CONNECT_TIMEOUT_MS_KEY: String = "bigquery.connect_timeout"
/**
* System property key for timeout in milliseconds to read data from an established connection.
* Default is 20000 (20 seconds). 0 for an infinite timeout.
*/
val READ_TIMEOUT_MS_KEY: String = "bigquery.read_timeout"
private val SCOPES = List(BigqueryScopes.BIGQUERY).asJava
private var instance: BigQueryClient = null
/**
* Get the default BigQueryClient instance.
*
* Project must be set via `bigquery.project` system property.
* An optional JSON secret file can be set via `bigquery.secret`.
* For example, by adding the following code at the beginning of a job:
* {{{
* sys.props("bigquery.project") = "my-project"
* sys.props("bigquery.secret") = "/path/to/secret.json"
* }}}
*
* Or by passing them as SBT command line arguments:
* {{{
* sbt -Dbigquery.project=my-project -Dbigquery.secret=/path/to/secret.json
* }}}
*/
def defaultInstance(): BigQueryClient = {
if (instance == null) {
instance = if (sys.props(PROJECT_KEY) != null) {
BigQueryClient(sys.props(PROJECT_KEY))
} else {
val project = new DefaultProjectFactory().create(null)
if (project != null) {
BigQueryClient(project)
} else {
throw new RuntimeException(
s"Property $PROJECT_KEY not set. Use -D$PROJECT_KEY=")
}
}
}
instance
}
/** Create a new BigQueryClient instance with the given project. */
def apply(project: String): BigQueryClient = {
val secret = sys.props(SECRET_KEY)
if (secret == null) {
new BigQueryClient(project)
} else {
BigQueryClient(project, new File(secret))
}
}
/** Create a new BigQueryClient instance with the given project and credential. */
def apply(project: String, credential: Credential): BigQueryClient =
new BigQueryClient(project, credential)
/** Create a new BigQueryClient instance with the given project and secret file. */
def apply(project: String, secretFile: File): BigQueryClient =
new BigQueryClient(project, secretFile)
private def cacheDirectory: String = getPropOrElse(CACHE_DIRECTORY_KEY, CACHE_DIRECTORY_DEFAULT)
private def connectTimeoutMs: Option[Int] = Option(sys.props(CONNECT_TIMEOUT_MS_KEY)).map(_.toInt)
private def readTimeoutMs: Option[Int] = Option(sys.props(READ_TIMEOUT_MS_KEY)).map(_.toInt)
private def getPropOrElse(key: String, default: String): String = {
val value = sys.props(key)
if (value == null) default else value
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy