All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.pretrained.S3ResourceDownloader.scala Maven / Gradle / Ivy

package com.johnsnowlabs.nlp.pretrained

import java.io.File
import java.nio.file.Files
import java.sql.Timestamp
import java.util.Calendar
import java.util.zip.ZipInputStream

import com.amazonaws.auth.AWSCredentials
import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.s3.AmazonS3Client
import com.amazonaws.services.s3.model.GetObjectRequest
import com.amazonaws.{AmazonServiceException, ClientConfiguration}
import com.johnsnowlabs.util.ConfigHelper
import org.apache.hadoop.fs.Path

import scala.collection.mutable


class S3ResourceDownloader(bucket: => String,
                           s3Path: => String,
                           cacheFolder: => String,
                           credentials: => Option[AWSCredentials] = None,
                           region: String = "us-east-1"
                          )
  extends ResourceDownloader with AutoCloseable {

  // repository Folder -> repository Metadata
  val repoFolder2Metadata = mutable.Map[String, RepositoryMetadata]()
  val cachePath = new Path(cacheFolder)

  if (!fs.exists(cachePath)) {
    fs.mkdirs(cachePath)
  }

  lazy val client = {
    val regionObj = RegionUtils.getRegion(region)

    val config = new ClientConfiguration()
    val timeout = ConfigHelper.getConfigValue(ConfigHelper.s3SocketTimeout).map(_.toInt).getOrElse(0)
    config.setSocketTimeout(timeout)

    val s3Client = {
      if (credentials.isDefined) {
        new AmazonS3Client(credentials.get, config)
      } else {
        new AmazonS3Client(config)
      }
    }

    s3Client.setRegion(regionObj)
    s3Client
  }

  private def downloadMetadataIfNeed(folder: String): List[ResourceMetadata] = {
    val lastState = repoFolder2Metadata.get(folder)

    val fiveMinsBefore = getTimestamp(-5)
    val needToRefersh = lastState.isEmpty || lastState.get.lastMetadataDownloaded.before(fiveMinsBefore)

    if (!needToRefersh) {
      lastState.get.metadata
    }
    else {
      val metaFile = getS3File(s3Path, folder, "metadata.json")
      val obj = client.getObject(bucket, metaFile)
      val metadata = ResourceMetadata.readResources(obj.getObjectContent)
      val version = obj.getObjectMetadata.getVersionId

      RepositoryMetadata(metaFile, folder, version, getTimestamp(), metadata)

      metadata
    }
  }

  def resolveLink(request: ResourceRequest): Option[ResourceMetadata] = {
    val metadata = downloadMetadataIfNeed(request.folder)
    ResourceMetadata.resolveResource(metadata, request)
  }

  /**
    * Download resource to local file
    *
    * @param request        Resource request
    * @return               Downloaded file or None if resource is not found
    */
  override def download(request: ResourceRequest): Option[String] = {

    val link = resolveLink(request)
    link.flatMap {
      resource =>
        val s3FilePath = getS3File(s3Path, request.folder, resource.fileName)
        val dstFile = new Path(cachePath.toString, resource.fileName)
        if (!client.doesObjectExist(bucket, s3FilePath)) {
          None
        } else {
          if (!fs.exists(dstFile)) {

            // 1. Create tmp file
            val tmpFileName = Files.createTempFile(resource.fileName, "").toString
            val tmpFile = new File(tmpFileName)

            // 2. Download content to tmp file
            val req = new GetObjectRequest(bucket, s3FilePath)
            client.getObject(req, tmpFile)

            // 3. Move tmp file to destination
            fs.moveFromLocalFile(new Path(tmpFile.toString), dstFile)
          }

          // 4. Unzip if needs
          if (resource.isZipped) {
            val zis = new ZipInputStream(fs.open(dstFile))
            val buf = Array.ofDim[Byte](1024)
            var entry = zis.getNextEntry
            require(dstFile.toString.substring(dstFile.toString.length - 4) == ".zip")
            val splitPath = dstFile.toString.substring(0, dstFile.toString.length - 4)
            while (entry != null) {
              if (!entry.isDirectory) {
                val entryName = new Path(splitPath, entry.getName)
                val outputStream = fs.create(entryName)
                var bytesRead = zis.read(buf, 0, 1024)
                while (bytesRead > -1) {
                  outputStream.write(buf, 0, bytesRead)
                  bytesRead = zis.read(buf, 0, 1024)
                }
                outputStream.close()
              }
              zis.closeEntry()
              entry = zis.getNextEntry
            }
            zis.close()
            Some(splitPath)
          } else {
            Some(dstFile.getName)
          }
        }
    }
  }

  override def close(): Unit = {
    client.shutdown()
  }

  override def clearCache(request: ResourceRequest): Unit = {
    val metadata = downloadMetadataIfNeed(request.folder)

    val resources = ResourceMetadata.resolveResource(metadata, request)
    for (resource <- resources) {
      val fileName = new Path(cachePath.toString, resource.fileName)
      if (fs.exists(fileName))
        fs.delete(fileName, true)

      if (resource.isZipped) {
        require(fileName.toString.substring(fileName.toString.length - 4) == ".zip")
        val unzipped = fileName.toString.substring(0, fileName.toString.length - 4)
        val unzippedFile = new Path(unzipped)
        if (fs.exists(unzippedFile))
          fs.delete(unzippedFile, true)
      }
    }
  }

  private def getTimestamp(addMinutes: Int = 0): Timestamp = {
    val cal = Calendar.getInstance()
    cal.add(Calendar.MINUTE, addMinutes)
    val timestamp = new Timestamp(cal.getTime().getTime())
    cal.clear()
    timestamp
  }

  private def getS3File(parts: String*): String = {
    parts
      .map(part => part.stripSuffix("/"))
      .filter(part => part.nonEmpty)
      .mkString("/")
  }

  implicit class S3ClientWrapper(client: AmazonS3Client) {

    def doesObjectExist(bucket: String, key: String): Boolean = {
      try {
        client.getObjectMetadata(bucket, key)
        true
      } catch {
        case e: AmazonServiceException => if (e.getStatusCode == 404) return false else throw e
      }
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy