com.johnsnowlabs.storage.HasStorage.scala Maven / Gradle / Ivy
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.storage
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
import com.johnsnowlabs.nlp.HasCaseSensitiveProperties
import com.johnsnowlabs.nlp.annotators.param.ExternalResourceParam
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs}
import com.johnsnowlabs.storage.Database.Name
import com.johnsnowlabs.util.{ConfigHelper, ConfigLoader, FileHelper}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Dataset, SparkSession}
import java.nio.file.{Files, Paths, StandardCopyOption}
import java.util.UUID
trait HasStorage extends HasStorageRef with HasStorageOptions with HasCaseSensitiveProperties {
protected val databases: Array[Database.Name]
/** Path to the external resource.
* @group param
*/
val storagePath = new ExternalResourceParam(this, "storagePath", "path to file")
/** @group setParam */
def setStoragePath(path: String, readAs: String): this.type =
set(storagePath, new ExternalResource(path, readAs, Map.empty[String, String]))
/** @group setParam */
def setStoragePath(path: String, readAs: ReadAs.Value): this.type =
setStoragePath(path, readAs.toString)
/** @group getParam */
def getStoragePath: Option[ExternalResource] = get(storagePath)
protected val missingRefMsg: String = s"Please set storageRef param in $this."
protected def index(
fitDataset: Dataset[_],
storageSourcePath: Option[String],
readAs: Option[ReadAs.Value],
writers: Map[Database.Name, StorageWriter[_]],
readOptions: Option[Map[String, String]] = None): Unit
protected def createWriter(database: Name, connection: RocksDBConnection): StorageWriter[_]
private def indexDatabases(
databases: Array[Database.Name],
resource: Option[ExternalResource],
localFiles: Array[String],
fitDataset: Dataset[_],
spark: SparkContext): Unit = {
require(
databases.length == localFiles.length,
"Storage temp locations must be equal to the amount of databases")
lazy val connections = databases
.zip(localFiles)
.map { case (database, localFile) => (database, RocksDBConnection.getOrCreate(localFile)) }
val writers = connections
.map { case (db, conn) =>
(db, createWriter(db, conn))
}
.toMap[Database.Name, StorageWriter[_]]
val storageSourcePath = resource.map(r => importIfS3(r.path, spark).toUri.toString)
if (resource.isDefined && new Path(resource.get.path)
.getFileSystem(spark.hadoopConfiguration)
.getScheme != "file") {
val uri = new java.net.URI(storageSourcePath.get.replaceAllLiterally("\\", "/"))
val fs = FileSystem.get(uri, spark.hadoopConfiguration)
/** ToDo: What if the file is too large to copy to local? Index directly from hadoop? */
val tmpFile = Files.createTempFile("sparknlp_", ".str").toAbsolutePath.toString
fs.copyToLocalFile(new Path(storageSourcePath.get), new Path(tmpFile))
index(fitDataset, Some(tmpFile), resource.map(_.readAs), writers, resource.map(_.options))
FileHelper.delete(tmpFile)
} else {
index(
fitDataset,
storageSourcePath,
resource.map(_.readAs),
writers,
resource.map(_.options))
}
writers.values.foreach(_.close())
connections.map(_._2).foreach(_.close())
}
private def preload(
fitDataset: Dataset[_],
resource: Option[ExternalResource],
spark: SparkSession,
databases: Array[Database.Name]): Unit = {
val sparkContext = spark.sparkContext
val tmpLocalDestinations = {
databases.map(_ =>
Files
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_idx")
.toAbsolutePath
.toString)
}
indexDatabases(databases, resource, tmpLocalDestinations, fitDataset, sparkContext)
val locators =
databases.map(database => StorageLocator(database.toString, $(storageRef), spark))
tmpLocalDestinations.zip(locators).foreach { case (tmpLocalDestination, locator) =>
/** tmpFiles indexed must be explicitly set to be local files */
val uri =
"file://" + new java.net.URI(tmpLocalDestination.replaceAllLiterally("\\", "/")).getPath
StorageHelper.sendToCluster(
new Path(uri),
locator.clusterFilePath,
locator.clusterFileName,
locator.destinationScheme,
sparkContext)
}
// 3. Create Spark Embeddings
locators.foreach(locator => RocksDBConnection.getOrCreate(locator.clusterFileName))
}
private def importIfS3(path: String, spark: SparkContext): Path = {
val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
var src = new Path(path)
// if the path contains s3a download to local cache if not present
if (uri.getScheme != null) {
if (uri.getScheme.equals("s3a")) {
var accessKeyId = ConfigLoader.getConfigStringValue(ConfigHelper.accessKeyId)
var secretAccessKey = ConfigLoader.getConfigStringValue(ConfigHelper.secretAccessKey)
if (accessKeyId == "" || secretAccessKey == "") {
val defaultCredentials = new DefaultAWSCredentialsProviderChain().getCredentials
accessKeyId = defaultCredentials.getAWSAccessKeyId
secretAccessKey = defaultCredentials.getAWSSecretKey
}
var old_key = ""
var old_secret = ""
if (spark.hadoopConfiguration.get("fs.s3a.access.key") != null) {
old_key = spark.hadoopConfiguration.get("fs.s3a.access.key")
old_secret = spark.hadoopConfiguration.get("fs.s3a.secret.key")
}
try {
val dst = new Path(ResourceDownloader.cacheFolder, src.getName)
if (!Files.exists(Paths.get(dst.toUri.getPath))) {
// download s3 resource locally using config keys
spark.hadoopConfiguration.set("fs.s3a.access.key", accessKeyId)
spark.hadoopConfiguration.set("fs.s3a.secret.key", secretAccessKey)
val s3fs = FileSystem.get(uri, spark.hadoopConfiguration)
val dst_tmp = new Path(ResourceDownloader.cacheFolder, src.getName + "_tmp")
s3fs.copyToLocalFile(src, dst_tmp)
// rename to original file
Files.move(
Paths.get(dst_tmp.toUri.getRawPath),
Paths.get(dst.toUri.getRawPath),
StandardCopyOption.REPLACE_EXISTING)
}
src = new Path(dst.toUri.getPath)
} finally {
// reset the keys
if (!old_key.equals("")) {
spark.hadoopConfiguration.set("fs.s3a.access.key", old_key)
spark.hadoopConfiguration.set("fs.s3a.secret.key", old_secret)
}
}
}
}
src
}
private var preloaded = false
def indexStorage(fitDataset: Dataset[_], resource: Option[ExternalResource]): Unit = {
if (!preloaded) {
preloaded = true
require(isDefined(storageRef), missingRefMsg)
preload(fitDataset, resource, fitDataset.sparkSession, databases)
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy