All Downloads are FREE. Search and download functionalities are using the official Maven repository.

geotrellis.spark.io.s3.S3RDDWriter.scala Maven / Gradle / Ivy

package geotrellis.spark.io.s3

import geotrellis.raster.Tile
import geotrellis.spark._
import geotrellis.spark.io._
import geotrellis.spark.io.avro._
import geotrellis.spark.io.avro.codecs.KeyValueRecordCodec
import geotrellis.spark.io.index.{ZCurveKeyIndexMethod, KeyIndexMethod, KeyIndex}
import geotrellis.spark.util.KryoWrapper

import com.amazonaws.services.s3.model.{AmazonS3Exception, PutObjectResult, ObjectMetadata, PutObjectRequest}
import com.typesafe.scalalogging.slf4j._
import org.apache.spark.rdd.RDD
import scalaz.concurrent.Task
import scalaz.stream.{Process, nondeterminism}
import spray.json._
import spray.json.DefaultJsonProtocol._

import java.io.ByteArrayInputStream
import java.util.concurrent.Executors
import scala.reflect._

trait S3RDDWriter {

  def getS3Client: () => S3Client

  def write[K: AvroRecordCodec: ClassTag, V: AvroRecordCodec: ClassTag](
    rdd: RDD[(K, V)],
    bucket: String,
    keyPath: K => String,
    putObjectModifier: PutObjectRequest => PutObjectRequest = { p => p }
  ): Unit = {
    val codec  = KeyValueRecordCodec[K, V]
    val schema = codec.schema

    implicit val sc = rdd.sparkContext

    val _getS3Client = getS3Client
    val _codec = codec

    val pathsToTiles =
      // Call groupBy with numPartitions; if called without that argument or a partitioner,
      // groupBy will reuse the partitioner on the parent RDD if it is set, which could be typed
      // on a key type that may no longer by valid for the key type of the resulting RDD.
      rdd.groupBy({ row => keyPath(row._1) }, numPartitions = rdd.partitions.length)

    pathsToTiles.foreachPartition { partition =>
      import geotrellis.spark.util.TaskUtils._
      val getS3Client = _getS3Client
      val s3client: S3Client = getS3Client()

      val requests: Process[Task, PutObjectRequest] =
        Process.unfold(partition){ iter =>
          if (iter.hasNext) {
            val recs = iter.next()
            val key = recs._1
            val pairs = recs._2.toVector
            val bytes = AvroEncoder.toBinary(pairs)(_codec)
            val metadata = new ObjectMetadata()
            metadata.setContentLength(bytes.length)
            val is = new ByteArrayInputStream(bytes)
            val request = putObjectModifier(new PutObjectRequest(bucket, key, is, metadata))
            Some(request, iter)
          } else  {
            None
          }
        }

      val pool = Executors.newFixedThreadPool(8)

      val write: PutObjectRequest => Process[Task, PutObjectResult] = { request =>
        Process eval Task {
          request.getInputStream.reset() // reset in case of retransmission to avoid 400 error
          s3client.putObject(request)
        }(pool).retryEBO {
          case e: AmazonS3Exception if e.getStatusCode == 503 => true
          case _ => false
        }
      }

      val results = nondeterminism.njoin(maxOpen = 8, maxQueued = 8) { requests map write }
      results.run.unsafePerformSync
      pool.shutdown()
    }
  }
}

object S3RDDWriter extends S3RDDWriter {
   def getS3Client: () => S3Client = () => S3Client.default
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy