All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.spotify.scio.util.ScioUtil.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2019 Spotify AB.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.spotify.scio.util

import java.net.URI
import java.util.UUID
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import com.spotify.scio.ScioContext
import com.spotify.scio.coders.{Coder, CoderMaterializer}
import com.spotify.scio.values.SCollection
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions
import org.apache.beam.sdk.extensions.gcp.util.Transport
import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy
import org.apache.beam.sdk.io.FileIO.Write.FileNaming
import org.apache.beam.sdk.io.{DefaultFilenamePolicy, FileBasedSink, FileIO, FileSystems}
import org.apache.beam.sdk.io.fs.ResourceId
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider
import org.apache.beam.sdk.util.common.ElementByteSizeObserver
import org.apache.beam.sdk.values.WindowingStrategy
import org.apache.beam.sdk.{PipelineResult, PipelineRunner}
import org.apache.commons.lang3.StringUtils
import org.slf4j.LoggerFactory

import scala.collection.compat.immutable.ArraySeq
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

private[scio] object ScioUtil {
  @transient private lazy val log = LoggerFactory.getLogger(this.getClass)
  @transient lazy val jsonFactory = Transport.getJsonFactory

  def isLocalUri(uri: URI): Boolean =
    uri.getScheme == null || uri.getScheme == "file"

  def isRemoteUri(uri: URI): Boolean = !isLocalUri(uri)

  def isLocalRunner(runner: Class[_ <: PipelineRunner[_ <: PipelineResult]]): Boolean = {
    require(runner != null, "Pipeline runner not set!")
    // FIXME: cover Flink, Spark, etc. in local mode
    runner.getName == "org.apache.beam.runners.direct.DirectRunner"
  }

  def isRemoteRunner(runner: Class[_ <: PipelineRunner[_ <: PipelineResult]]): Boolean =
    !isLocalRunner(runner)

  def classOf[T: ClassTag]: Class[T] =
    implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]

  def getScalaJsonMapper: ObjectMapper =
    new ObjectMapper().registerModule(DefaultScalaModule)

  def getTempFile(context: ScioContext, fileOrPath: String = null): String = {
    val fop = Option(fileOrPath).getOrElse("scio-materialize-" + UUID.randomUUID().toString)
    val uri = URI.create(fop)
    if ((ScioUtil.isLocalUri(uri) && uri.toString.startsWith("/")) || uri.isAbsolute) {
      fop
    } else {
      val filename = fop
      val tmpDir = if (context.options.getTempLocation != null) {
        context.options.getTempLocation
      } else {
        val m =
          "Specify a temporary location via --tempLocation or PipelineOptions.setTempLocation."
        Try(context.optionsAs[GcpOptions].getGcpTempLocation) match {
          case Success(l) =>
            log.warn(
              "Using GCP temporary location as a temporary location to materialize data. " + m
            )
            l
          case Failure(_) =>
            throw new IllegalArgumentException("No temporary location was specified. " + m)
        }
      }
      tmpDir + (if (tmpDir.endsWith("/")) "" else "/") + filename
    }
  }

  private def stripPath(path: String): String = StringUtils.stripEnd(path, "/")
  def strippedPath(path: String): String = {
    require(path != null, "Path must not be null")
    s"${stripPath(path)}/"
  }
  def pathWithPrefix(path: String, prefix: String): String = {
    require(path != null, "Path must not be null")
    stripPath(path) + "/" + Option(prefix).getOrElse("part")
  }

  def filePattern(path: String, suffix: String): String = {
    require(path != null, "Path must not be null")
    Option(suffix) match {
      case Some(_) if path.contains("*") =>
        // path is already a pattern
        throw new IllegalArgumentException(s"Suffix must be used with a static path but got: $path")
      case Some(s) =>
        // match all file with suffix in path (must be a folder)
        stripPath(path) + "/*" + s
      case None =>
        path
    }
  }

  def consistentHashCode[K](k: K): Int = k match {
    case key: Array[_] => ArraySeq.unsafeWrapArray(key).##
    case key           => key.##
  }

  def toResourceId(directory: String): ResourceId =
    FileSystems.matchNewResource(directory, true)

  def defaultFilenamePolicy(
    path: String,
    prefix: String,
    shardTemplate: String,
    suffix: String,
    isWindowed: Boolean
  ): FilenamePolicy = {
    val prefixedPath = pathWithPrefix(path, prefix)
    val resource = FileBasedSink.convertToFileResourceIfPossible(prefixedPath)
    val baseFileName = StaticValueProvider.of(resource)
    DefaultFilenamePolicy.fromStandardParameters(baseFileName, shardTemplate, suffix, isWindowed)
  }

  def defaultNaming(
    prefix: String,
    suffix: String
  )(destination: String): FileNaming = {
    val prefixedPath = pathWithPrefix(destination, prefix)
    FileIO.Write.defaultNaming(prefixedPath, suffix)
  }

  def tempDirOrDefault(tempDirectory: String, sc: ScioContext): ResourceId = {
    Option(tempDirectory)
      .orElse(Option(sc.options.getTempLocation))
      .orElse(Try(sc.optionsAs[GcpOptions]).toOption.flatMap(x => Option(x.getGcpTempLocation)))
      .map(toResourceId)
      .getOrElse(
        throw new IllegalArgumentException(
          "No temporary location was specified. Specify a temporary location via --tempLocation or PipelineOptions.setTempLocation."
        )
      )
  }

  def isWindowed(coll: SCollection[_]): Boolean =
    coll.internal.getWindowingStrategy != WindowingStrategy.globalDefault()

  private class ByteSizeObserver extends ElementByteSizeObserver {
    private var elementByteSize: Long = 0L
    override def reportElementSize(elementByteSize: Long): Unit =
      this.elementByteSize += elementByteSize
    def getElementByteSize: Long = elementByteSize
  }

  def elementByteSize[T: Coder](context: ScioContext): T => Long = {
    val bCoder = CoderMaterializer.beam(context, Coder[T])

    { (e: T) =>
      val observer = new ByteSizeObserver()
      bCoder.registerByteSizeObserver(e, observer)
      observer.advance()
      observer.getElementByteSize
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy