All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.spotify.scio.extra.sparkey.SparkeyIO.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2023 Spotify AB.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.spotify.scio.extra.sparkey

import java.lang.Math.floorMod
import java.util.UUID
import com.spotify.scio.extra.sparkey.instances._
import com.spotify.scio.io.{TapOf, TapT, TestIO}
import com.spotify.scio.util.RemoteFileUtil
import com.spotify.scio.values.SCollection
import com.spotify.sparkey.{CompressionType, SparkeyReader}
import org.apache.beam.sdk.io.FileSystems
import org.apache.beam.sdk.io.fs.MoveOptions.StandardMoveOptions
import org.apache.beam.sdk.io.fs.{EmptyMatchTreatment, ResourceId}
import org.slf4j.LoggerFactory

import scala.collection.mutable
import scala.jdk.CollectionConverters._

/** Special version of [[com.spotify.scio.io.ScioIO]] for use with sparkey methods. */
private[sparkey] case class SparkeyTestIO[T](path: String) extends TestIO[T] {
  override val tapT: TapT.Aux[T, T] = TapOf[T]
}

object SparkeyIO {
  @transient private lazy val logger = LoggerFactory.getLogger(this.getClass)

  val DefaultNumShards: Short = 1
  val DefaultSideInputNumShards: Short = 64
  val DefaultCompressionType: CompressionType = CompressionType.NONE
  val DefaultCompressionBlockSize: Int = 0

  def apply(path: String): SparkeyTestIO[SparkeyReader] = SparkeyTestIO[SparkeyReader](path)
  def output[K, V](path: String): SparkeyTestIO[(K, V)] = SparkeyTestIO[(K, V)](path)

  private def writeToSparkey[K, V](
    uri: SparkeyUri,
    rfu: RemoteFileUtil,
    maxMemoryUsage: Long,
    compressionType: CompressionType,
    compressionBlockSize: Int,
    elements: Iterable[(K, V)],
    w: SparkeyWritable[K, V]
  ): SparkeyUri = {
    val writer =
      new SparkeyWriter(uri, rfu, compressionType, compressionBlockSize, maxMemoryUsage)
    val it = elements.iterator
    while (it.hasNext) {
      val kv = it.next()
      w.put(writer, kv._1, kv._2)
    }
    writer.close()
    uri
  }

  /** @param baseUri The final destination for sparkey files */
  private[sparkey] def writeSparkey[K, V](
    baseUri: SparkeyUri,
    writable: SparkeyWritable[K, V],
    data: SCollection[(K, V)],
    maxMemoryUsage: Long,
    numShards: Short,
    compressionType: CompressionType,
    compressionBlockSize: Int
  ): SCollection[SparkeyUri] = {
    require(
      !baseUri.isSharded,
      s"path to which sparkey will be saved must not include a `*` wildcard."
    )
    require(numShards > 0, s"numShards must be greater than 0, found $numShards")
    if (compressionType != CompressionType.NONE) {
      require(
        compressionBlockSize > 0,
        s"Compression block size must be > 0 for $compressionType"
      )
    }
    val sc = data.context
    val isUnsharded = numShards == 1
    val rfu = RemoteFileUtil.create(sc.options)
    val tempLocation = sc.options.getTempLocation

    // verify that we're not writing to a previously-used output dir
    List(baseUri, SparkeyUri(s"${baseUri.path}/*")).foreach { uri =>
      require(!uri.exists(rfu), s"Sparkey URI ${uri.path} already exists")
    }
    // root destination to which all _interim_ results are written,
    // deleted upon successful completion of the write
    val tempPath = s"$tempLocation/sparkey-temp-${UUID.randomUUID}"

    val outputUri = if (isUnsharded) baseUri else SparkeyUri(s"${baseUri.path}/*")
    logger.info(s"Saving as Sparkey with $numShards shards: ${baseUri.path}")

    def resourcesForPattern(pattern: String): mutable.Buffer[ResourceId] =
      FileSystems
        .`match`(pattern, EmptyMatchTreatment.ALLOW)
        .metadata()
        .asScala
        .map(_.resourceId())

    data.transform { collection =>
      // shard by key hash
      val shards = collection
        .groupBy { case (k, _) => floorMod(writable.shardHash(k), numShards.toInt).toShort }

      // all shards
      val allShards = sc
        .parallelize(0 until numShards.toInt)
        .map(_.toShort -> ())

      // write files to temporary locations
      val tempShardUris = shards
        .hashFullOuterJoin(allShards)
        .map { case (shard, (xs, _)) =>
          // use a temp uri so that if a bundle fails retries will not fail
          val tempUri = SparkeyUri(s"$tempPath/${UUID.randomUUID}")
          // perform the write to the temp uri
          shard -> writeToSparkey(
            tempUri.sparkeyUriForShard(shard, numShards),
            rfu,
            maxMemoryUsage,
            compressionType,
            compressionBlockSize,
            xs.getOrElse(Iterable.empty),
            writable
          )
        }

      // TODO WriteFiles inserts a reshuffle here for unclear reasons

      tempShardUris.reifyAsListInGlobalWindow
        .map { seq =>
          val items = seq.toList

          // accumulate source files and destination files
          val (srcPaths, dstPaths) = items
            .foldLeft((List.empty[ResourceId], List.empty[ResourceId])) {
              case ((srcs, dsts), (shard, uri)) =>
                if (isUnsharded && shard != 0)
                  throw new IllegalArgumentException(s"numShards=1 but got shard=$shard")
                // assumes paths always returns things in the same order 🙃
                val dstUri =
                  if (isUnsharded) baseUri else baseUri.sparkeyUriForShard(shard, numShards)

                val srcResources = srcs ++ uri.paths
                val dstResources = dsts ++ dstUri.paths

                (srcResources, dstResources)
            }

          // rename source files to dest files
          logger.info(s"Copying ${items.size} files from temp to final GCS destination.")
          // per FileBasedSink.java#783 ignore errors as files may have previously been deleted
          FileSystems.rename(
            srcPaths.asJava,
            dstPaths.asJava,
            StandardMoveOptions.IGNORE_MISSING_FILES,
            StandardMoveOptions.SKIP_IF_DESTINATION_EXISTS
          )

          // cleanup orphan files per FileBasedSink.removeTemporaryFiles
          val orphanTempFiles = resourcesForPattern(s"${tempPath}/*")
          orphanTempFiles.foreach { r =>
            logger.warn("Will also remove unknown temporary file {}.", r)
          }
          FileSystems.delete(orphanTempFiles.asJava, StandardMoveOptions.IGNORE_MISSING_FILES)
          // clean up temp dir, can fail, but failure is to be ignored per FileBasedSink
          val tempPathResource = resourcesForPattern(tempPath)
          try {
            FileSystems.delete(tempPathResource.asJava, StandardMoveOptions.IGNORE_MISSING_FILES)
          } catch {
            case _: Exception =>
              logger.warn("Failed to remove temporary directory: [{}].", tempPath)
          }

          outputUri
        }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy