All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.GraphFinisher.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.AnnotatorType.NODE
import com.johnsnowlabs.nlp.util.FinisherUtil
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, Row}

/** Helper class to convert the knowledge graph from GraphExtraction into a generic format, such
  * as RDF.
  *
  * ==Example==
  * This is a continuation of the example of
  * [[com.johnsnowlabs.nlp.annotators.GraphExtraction GraphExtraction]]. To see how the graph is
  * extracted, see the documentation of that class.
  * {{{
  * import com.johnsnowlabs.nlp.GraphFinisher
  *
  * val graphFinisher = new GraphFinisher()
  *   .setInputCol("graph")
  *   .setOutputCol("graph_finished")
  *   .setOutputAsArray(false)
  *
  * val finishedResult = graphFinisher.transform(result)
  * finishedResult.select("text", "graph_finished").show(false)
  * +-----------------------------------------------------+-----------------------------------------------------------------------+
  * |text                                                 |graph_finished                                                         |
  * +-----------------------------------------------------+-----------------------------------------------------------------------+
  * |You and John prefer the morning flight through Denver|[[(prefer,nsubj,morning), (morning,flat,flight), (flight,flat,Denver)]]|
  * +-----------------------------------------------------+-----------------------------------------------------------------------+
  * }}}
  * @see
  *   [[com.johnsnowlabs.nlp.annotators.GraphExtraction GraphExtraction]] to extract the graph.
  * @param uid
  *   required uid for storing annotator to disk
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class GraphFinisher(override val uid: String) extends Transformer {

  def this() = this(Identifiable.randomUID("graph_finisher"))

  /** Name of input annotation cols
    *
    * @group param
    */
  val inputCol = new Param[String](this, "inputCol", "Name of input annotation col")

  /** Name of finisher output cols
    *
    * @group param
    */
  val outputCol =
    new Param[String](this, "outputCol", "Name of finisher output col")

  /** Finisher generates an Array with the results instead of string (Default: `true`)
    *
    * @group param
    */
  val outputAsArray: BooleanParam =
    new BooleanParam(this, "outputAsArray", "Finisher generates an Array with the results")

  /** Whether to remove annotation columns (Default: `true`)
    *
    * @group param
    */
  val cleanAnnotations: BooleanParam =
    new BooleanParam(
      this,
      "cleanAnnotations",
      "Whether to remove annotation columns (Default: `true`)")

  /** Annotation metadata format (Default: `false`)
    *
    * @group param
    */
  val includeMetadata: BooleanParam =
    new BooleanParam(this, "includeMetadata", "Annotation metadata format (Default: `false`)")

  /** Name of input annotation col
    *
    * @group setParam
    */
  def setInputCol(value: String): this.type = set(inputCol, value)

  /** Name of finisher output col
    *
    * @group setParam
    */
  def setOutputCol(value: String): this.type = set(outputCol, value)

  /** Finisher generates an Array with the results instead of string (Default: `true`)
    *
    * @group setParam
    */
  def setOutputAsArray(value: Boolean): this.type = set(outputAsArray, value)

  /** Whether to remove annotation columns (Default: `true`)
    *
    * @group setParam
    */
  def setCleanAnnotations(value: Boolean): this.type = set(cleanAnnotations, value)

  /** Annotation metadata format (Default: `false`)
    *
    * @group setParam
    */
  def setIncludeMetadata(value: Boolean): this.type = set(includeMetadata, value)

  /** Name of input annotation col
    *
    * @group getParam
    */
  def getOutputCol: String = get(outputCol).getOrElse("finished_" + getInputCol)

  /** Name of EmbeddingsFinisher output cols
    *
    * @group getParam
    */
  def getInputCol: String = $(inputCol)

  setDefault(cleanAnnotations -> true, outputAsArray -> true, includeMetadata -> false)

  private val METADATA_STRUCT_INDEX = 4

  override def transform(dataset: Dataset[_]): DataFrame = {
    var flattenedDataSet = dataset.withColumn(
      $(outputCol), {
        if ($(outputAsArray)) flattenPathsAsArray(dataset.col($(inputCol)))
        else flattenPaths(dataset.col($(inputCol)))
      })

    if ($(includeMetadata)) {
      flattenedDataSet = flattenedDataSet.withColumn(
        $(outputCol) + "_metadata",
        flattenMetadata(dataset.col($(inputCol))))
    }

    FinisherUtil.cleaningAnnotations($(cleanAnnotations), flattenedDataSet)
  }

  private def buildPathsAsArray(metadata: Map[String, String]): List[List[List[String]]] = {
    val paths = extractPaths(metadata)

    val pathsInRDFFormat = paths.map { path =>
      val evenPathIndices = path.indices.toList.filter(index => index % 2 == 0)
      val sliceIndices = evenPathIndices zip evenPathIndices.tail
      sliceIndices.map(sliceIndex => path.slice(sliceIndex._1, sliceIndex._2 + 1).toList)
    }
    pathsInRDFFormat
  }

  private def buildPaths(metadata: Map[String, String]): List[List[String]] = {
    val paths = extractPaths(metadata)

    val pathsInRDFFormat = paths.map { path =>
      val evenPathIndices = path.indices.toList.filter(index => index % 2 == 0)
      val sliceIndices = evenPathIndices zip evenPathIndices.tail
      sliceIndices.map { sliceIndex =>
        val node = path.slice(sliceIndex._1, sliceIndex._2 + 1)
        "(" + node.mkString(",") + ")"
      }
    }
    pathsInRDFFormat
  }

  private def extractPaths(metadata: Map[String, String]): List[Array[String]] = {
    val paths = metadata.flatMap { case (key, value) =>
      if (key.contains("path")) Some(value.split(",")) else None
    }.toList

    paths
  }

  private def flattenPathsAsArray: UserDefinedFunction = udf { annotations: Seq[Row] =>
    annotations.flatMap { row =>
      val metadata = row.getMap[String, String](METADATA_STRUCT_INDEX)
      val pathsInRDFFormat = buildPathsAsArray(metadata.toMap)
      pathsInRDFFormat
    }
  }

  private def flattenPaths: UserDefinedFunction = udf { annotations: Seq[Row] =>
    annotations.flatMap { row =>
      val metadata = row.getMap[String, String](METADATA_STRUCT_INDEX)
      val pathsInRDFFormat = buildPaths(metadata.toMap)
      pathsInRDFFormat
    }
  }

  private def flattenMetadata: UserDefinedFunction = udf { annotations: Seq[Row] =>
    annotations.flatMap { row =>
      val metadata = row.getMap[String, String](4)
      val relationships = metadata.flatMap { case (key, value) =>
        if (key.contains("relationship") || key.contains("entities")) Some("(" + value + ")")
        else None
      }.toList
      relationships
    }
  }

  def annotate(metadata: Map[String, String]): Seq[Annotation] = {
    val pathsInRDFFormat = buildPaths(metadata)
    val result = pathsInRDFFormat.map(pathRDF => "[" + pathRDF.mkString(",") + "]").mkString(",")
    Seq(Annotation(NODE, 0, 0, result, Map()))
  }

  override def copy(extra: ParamMap): Transformer = super.defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {

    FinisherUtil.checkIfInputColsExist(Array(getInputCol), schema)
    FinisherUtil.checkIfAnnotationColumnIsSparkNLPAnnotation(schema, getInputCol)

    require(
      Seq(AnnotatorType.NODE).contains(schema(getInputCol).metadata.getString("annotatorType")),
      s"column [$getInputCol] must be a ${AnnotatorType.NODE} type")

    val metadataFields = FinisherUtil.getMetadataFields(Array(getOutputCol), $(outputAsArray))
    val outputFields = schema.fields ++
      FinisherUtil.getOutputFields(Array(getOutputCol), $(outputAsArray)) ++ metadataFields
    val cleanFields = FinisherUtil.getCleanFields($(cleanAnnotations), outputFields)

    StructType(cleanFields)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy