All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.pretrained.PretrainedPipeline.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.pretrained

import com.johnsnowlabs.nlp.{IAnnotation, LightPipeline}
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.DataFrame

/** Represents a fully constructed and trained Spark NLP pipeline, ready to be used. This way, a
  * whole pipeline can be defined in 1 line. Additionally, the [[LightPipeline]] version of the
  * model can be retrieved with member `lightModel`.
  *
  * For more extended examples see the [[https://sparknlp.org/docs/en/pipelines Pipelines page]]
  * and our [[https://github.com/JohnSnowLabs/spark-nlp-models Github Model Repository]] for
  * available pipeline models.
  *
  * ==Example==
  * {{{
  * import com.johnsnowlabs.nlp.pretrained.PretrainedPipeline
  * import com.johnsnowlabs.nlp.SparkNLP
  * val testData = spark.createDataFrame(Seq(
  * (1, "Google has announced the release of a beta version of the popular TensorFlow machine learning library"),
  * (2, "Donald John Trump (born June 14, 1946) is the 45th and current president of the United States")
  * )).toDF("id", "text")
  *
  * val pipeline = PretrainedPipeline("explain_document_dl", lang="en")
  *
  * val annotation = pipeline.transform(testData)
  *
  * annotation.select("entities.result").show(false)
  *
  * /*
  * +----------------------------------+
  * |result                            |
  * +----------------------------------+
  * |[Google, TensorFlow]              |
  * |[Donald John Trump, United States]|
  * +----------------------------------+
  * */
  * }}}
  *
  * @param downloadName
  *   Name of the Pipeline Model
  * @param lang
  *   Language of the defined pipeline (Default: "en")
  * @param source
  *   Source where to get the Pipeline Model
  * @param parseEmbeddingsVectors
  * @param diskLocation
  */
case class PretrainedPipeline(
    downloadName: String,
    lang: String = "en",
    source: String = ResourceDownloader.publicLoc,
    parseEmbeddingsVectors: Boolean = false,
    diskLocation: Option[String] = None) {

  /** Support for java default argument interoperability */
  def this(downloadName: String) {
    this(downloadName, "en", ResourceDownloader.publicLoc)
  }

  def this(downloadName: String, lang: String) {
    this(downloadName, lang, ResourceDownloader.publicLoc)
  }

  val model: PipelineModel = if (diskLocation.isEmpty) {
    ResourceDownloader
      .downloadPipeline(downloadName, Option(lang), source)
  } else {
    PipelineModel.load(diskLocation.get)
  }

  lazy val lightModel = new LightPipeline(model, parseEmbeddingsVectors)

  def annotate(target: String): Map[String, Seq[String]] = lightModel.annotate(target)

  def annotate(target: Array[String]): Array[Map[String, Seq[String]]] =
    lightModel.annotate(target)

  def annotateJava(target: String): java.util.Map[String, java.util.List[String]] =
    lightModel.annotateJava(target)

  def annotateJava(targets: java.util.ArrayList[String])
      : java.util.List[java.util.Map[String, java.util.List[String]]] = {
    lightModel.annotateJava(targets)
  }

  def fullAnnotate(target: String, optionalTarget: String = ""): Map[String, Seq[IAnnotation]] = {
    lightModel.fullAnnotate(target, optionalTarget)
  }

  def fullAnnotate(targets: Array[String]): Array[Map[String, Seq[IAnnotation]]] = {
    lightModel.fullAnnotate(targets)
  }

  def fullAnnotate(
      targets: Array[String],
      optionalTargets: Array[String]): Array[Map[String, Seq[IAnnotation]]] = {
    lightModel.fullAnnotate(targets, optionalTargets)
  }

  def fullAnnotateImage(pathToImage: String): Map[String, Seq[IAnnotation]] = {
    lightModel.fullAnnotateImage(pathToImage)
  }

  def fullAnnotateImage(pathToImages: Array[String]): Array[Map[String, Seq[IAnnotation]]] = {
    lightModel.fullAnnotateImage(pathToImages)
  }

  def fullAnnotate(audio: Array[Float]): Map[String, Seq[IAnnotation]] = {
    lightModel.fullAnnotate(audio)
  }

  def fullAnnotate(audios: Array[Array[Float]]): Array[Map[String, Seq[IAnnotation]]] = {
    lightModel.fullAnnotate(audios)
  }

  def fullAnnotateJava(target: String): java.util.Map[String, java.util.List[IAnnotation]] = {
    lightModel.fullAnnotateJava(target)
  }

  def fullAnnotateJava(
      target: String,
      optionalTarget: String): java.util.Map[String, java.util.List[IAnnotation]] = {
    lightModel.fullAnnotateJava(target, optionalTarget)
  }

  def fullAnnotateJava(targets: java.util.ArrayList[String])
      : java.util.List[java.util.Map[String, java.util.List[IAnnotation]]] = {
    lightModel.fullAnnotateJava(targets)
  }

  def fullAnnotateJava(
      targets: java.util.ArrayList[String],
      optionalTargets: java.util.ArrayList[String])
      : java.util.List[java.util.Map[String, java.util.List[IAnnotation]]] = {
    lightModel.fullAnnotateJava(targets, optionalTargets)
  }

  def fullAnnotateImageJava(
      pathToImage: String): java.util.Map[String, java.util.List[IAnnotation]] = {
    lightModel.fullAnnotateImageJava(pathToImage)
  }

  def fullAnnotateImageJava(pathToImages: java.util.ArrayList[String])
      : java.util.List[java.util.Map[String, java.util.List[IAnnotation]]] = {
    lightModel.fullAnnotateJava(pathToImages)
  }

  def fullAnnotateSingleAudioJava(
      audio: java.util.ArrayList[Double]): java.util.Map[String, java.util.List[IAnnotation]] = {
    lightModel.fullAnnotateSingleAudioJava(audio)
  }

  def fullAnnotateAudiosJava(audios: java.util.ArrayList[java.util.ArrayList[Double]])
      : java.util.List[java.util.Map[String, java.util.List[IAnnotation]]] = {
    lightModel.fullAnnotateAudiosJava(audios)
  }

  def transform(dataFrame: DataFrame): DataFrame = model.transform(dataFrame)

}

object PretrainedPipeline {
  def fromDisk(path: String, parseEmbeddings: Boolean = false): PretrainedPipeline = {
    PretrainedPipeline(null, null, null, parseEmbeddings, Some(path))
  }
  def fromDisk(path: String): PretrainedPipeline = {
    fromDisk(path, false)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy