All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.pretrained.ResourceDownloader.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2017-2023 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.pretrained

import com.johnsnowlabs.nlp.annotators._
import com.johnsnowlabs.nlp.annotators.audio.{HubertForCTC, Wav2Vec2ForCTC, WhisperForCTC}
import com.johnsnowlabs.nlp.annotators.classifier.dl._
import com.johnsnowlabs.nlp.annotators.coref.SpanBertCorefModel
import com.johnsnowlabs.nlp.annotators.cv._
import com.johnsnowlabs.nlp.annotators.er.EntityRulerModel
import com.johnsnowlabs.nlp.annotators.ld.dl.LanguageDetectorDL
import com.johnsnowlabs.nlp.annotators.ner.crf.NerCrfModel
import com.johnsnowlabs.nlp.annotators.ner.dl.{NerDLModel, ZeroShotNerModel}
import com.johnsnowlabs.nlp.annotators.parser.dep.DependencyParserModel
import com.johnsnowlabs.nlp.annotators.parser.typdep.TypedDependencyParserModel
import com.johnsnowlabs.nlp.annotators.pos.perceptron.PerceptronModel
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
import com.johnsnowlabs.nlp.annotators.sda.pragmatic.SentimentDetectorModel
import com.johnsnowlabs.nlp.annotators.sda.vivekn.ViveknSentimentModel
import com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLModel
import com.johnsnowlabs.nlp.annotators.seq2seq._
import com.johnsnowlabs.nlp.annotators.spell.context.ContextSpellCheckerModel
import com.johnsnowlabs.nlp.annotators.spell.norvig.NorvigSweetingModel
import com.johnsnowlabs.nlp.annotators.spell.symmetric.SymmetricDeleteModel
import com.johnsnowlabs.nlp.annotators.ws.WordSegmenterModel
import com.johnsnowlabs.nlp.embeddings._
import com.johnsnowlabs.nlp.pretrained.ResourceType.ResourceType
import com.johnsnowlabs.nlp.util.io.{OutputHelper, ResourceHelper}
import com.johnsnowlabs.nlp.{DocumentAssembler, PromptAssembler, TableAssembler, pretrained}
import com.johnsnowlabs.util._
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.ml.util.DefaultParamsReadable
import org.apache.spark.ml.{PipelineModel, PipelineStage}
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
import scala.util.{Failure, Success}

trait ResourceDownloader {

  /** Download resource to local file
    *
    * @param request
    *   Resource request
    * @return
    *   downloaded file or None if resource is not found
    */
  def download(request: ResourceRequest): Option[String]

  def getDownloadSize(request: ResourceRequest): Option[Long]

  def clearCache(request: ResourceRequest): Unit

  def downloadMetadataIfNeed(folder: String): List[ResourceMetadata]

  def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean = true): Option[String]

  val fileSystem: FileSystem = ResourceDownloader.fileSystem

}

object ResourceDownloader {

  private val logger: Logger = LoggerFactory.getLogger(this.getClass.toString)

  val fileSystem: FileSystem = OutputHelper.getFileSystem

  def s3Bucket: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedS3BucketKey)

  def s3BucketCommunity: String =
    ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedCommunityS3BucketKey)

  def s3Path: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedS3PathKey)

  def cacheFolder: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedCacheFolder)

  val publicLoc = "public/models"

  private val cache: mutable.Map[ResourceRequest, PipelineStage] =
    mutable.Map[ResourceRequest, PipelineStage]()

  lazy val sparkVersion: Version = {
    val spark_version = ResourceHelper.spark.version
    Version.parse(spark_version)
  }

  lazy val libVersion: Version = {
    Version.parse(Build.version)
  }

  var privateDownloader: ResourceDownloader =
    new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "private")
  var publicDownloader: ResourceDownloader =
    new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "public")
  var communityDownloader: ResourceDownloader =
    new S3ResourceDownloader(s3BucketCommunity, s3Path, cacheFolder, "community")

  def getResourceDownloader(folder: String): ResourceDownloader = {
    folder match {
      case this.publicLoc => publicDownloader
      case loc if loc.startsWith("@") => communityDownloader
      case _ => privateDownloader
    }
  }

  /** Reset the cache and recreate ResourceDownloader S3 credentials */
  def resetResourceDownloader(): Unit = {
    cache.empty
    this.privateDownloader = new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "private")
  }

  /** List all pretrained models in public name_lang */
  def listPublicModels(): List[String] = {
    listPretrainedResources(folder = publicLoc, ResourceType.MODEL)
  }

  /** Prints all pretrained models for a particular annotator model, that are compatible with a
    * version of Spark NLP. If any of the optional arguments are not set, the filter is not
    * considered.
    *
    * @param annotator
    *   Name of the model class, for example "NerDLModel"
    * @param lang
    *   Language of the pretrained models to display, for example "en"
    * @param version
    *   Version of Spark NLP that the model should be compatible with, for example "3.2.3"
    */
  def showPublicModels(
      annotator: Option[String] = None,
      lang: Option[String] = None,
      version: Option[String] = Some(Build.version)): Unit = {
    println(
      publicResourceString(
        annotator = annotator,
        lang = lang,
        version = version,
        resourceType = ResourceType.MODEL))
  }

  /** Prints all pretrained models for a particular annotator model, that are compatible with this
    * version of Spark NLP.
    *
    * @param annotator
    *   Name of the annotator class
    */
  def showPublicModels(annotator: String): Unit = showPublicModels(Some(annotator))

  /** Prints all pretrained models for a particular annotator model, that are compatible with this
    * version of Spark NLP.
    *
    * @param annotator
    *   Name of the annotator class
    * @param lang
    *   Language of the pretrained models to display
    */
  def showPublicModels(annotator: String, lang: String): Unit =
    showPublicModels(Some(annotator), Some(lang))

  /** Prints all pretrained models for a particular annotator, that are compatible with a version
    * of Spark NLP.
    *
    * @param annotator
    *   Name of the model class, for example "NerDLModel"
    * @param lang
    *   Language of the pretrained models to display, for example "en"
    * @param version
    *   Version of Spark NLP that the model should be compatible with, for example "3.2.3"
    */
  def showPublicModels(annotator: String, lang: String, version: String): Unit =
    showPublicModels(Some(annotator), Some(lang), Some(version))

  /** List all pretrained pipelines in public */
  def listPublicPipelines(): List[String] = {
    listPretrainedResources(folder = publicLoc, ResourceType.PIPELINE)
  }

  /** Prints all Pipelines available for a language and a version of Spark NLP. By default shows
    * all languages and uses the current version of Spark NLP.
    *
    * @param lang
    *   Language of the Pipeline
    * @param version
    *   Version of Spark NLP
    */
  def showPublicPipelines(
      lang: Option[String] = None,
      version: Option[String] = Some(Build.version)): Unit = {
    println(
      publicResourceString(
        annotator = None,
        lang = lang,
        version = version,
        resourceType = ResourceType.PIPELINE))
  }

  /** Prints all Pipelines available for a language and this version of Spark NLP.
    *
    * @param lang
    *   Language of the Pipeline
    */
  def showPublicPipelines(lang: String): Unit = showPublicPipelines(Some(lang))

  /** Prints all Pipelines available for a language and a version of Spark NLP.
    *
    * @param lang
    *   Language of the Pipeline
    * @param version
    *   Version of Spark NLP
    */
  def showPublicPipelines(lang: String, version: String): Unit =
    showPublicPipelines(Some(lang), Some(version))

  /** Returns models or pipelines in metadata json which has not been categorized yet.
    *
    * @return
    *   list of models or pipelines which are not categorized in metadata json
    */
  def listUnCategorizedResources(): List[String] = {
    listPretrainedResources(folder = publicLoc, ResourceType.NOT_DEFINED)
  }

  def showUnCategorizedResources(lang: String): Unit = {
    println(publicResourceString(None, Some(lang), None, resourceType = ResourceType.NOT_DEFINED))
  }

  def showUnCategorizedResources(lang: String, version: String): Unit = {
    println(
      publicResourceString(
        None,
        Some(lang),
        Some(version),
        resourceType = ResourceType.NOT_DEFINED))

  }

  def showString(list: List[String], resourceType: ResourceType): String = {
    val sb = new StringBuilder
    var max_length = 14
    var max_length_version = 7
    for (data <- list) {
      val temp = data.split(":")
      max_length = scala.math.max(temp(0).length, max_length)
      max_length_version = scala.math.max(temp(2).length, max_length_version)
    }
    // adding head
    sb.append("+")
    sb.append("-" * (max_length + 2))
    sb.append("+")
    sb.append("-" * 6)
    sb.append("+")
    sb.append("-" * (max_length_version + 2))
    sb.append("+\n")
    if (resourceType.equals(ResourceType.PIPELINE))
      sb.append(
        "| " + "Pipeline" + (" " * (max_length - 8)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")
    else if (resourceType.equals(ResourceType.MODEL))
      sb.append(
        "| " + "Model" + (" " * (max_length - 5)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")
    else
      sb.append(
        "| " + "Pipeline/Model" + (" " * (max_length - 14)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")

    sb.append("+")
    sb.append("-" * (max_length + 2))
    sb.append("+")
    sb.append("-" * 6)
    sb.append("+")
    sb.append("-" * (max_length_version + 2))
    sb.append("+\n")
    for (data <- list) {
      val temp = data.split(":")
      sb.append(
        "| " + temp(0) + (" " * (max_length - temp(0).length)) + " |  " + temp(1) + "  | " + temp(
          2) + " " * (max_length_version - temp(2).length) + " |\n")
    }
    // adding bottom
    sb.append("+")
    sb.append("-" * (max_length + 2))
    sb.append("+")
    sb.append("-" * 6)
    sb.append("+")
    sb.append("-" * (max_length_version + 2))
    sb.append("+\n")
    sb.toString()

  }

  def publicResourceString(
      annotator: Option[String] = None,
      lang: Option[String] = None,
      version: Option[String] = Some(Build.version),
      resourceType: ResourceType): String = {
    showString(
      listPretrainedResources(
        folder = publicLoc,
        resourceType,
        annotator = annotator,
        lang = lang,
        version = version match {
          case Some(ver) => Some(Version.parse(ver))
          case None => None
        }),
      resourceType)
  }

  /** Lists pretrained resource from metadata.json, depending on the set filters. The folder in
    * the S3 location and the resourceType is necessary. The other filters are optional and will
    * be ignored if not set.
    *
    * @param folder
    *   Folder in the S3 location
    * @param resourceType
    *   Type of the Resource. Can Either `ResourceType.MODEL`, `ResourceType.PIPELINE` or
    *   `ResourceType.NOT_DEFINED`
    * @param annotator
    *   Name of the model class
    * @param lang
    *   Language of the model
    * @param version
    *   Version that the model should be compatible with
    * @return
    *   A list of the available resources
    */
  def listPretrainedResources(
      folder: String,
      resourceType: ResourceType,
      annotator: Option[String] = None,
      lang: Option[String] = None,
      version: Option[Version] = None): List[String] = {
    val resourceList = new ListBuffer[String]()

    val resourceMetaData = getResourceMetadata(folder)

    for (meta <- resourceMetaData) {
      val isSameResourceType =
        meta.category.getOrElse(ResourceType.NOT_DEFINED).toString.equals(resourceType.toString)
      val isCompatibleWithVersion = version match {
        case Some(ver) => Version.isCompatible(ver, meta.libVersion)
        case None => true
      }
      val isSameAnnotator = annotator match {
        case Some(cls) => meta.annotator.getOrElse("").equalsIgnoreCase(cls)
        case None => true
      }
      val isSameLanguage = lang match {
        case Some(l) => meta.language.getOrElse("").equalsIgnoreCase(l)
        case None => true
      }

      if (isSameResourceType & isCompatibleWithVersion & isSameAnnotator & isSameLanguage) {
        resourceList += meta.name + ":" + meta.language.getOrElse("-") + ":" + meta.libVersion
          .getOrElse("-")
      }
    }
    resourceList.result()
  }

  def listPretrainedResources(
      folder: String,
      resourceType: ResourceType,
      lang: String): List[String] =
    listPretrainedResources(folder, resourceType, lang = Some(lang))

  def listPretrainedResources(
      folder: String,
      resourceType: ResourceType,
      version: Version): List[String] =
    listPretrainedResources(folder, resourceType, version = Some(version))

  def listPretrainedResources(
      folder: String,
      resourceType: ResourceType,
      lang: String,
      version: Version): List[String] =
    listPretrainedResources(folder, resourceType, lang = Some(lang), version = Some(version))

  def listAvailableAnnotators(folder: String = publicLoc): List[String] = {

    val resourceMetaData = getResourceMetadata(folder)

    resourceMetaData
      .map(_.annotator.getOrElse(""))
      .toSet
      .filter { a =>
        !a.equals("")
      }
      .toList
      .sorted
  }

  private def getResourceMetadata(location: String): List[ResourceMetadata] = {
    getResourceDownloader(location).downloadMetadataIfNeed(location)
  }

  def showAvailableAnnotators(folder: String = publicLoc): Unit = {
    println(listAvailableAnnotators(folder).mkString("\n"))
  }

  /** Loads resource to path
    *
    * @param name
    *   Name of Resource
    * @param folder
    *   Subfolder in s3 where to search model (e.g. medicine)
    * @param language
    *   Desired language of Resource
    * @return
    *   path of downloaded resource
    */
  def downloadResource(
      name: String,
      language: Option[String] = None,
      folder: String = publicLoc): String = {
    downloadResource(ResourceRequest(name, language, folder))
  }

  /** Loads resource to path
    *
    * @param request
    *   Request for resource
    * @return
    *   path of downloaded resource
    */
  def downloadResource(request: ResourceRequest): String = {
    val future = Future {
      val updatedRequest: ResourceRequest = if (request.folder.startsWith("@")) {
        request.copy(folder = request.folder.replace("@", ""))
      } else request
      getResourceDownloader(request.folder).download(updatedRequest)
    }

    var downloadFinished = false
    var path: Option[String] = None
    val fileSize = getDownloadSize(request)
    require(
      !fileSize.equals("-1"),
      s"Can not find ${request.name} inside ${request.folder} to download. Please make sure the name and location are correct!")
    println(request.name + " download started this may take some time.")
    println("Approximate size to download " + fileSize)

    while (!downloadFinished) {
      future.onComplete {
        case Success(value) =>
          downloadFinished = true
          path = value
        case Failure(exception) =>
          println(s"Error: ${exception.getMessage}")
          logger.error(exception.getMessage)
          downloadFinished = true
          path = None
      }
      Thread.sleep(1000)

    }

    require(
      path.isDefined,
      s"Was not found appropriate resource to download for request: $request with downloader: $privateDownloader")
    println("Download done! Loading the resource.")
    path.get
  }

  /** Downloads a model from the default S3 bucket to the cache pretrained folder.
    * @param model
    *   the name of the key in the S3 bucket or s3 URI
    * @param folder
    *   the folder of the model
    * @param unzip
    *   used to unzip the model, by default true
    */
  def downloadModelDirectly(
      model: String,
      folder: String = publicLoc,
      unzip: Boolean = true): Unit = {
    getResourceDownloader(folder).downloadAndUnzipFile(model, unzip)
  }

  def downloadModel[TModel <: PipelineStage](
      reader: DefaultParamsReadable[TModel],
      name: String,
      language: Option[String] = None,
      folder: String = publicLoc): TModel = {
    downloadModel(reader, ResourceRequest(name, language, folder))
  }

  def downloadModel[TModel <: PipelineStage](
      reader: DefaultParamsReadable[TModel],
      request: ResourceRequest): TModel = {
    if (!cache.contains(request)) {
      val path = downloadResource(request)
      val model = reader.read.load(path)
      cache(request) = model
      model
    } else {
      cache(request).asInstanceOf[TModel]
    }
  }

  def downloadPipeline(
      name: String,
      language: Option[String] = None,
      folder: String = publicLoc): PipelineModel = {
    downloadPipeline(ResourceRequest(name, language, folder))
  }

  def downloadPipeline(request: ResourceRequest): PipelineModel = {
    if (!cache.contains(request)) {
      val path = downloadResource(request)
      val model = PipelineModel.read.load(path)
      cache(request) = model
      model
    } else {
      cache(request).asInstanceOf[PipelineModel]
    }
  }

  def clearCache(
      name: String,
      language: Option[String] = None,
      folder: String = publicLoc): Unit = {
    clearCache(ResourceRequest(name, language, folder))
  }

  def clearCache(request: ResourceRequest): Unit = {
    privateDownloader.clearCache(request)
    publicDownloader.clearCache(request)
    communityDownloader.clearCache(request)
    cache.remove(request)
  }

  def getDownloadSize(resourceRequest: ResourceRequest): String = {

    val updatedResourceRequest: ResourceRequest = if (resourceRequest.folder.startsWith("@")) {
      resourceRequest.copy(folder = resourceRequest.folder.replace("@", ""))
    } else resourceRequest

    val size = getResourceDownloader(resourceRequest.folder)
      .getDownloadSize(updatedResourceRequest)

    size match {
      case Some(downloadBytes) => FileHelper.getHumanReadableFileSize(downloadBytes)
      case None => "-1"

    }
  }

}

object ResourceType extends Enumeration {
  type ResourceType = Value
  val MODEL: pretrained.ResourceType.Value = Value("ml")
  val PIPELINE: pretrained.ResourceType.Value = Value("pl")
  val NOT_DEFINED: pretrained.ResourceType.Value = Value("nd")
}

case class ResourceRequest(
    name: String,
    language: Option[String] = None,
    folder: String = ResourceDownloader.publicLoc,
    libVersion: Version = ResourceDownloader.libVersion,
    sparkVersion: Version = ResourceDownloader.sparkVersion)

/* convenience accessor for Py4J calls */
object PythonResourceDownloader {

  val keyToReader: mutable.Map[String, DefaultParamsReadable[_]] = mutable.Map(
    "DocumentAssembler" -> DocumentAssembler,
    "SentenceDetector" -> SentenceDetector,
    "TokenizerModel" -> TokenizerModel,
    "PerceptronModel" -> PerceptronModel,
    "NerCrfModel" -> NerCrfModel,
    "Stemmer" -> Stemmer,
    "NormalizerModel" -> NormalizerModel,
    "RegexMatcherModel" -> RegexMatcherModel,
    "LemmatizerModel" -> LemmatizerModel,
    "DateMatcher" -> DateMatcher,
    "TextMatcherModel" -> TextMatcherModel,
    "SentimentDetectorModel" -> SentimentDetectorModel,
    "ViveknSentimentModel" -> ViveknSentimentModel,
    "NorvigSweetingModel" -> NorvigSweetingModel,
    "SymmetricDeleteModel" -> SymmetricDeleteModel,
    "NerDLModel" -> NerDLModel,
    "WordEmbeddingsModel" -> WordEmbeddingsModel,
    "BertEmbeddings" -> BertEmbeddings,
    "DependencyParserModel" -> DependencyParserModel,
    "TypedDependencyParserModel" -> TypedDependencyParserModel,
    "UniversalSentenceEncoder" -> UniversalSentenceEncoder,
    "ElmoEmbeddings" -> ElmoEmbeddings,
    "ClassifierDLModel" -> ClassifierDLModel,
    "ContextSpellCheckerModel" -> ContextSpellCheckerModel,
    "AlbertEmbeddings" -> AlbertEmbeddings,
    "XlnetEmbeddings" -> XlnetEmbeddings,
    "SentimentDLModel" -> SentimentDLModel,
    "LanguageDetectorDL" -> LanguageDetectorDL,
    "StopWordsCleaner" -> StopWordsCleaner,
    "BertSentenceEmbeddings" -> BertSentenceEmbeddings,
    "MultiClassifierDLModel" -> MultiClassifierDLModel,
    "SentenceDetectorDLModel" -> SentenceDetectorDLModel,
    "T5Transformer" -> T5Transformer,
    "MarianTransformer" -> MarianTransformer,
    "WordSegmenterModel" -> WordSegmenterModel,
    "DistilBertEmbeddings" -> DistilBertEmbeddings,
    "RoBertaEmbeddings" -> RoBertaEmbeddings,
    "XlmRoBertaEmbeddings" -> XlmRoBertaEmbeddings,
    "LongformerEmbeddings" -> LongformerEmbeddings,
    "RoBertaSentenceEmbeddings" -> RoBertaSentenceEmbeddings,
    "XlmRoBertaSentenceEmbeddings" -> XlmRoBertaSentenceEmbeddings,
    "AlbertForTokenClassification" -> AlbertForTokenClassification,
    "BertForTokenClassification" -> BertForTokenClassification,
    "DeBertaForTokenClassification" -> DeBertaForTokenClassification,
    "DistilBertForTokenClassification" -> DistilBertForTokenClassification,
    "LongformerForTokenClassification" -> LongformerForTokenClassification,
    "RoBertaForTokenClassification" -> RoBertaForTokenClassification,
    "XlmRoBertaForTokenClassification" -> XlmRoBertaForTokenClassification,
    "XlnetForTokenClassification" -> XlnetForTokenClassification,
    "AlbertForSequenceClassification" -> AlbertForSequenceClassification,
    "BertForSequenceClassification" -> BertForSequenceClassification,
    "DeBertaForSequenceClassification" -> DeBertaForSequenceClassification,
    "DistilBertForSequenceClassification" -> DistilBertForSequenceClassification,
    "LongformerForSequenceClassification" -> LongformerForSequenceClassification,
    "RoBertaForSequenceClassification" -> RoBertaForSequenceClassification,
    "XlmRoBertaForSequenceClassification" -> XlmRoBertaForSequenceClassification,
    "XlnetForSequenceClassification" -> XlnetForSequenceClassification,
    "GPT2Transformer" -> GPT2Transformer,
    "EntityRulerModel" -> EntityRulerModel,
    "Doc2VecModel" -> Doc2VecModel,
    "Word2VecModel" -> Word2VecModel,
    "DeBertaEmbeddings" -> DeBertaEmbeddings,
    "DeBertaForSequenceClassification" -> DeBertaForSequenceClassification,
    "DeBertaForTokenClassification" -> DeBertaForTokenClassification,
    "CamemBertEmbeddings" -> CamemBertEmbeddings,
    "AlbertForQuestionAnswering" -> AlbertForQuestionAnswering,
    "BertForQuestionAnswering" -> BertForQuestionAnswering,
    "DeBertaForQuestionAnswering" -> DeBertaForQuestionAnswering,
    "DistilBertForQuestionAnswering" -> DistilBertForQuestionAnswering,
    "LongformerForQuestionAnswering" -> LongformerForQuestionAnswering,
    "RoBertaForQuestionAnswering" -> RoBertaForQuestionAnswering,
    "XlmRoBertaForQuestionAnswering" -> XlmRoBertaForQuestionAnswering,
    "SpanBertCorefModel" -> SpanBertCorefModel,
    "ViTForImageClassification" -> ViTForImageClassification,
    "VisionEncoderDecoderForImageCaptioning" -> VisionEncoderDecoderForImageCaptioning,
    "SwinForImageClassification" -> SwinForImageClassification,
    "ConvNextForImageClassification" -> ConvNextForImageClassification,
    "Wav2Vec2ForCTC" -> Wav2Vec2ForCTC,
    "HubertForCTC" -> HubertForCTC,
    "WhisperForCTC" -> WhisperForCTC,
    "CamemBertForTokenClassification" -> CamemBertForTokenClassification,
    "TableAssembler" -> TableAssembler,
    "TapasForQuestionAnswering" -> TapasForQuestionAnswering,
    "CamemBertForSequenceClassification" -> CamemBertForSequenceClassification,
    "CamemBertForQuestionAnswering" -> CamemBertForQuestionAnswering,
    "ZeroShotNerModel" -> ZeroShotNerModel,
    "BartTransformer" -> BartTransformer,
    "BertForZeroShotClassification" -> BertForZeroShotClassification,
    "DistilBertForZeroShotClassification" -> DistilBertForZeroShotClassification,
    "RoBertaForZeroShotClassification" -> RoBertaForZeroShotClassification,
    "XlmRoBertaForZeroShotClassification" -> XlmRoBertaForZeroShotClassification,
    "BartForZeroShotClassification" -> BartForZeroShotClassification,
    "InstructorEmbeddings" -> InstructorEmbeddings,
    "E5Embeddings" -> E5Embeddings,
    "MPNetEmbeddings" -> MPNetEmbeddings,
    "CLIPForZeroShotClassification" -> CLIPForZeroShotClassification,
    "DeBertaForZeroShotClassification" -> DeBertaForZeroShotClassification,
    "BGEEmbeddings" -> BGEEmbeddings,
    "MPNetForSequenceClassification" -> MPNetForSequenceClassification,
    "MPNetForQuestionAnswering" -> MPNetForQuestionAnswering,
    "LLAMA2Transformer" -> LLAMA2Transformer,
    "M2M100Transformer" -> M2M100Transformer,
    "UAEEmbeddings" -> UAEEmbeddings,
    "AutoGGUFModel" -> AutoGGUFModel,
    "AlbertForZeroShotClassification" -> AlbertForZeroShotClassification,
    "MxbaiEmbeddings" -> MxbaiEmbeddings,
    "SnowFlakeEmbeddings" -> SnowFlakeEmbeddings,
    "CamemBertForZeroShotClassification" -> CamemBertForZeroShotClassification,
    "BertForMultipleChoice" -> BertForMultipleChoice,
    "PromptAssembler" -> PromptAssembler)

  // List pairs of types such as the one with key type can load a pretrained model from the value type
  val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering")

  def downloadModel(
      readerStr: String,
      name: String,
      language: String = null,
      remoteLoc: String = null): PipelineStage = {

    val reader = keyToReader.getOrElse(
      if (typeMapper.contains(readerStr)) typeMapper(readerStr) else readerStr,
      throw new RuntimeException(s"Unsupported Model: $readerStr"))

    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)

    val model = ResourceDownloader.downloadModel(
      reader.asInstanceOf[DefaultParamsReadable[PipelineStage]],
      name,
      Option(language),
      correctedFolder)

    // Cast the model to the required type. This has to be done for each entry in the typeMapper map
    if (typeMapper.contains(readerStr) && readerStr == "ZeroShotNerModel")
      ZeroShotNerModel(model)
    else
      model
  }

  def downloadPipeline(
      name: String,
      language: String = null,
      remoteLoc: String = null): PipelineModel = {
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
    ResourceDownloader.downloadPipeline(name, Option(language), correctedFolder)
  }

  def clearCache(name: String, language: String = null, remoteLoc: String = null): Unit = {
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
    ResourceDownloader.clearCache(name, Option(language), correctedFolder)
  }

  def downloadModelDirectly(
      model: String,
      remoteLoc: String = null,
      unzip: Boolean = true): Unit = {
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
    ResourceDownloader.downloadModelDirectly(model, correctedFolder, unzip)
  }

  def showUnCategorizedResources(): String = {
    ResourceDownloader.publicResourceString(
      annotator = None,
      lang = None,
      version = None,
      resourceType = ResourceType.NOT_DEFINED)
  }

  def showPublicPipelines(lang: String, version: String): String = {
    val ver: Option[String] = version match {
      case null => Some(Build.version)
      case _ => Some(version)
    }
    ResourceDownloader.publicResourceString(
      annotator = None,
      lang = Option(lang),
      version = ver,
      resourceType = ResourceType.PIPELINE)
  }

  def showPublicModels(annotator: String, lang: String, version: String): String = {
    val ver: Option[String] = version match {
      case null => Some(Build.version)
      case _ => Some(version)
    }
    ResourceDownloader.publicResourceString(
      annotator = Option(annotator),
      lang = Option(lang),
      version = ver,
      resourceType = ResourceType.MODEL)
  }

  def showAvailableAnnotators(): String = {
    ResourceDownloader.listAvailableAnnotators().mkString("\n")
  }

  def getDownloadSize(name: String, language: String = "en", remoteLoc: String = null): String = {
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
    ResourceDownloader.getDownloadSize(ResourceRequest(name, Option(language), correctedFolder))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy