All Downloads are FREE. Search and download functionalities are using the official Maven repository.

geotrellis.spark.io.s3.S3InputFormat.scala Maven / Gradle / Ivy

package geotrellis.spark.io.s3

import com.amazonaws.services.s3.model.{ListObjectsRequest, ObjectListing}
import com.amazonaws.auth._
import com.amazonaws.regions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{InputFormat, Job, JobContext}
import com.typesafe.scalalogging.slf4j._

import scala.util.matching.Regex

/** Reads keys from s3n URL using AWS Java SDK.
  * The number of keys per InputSplits are controlled by S3 pagination.
  * If AWS credentials are not part of the URL they will be discovered using [DefaultAWSCredentialsProviderChain]:
  *   - EnvironmentVariableCredentialsProvider
  *   - SystemPropertiesCredentialsProvider
  *   - ProfileCredentialsProvider
  *   - InstanceProfileCredentialsProvider
  */
abstract class S3InputFormat[K, V] extends InputFormat[K,V] with LazyLogging {
  import S3InputFormat._

  def getS3Client(credentials: AWSCredentials): S3Client =
    new geotrellis.spark.io.s3.AmazonS3Client(credentials, S3Client.defaultConfiguration)

  override def getSplits(context: JobContext) = {
    import scala.collection.JavaConversions._

    val conf = context.getConfiguration
    val anon = conf.get(ANONYMOUS)
    val id = conf.get(AWS_ID)
    val key = conf.get(AWS_KEY)
    val bucket = conf.get(BUCKET)
    val prefix = conf.get(PREFIX)
    val region: Option[Region] = {
      val r = conf.get(REGION, null)
      if(r != null) Some(Region.getRegion(Regions.fromName(r)))
      else None
    }

    val partitionCountConf = conf.get(PARTITION_COUNT)
    val partitionSizeConf = conf.get(PARTITION_BYTES)
    require(null == partitionCountConf || null == partitionSizeConf,
      "Either PARTITION_COUNT or PARTITION_SIZE option may be set")

    val credentials =
      if (anon != null)
        new AnonymousAWSCredentials()
      else if (id != null && key != null)
        new BasicAWSCredentials(id, key)
      else
        new DefaultAWSCredentialsProviderChain().getCredentials

    val s3client: S3Client = getS3Client(credentials)
    for (r <- region) s3client.setRegion(r)

    logger.info(s"Listing Splits: bucket=$bucket prefix=$prefix")
    logger.debug(s"Authenticating with ID=${credentials.getAWSAccessKeyId}")

    val request = new ListObjectsRequest()
      .withBucketName(bucket)
      .withPrefix(prefix)

    def makeNewSplit =  {
      val split = new S3InputSplit
      split.setCredentials(credentials)
      split.bucket = bucket
      split
    }

    var splits: Vector[S3InputSplit] = Vector(makeNewSplit)

    if (null == partitionCountConf) {
      // By default attempt to make partitions the same size
      val maxSplitBytes = if (null == partitionSizeConf) S3InputFormat.DEFAULT_PARTITION_BYTES else partitionSizeConf.toLong
      logger.info(s"Building partitions, attempting to create them with size at most $maxSplitBytes bytes")
      s3client
        .listObjectsIterator(request)
        .filter(!_.getKey.endsWith("/"))
        .foreach { obj =>
          val curSplit = splits.last
          val objSize = obj.getSize
          if (curSplit.getLength == 0)
            curSplit.addKey(obj)
          else if (curSplit.size + objSize <= maxSplitBytes)
            curSplit.addKey(obj)
          else {
            val newSplit = makeNewSplit
            newSplit.addKey(obj)
            splits = splits :+ newSplit
          }
        }
    } else {
      val partitionCount = partitionCountConf.toInt
      logger.info(s"Building partitions of at most $partitionCount objects")
      val keys =
        s3client
          .listObjectsIterator(request)
          .filter(! _.getKey.endsWith("/"))
          .toVector

      val groupCount = math.max(1, keys.length / partitionCount)

      splits = keys
        .grouped(groupCount)
        .map { keys =>
          val split = makeNewSplit
          keys.foreach(split.addKey)
          split
        }.toVector
    }

    splits
  }
}

object S3InputFormat {
  final val DEFAULT_PARTITION_BYTES =  256 * 1024 * 1024
  final val ANONYMOUS = "s3.anonymous"
  final val AWS_ID = "s3.awsId"
  final val AWS_KEY = "s3.awsKey"
  final val BUCKET = "s3.bucket"
  final val PREFIX = "s3.prefix"
  final val REGION = "s3.region"
  final val PARTITION_COUNT = "s3.partitionCount"
  final val PARTITION_BYTES = "S3.partitionBytes"

  private val idRx = "[A-Z0-9]{20}"
  private val keyRx = "[a-zA-Z0-9+/]+={0,2}"
  private val slug = "[a-zA-Z0-9-]+"
  val S3UrlRx = new Regex(s"""s3n://(?:($idRx):($keyRx)@)?($slug)/{0,1}(.*)""", "aws_id", "aws_key", "bucket", "prefix")

  /** Set S3N url to use, may include AWS Id and Key */
  def setUrl(job: Job, url: String): Unit =
    setUrl(job.getConfiguration, url)

  def setUrl(conf: Configuration, url: String): Unit = {
    val S3UrlRx(id, key, bucket, prefix) = url

    if (id != null && key != null) {
      conf.set(AWS_ID, id)
      conf.set(AWS_KEY, key)
    }
    conf.set(BUCKET, bucket)
    conf.set(PREFIX, prefix)
  }

  def setBucket(job: Job, bucket: String): Unit =
    setBucket(job.getConfiguration, bucket)

  def setBucket(conf: Configuration, bucket: String): Unit =
    conf.set(BUCKET, bucket)

  def setPrefix(job: Job, prefix: String): Unit =
    setPrefix(job.getConfiguration, prefix)

  def setPrefix(conf: Configuration, prefix: String): Unit =
    conf.set(PREFIX, prefix)

  /** Set desired partition count */
  def setPartitionCount(job: Job, limit: Int): Unit =
    setPartitionCount(job.getConfiguration, limit)

  /** Set desired partition count */
  def setPartitionCount(conf: Configuration, limit: Int): Unit =
    conf.set(PARTITION_COUNT, limit.toString)

  def setRegion(job: Job, region: String): Unit =
    setRegion(job.getConfiguration, region)

  def setRegion(conf: Configuration, region: String): Unit =
    conf.set(REGION, region)

  /** Force anonymous access, bypass all key discovery */
  def setAnonymous(job: Job): Unit =
    setAnonymous(job.getConfiguration)

  /** Force anonymous access, bypass all key discovery */
  def setAnonymous(conf: Configuration): Unit =
    conf.set(ANONYMOUS, "true")

  /** Set desired partition size in bytes, at least one item per partition will be assigned */
  def setPartitionBytes(job: Job, bytes: Long): Unit =
    setPartitionBytes(job.getConfiguration, bytes)

  /** Set desired partition size in bytes, at least one item per partition will be assigned */
  def setPartitionBytes(conf: Configuration, bytes: Long): Unit =
    conf.set(PARTITION_BYTES, bytes.toString)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy