All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.Finisher.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.util.FinisherUtil
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, StringArrayParam}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row}

/** Converts annotation results into a format that easier to use. It is useful to extract the
  * results from Spark NLP Pipelines. The Finisher outputs annotation(s) values into `String`.
  *
  * For more extended examples on document pre-processing see the
  * [[https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/2.Text_Preprocessing_with_SparkNLP_Annotators_Transformers.ipynb Spark NLP Workshop]].
  *
  * ==Example==
  * {{{
  * import spark.implicits._
  * import com.johnsnowlabs.nlp.pretrained.PretrainedPipeline
  * import com.johnsnowlabs.nlp.Finisher
  *
  * val data = Seq((1, "New York and New Jersey aren't that far apart actually.")).toDF("id", "text")
  *
  * // Extracts Named Entities amongst other things
  * val pipeline = PretrainedPipeline("explain_document_dl")
  *
  * val finisher = new Finisher().setInputCols("entities").setOutputCols("output")
  * val explainResult = pipeline.transform(data)
  *
  * explainResult.selectExpr("explode(entities)").show(false)
  * +------------------------------------------------------------------------------------------------------------------------------------------------------+
  * |entities                                                                                                                                              |
  * +------------------------------------------------------------------------------------------------------------------------------------------------------+
  * |[[chunk, 0, 7, New York, [entity -> LOC, sentence -> 0, chunk -> 0], []], [chunk, 13, 22, New Jersey, [entity -> LOC, sentence -> 0, chunk -> 1], []]]|
  * +------------------------------------------------------------------------------------------------------------------------------------------------------+
  *
  * val result = finisher.transform(explainResult)
  * result.select("output").show(false)
  * +----------------------+
  * |output                |
  * +----------------------+
  * |[New York, New Jersey]|
  * +----------------------+
  * }}}
  *
  * @see
  *   [[com.johnsnowlabs.nlp.EmbeddingsFinisher EmbeddingsFinisher]] for finishing embeddings
  * @param uid
  *   required uid for storing annotator to disk
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class Finisher(override val uid: String) extends Transformer with DefaultParamsWritable {

  /** Name of input annotation cols
    *
    * @group param
    */
  val inputCols: StringArrayParam =
    new StringArrayParam(this, "inputCols", "Name of input annotation cols")

  /** Name of finisher output cols
    *
    * @group param
    */
  val outputCols: StringArrayParam =
    new StringArrayParam(this, "outputCols", "Name of finisher output cols")

  /** Character separating annotations (Default: `#`)
    *
    * @group param
    */
  val valueSplitSymbol: Param[String] =
    new Param(this, "valueSplitSymbol", "Character separating annotations (Default: `#`)")

  /** Character separating annotations (Default: `@`)
    *
    * @group param
    */
  val annotationSplitSymbol: Param[String] =
    new Param(this, "annotationSplitSymbol", "Character separating annotations (Default: `#`)")

  /** Whether to remove annotation columns (Default: `true`)
    *
    * @group param
    */
  val cleanAnnotations: BooleanParam =
    new BooleanParam(
      this,
      "cleanAnnotations",
      "Whether to remove annotation columns (Default: `true`)")

  /** Annotation metadata format (Default: `false`)
    *
    * @group param
    */
  val includeMetadata: BooleanParam =
    new BooleanParam(this, "includeMetadata", "Annotation metadata format (Default: `false`)")

  /** Finisher generates an Array with the results instead of string (Default: `true`)
    *
    * @group param
    */
  val outputAsArray: BooleanParam =
    new BooleanParam(
      this,
      "outputAsArray",
      "Finisher generates an Array with the results instead of string (Default: `true`)")

  /** Whether to include embeddings vectors in the process (Default: `false`)
    *
    * @group param
    */
  val parseEmbeddingsVectors: BooleanParam =
    new BooleanParam(
      this,
      "parseEmbeddingsVectors",
      "Whether to include embeddings vectors in the process (Default: `false`)")

  /** Name of input annotation cols
    *
    * @group setParam
    */
  def setInputCols(value: Array[String]): this.type = set(inputCols, value)

  /** Name of input annotation cols
    *
    * @group setParam
    */
  def setInputCols(value: String*): this.type = setInputCols(value.toArray)

  /** Name of finisher output cols
    *
    * @group setParam
    */
  def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

  /** Name of finisher output cols
    *
    * @group setParam
    */
  def setOutputCols(value: String*): this.type = setOutputCols(value.toArray)

  /** Character separating annotations (Default: `#`)
    *
    * @group setParam
    */
  def setValueSplitSymbol(value: String): this.type = set(valueSplitSymbol, value)

  /** Character separating annotations (Default: `#`)
    *
    * @group setParam
    */
  def setAnnotationSplitSymbol(value: String): this.type = set(annotationSplitSymbol, value)

  /** Whether to remove annotation columns (Default: `true`)
    *
    * @group setParam
    */
  def setCleanAnnotations(value: Boolean): this.type = set(cleanAnnotations, value)

  /** Annotation metadata format (Default: `false`)
    *
    * @group setParam
    */
  def setIncludeMetadata(value: Boolean): this.type = set(includeMetadata, value)

  /** Finisher generates an Array with the results instead of string (Default: `true`)
    *
    * @group setParam
    */
  def setOutputAsArray(value: Boolean): this.type = set(outputAsArray, value)

  /** Name of input annotation cols
    *
    * @group getParam
    */
  def getOutputCols: Array[String] = get(outputCols).getOrElse(getInputCols.map("finished_" + _))

  /** Name of finisher output cols
    *
    * @group getParam
    */
  def getInputCols: Array[String] = $(inputCols)

  /** Character separating annotations (Default: `#`)
    *
    * @group getParam
    */
  def getValueSplitSymbol: String = $(valueSplitSymbol)

  /** Character separating annotations (Default: `#`)
    *
    * @group getParam
    */
  def getAnnotationSplitSymbol: String = $(annotationSplitSymbol)

  /** Whether to remove annotation columns (Default: `true`)
    *
    * @group getParam
    */
  def getCleanAnnotations: Boolean = $(cleanAnnotations)

  /** Annotation metadata format (Default: `false`)
    *
    * @group getParam
    */
  def getIncludeMetadata: Boolean = $(includeMetadata)

  /** Finisher generates an Array with the results instead of string (Default: `true`)
    *
    * @group getParam
    */
  def getOutputAsArray: Boolean = $(outputAsArray)

  setDefault(
    cleanAnnotations -> true,
    includeMetadata -> false,
    outputAsArray -> true,
    parseEmbeddingsVectors -> false,
    valueSplitSymbol -> "#",
    annotationSplitSymbol -> "@")

  def this() = this(Identifiable.randomUID("finisher"))

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    require(
      getInputCols.length == getOutputCols.length,
      "inputCols and outputCols length must match")
    getInputCols.foreach { annotationColumn =>
      FinisherUtil.checkIfInputColsExist(getInputCols, schema)
      FinisherUtil.checkIfAnnotationColumnIsSparkNLPAnnotation(schema, annotationColumn)
    }
    val metadataFields = FinisherUtil.getMetadataFields(getOutputCols, $(outputAsArray))
    val outputFields = schema.fields ++ FinisherUtil.getOutputFields(
      getOutputCols,
      $(outputAsArray)) ++ metadataFields
    val cleanFields = FinisherUtil.getCleanFields($(cleanAnnotations), outputFields)

    StructType(cleanFields)
  }

  override def transform(dataset: Dataset[_]): Dataset[Row] = {
    /*For some reason, Dataset[_] -> Dataset[Row] is not accepted through foldRight
    val flattened = getInputCols.foldRight(dataset)((inputCol, data) =>
      data.withColumn(inputCol, Annotation.flatten(data.col(inputCol))).toDF()
    )
     */
    require(
      getInputCols.length == getOutputCols.length,
      "inputCols and outputCols length must match")
    val cols = getInputCols.zip(getOutputCols)
    var flattened = dataset
    cols.foreach { case (inputCol, outputCol) =>
      flattened = {
        flattened.withColumn(
          outputCol, {
            if ($(outputAsArray))
              Annotation.flattenArray($(parseEmbeddingsVectors))(flattened.col(inputCol))
            else if (! $(includeMetadata))
              Annotation.flatten(
                $(valueSplitSymbol),
                $(annotationSplitSymbol),
                $(parseEmbeddingsVectors))(flattened.col(inputCol))
            else
              Annotation.flattenDetail(
                $(valueSplitSymbol),
                $(annotationSplitSymbol),
                $(parseEmbeddingsVectors))(flattened.col(inputCol))
          })
      }
    }
    if ($(outputAsArray) && $(includeMetadata))
      cols.foreach { case (inputCol, outputCol) =>
        flattened = flattened.withColumn(
          outputCol + "_metadata",
          Annotation.flattenArrayMetadata(flattened.col(inputCol)))
      }

    FinisherUtil.cleaningAnnotations($(cleanAnnotations), flattened.toDF())
  }

}

/** This is the companion object of [[Finisher]]. Please refer to that class for the
  * documentation.
  */
object Finisher extends DefaultParamsReadable[Finisher]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy