All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.training.POS.scala Maven / Gradle / Ivy

package com.johnsnowlabs.nlp.training

import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, concat_ws, lit, split, udf}
import org.apache.spark.sql.types.MetadataBuilder

import scala.collection.mutable.ArrayBuffer

private case class TaggedToken(token: String, tag: String)
private case class TaggedDocument(sentence: String, taggedTokens: Array[TaggedToken])
private case class Annotations(text: String, document: Array[Annotation], pos: Array[Annotation])

case class POS() {

  /*
  * Add Metadata annotationType to output DataFrame
  * NOTE: This should be replaced by an existing function when it's accessible in next release
  * */

  def wrapColumnMetadata(col: Column, annotatorType: String, outPutColName: String): Column = {
    val metadataBuilder: MetadataBuilder = new MetadataBuilder()
    metadataBuilder.putString("annotatorType", annotatorType)
    col.as(outPutColName, metadataBuilder.build)
  }

  /*
   * This section is to help users to convert text files in token|tag style into DataFrame
   * with POS Annotation for training PerceptronApproach
   * */

  private def createDocumentAnnotation(sentence: String) = {
    Array(Annotation(
      AnnotatorType.DOCUMENT,
      0,
      sentence.length - 1,
      sentence,
      Map.empty[String, String]
    ))
  }

  private def createPosAnnotation(sentence: String, taggedTokens: Array[TaggedToken]) = {
    var lastBegin = 0
    taggedTokens.map { case TaggedToken(token, tag) =>
      val tokenBegin = sentence.indexOf(token, lastBegin)
      val a = Annotation(
        AnnotatorType.POS,
        tokenBegin,
        tokenBegin + token.length - 1,
        tag,
        Map("word" -> token)
      )
      lastBegin += token.length
      a
    }
  }

  private def lineToTaggedDocument(line: String, delimiter: String) = {

    val splitted = line.replaceAll(s"(?:${delimiter.head}\\w+)+(\\s)", "$0##$1").split("##").map(_.trim)

    val tokenTags = splitted.flatMap(token => {
      val tokenTag = token.split(delimiter.head).map(_.trim)
      if (tokenTag.exists(_.isEmpty) || tokenTag.length != 2)
        // Ignore broken pairs or pairs with delimiter char
        None
      else
        Some(TaggedToken(tokenTag.head, tokenTag.last))
    })
    TaggedDocument(tokenTags.map(_.token).mkString(" "), tokenTags)
  }

  def readDataset(
                   sparkSession: SparkSession,
                   path: String,
                   delimiter: String = "|",
                   outputPosCol: String = "tags",
                   outputDocumentCol: String = "document",
                   outputTextCol: String = "text"
                 ): DataFrame = {
    import sparkSession.implicits._

    require(delimiter.length == 1, s"Delimiter must be one character long. Received $delimiter")

    val dataset = sparkSession.read.textFile(path)
      .filter(_.nonEmpty)
      .map(line => lineToTaggedDocument(line, delimiter))
      .map { case TaggedDocument(sentence, taggedTokens) =>
        Annotations(
          sentence,
          createDocumentAnnotation(sentence),
          createPosAnnotation(sentence, taggedTokens)
        )
      }

    dataset
        .withColumnRenamed(
          "text",
          outputTextCol
        )
      .withColumn(
        outputDocumentCol,
        wrapColumnMetadata(dataset("document"), AnnotatorType.DOCUMENT, outputDocumentCol)
      )
      .withColumn(
        outputPosCol,
        wrapColumnMetadata(dataset("pos"), AnnotatorType.POS, outputPosCol)
      )
      .select(outputTextCol, outputDocumentCol, outputPosCol)
  }

  // For testing purposes when there is an array of tokens and an array of labels
  def readFromDataframe(posDataframe: DataFrame, tokensCol: String = "tokens", labelsCol: String = "labels",
                        outPutDocColName: String = "text", outPutPosColName: String = "tags"): DataFrame = {
    def annotatorType: String = AnnotatorType.POS

    def annotateTokensTags: UserDefinedFunction = udf { (tokens: Seq[String], tags: Seq[String], text: String) =>
      lazy val strTokens = tokens.mkString("#")
      lazy val strPosTags = tags.mkString("#")

      require(tokens.length == tags.length, s"Cannot train from DataFrame since there" +
        s" is a row with different amount of tags and tokens:\n$strTokens\n$strPosTags")

      val tokenTagAnnotation: ArrayBuffer[Annotation] = ArrayBuffer()
      def annotatorType: String = AnnotatorType.POS
      var lastIndex = 0

      for ((e, i) <- tokens.zipWithIndex) {

        val beginOfToken = text.indexOfSlice(e, lastIndex)
        val endOfToken = (beginOfToken + e.length) - 1

        val fullPOSAnnotatorStruct = new Annotation(
          annotatorType = annotatorType,
          begin=beginOfToken,
          end=endOfToken,
          result=tags(i),
          metadata=Map("word" -> e)
        )
        tokenTagAnnotation += fullPOSAnnotatorStruct
        lastIndex = text.indexOfSlice(e, lastIndex)
      }
      tokenTagAnnotation
    }

    val tempDataFrame = posDataframe
      .withColumn(outPutDocColName,  concat_ws(" ", col(tokensCol)))
      .withColumn(outPutPosColName, annotateTokensTags(col(tokensCol), col(labelsCol), col(outPutDocColName)))
      .drop(tokensCol, labelsCol)

    tempDataFrame.withColumn(
      outPutPosColName,
      wrapColumnMetadata(tempDataFrame(outPutPosColName), annotatorType, outPutPosColName)
    )
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy