All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.ai.ZeroShotNerClassification.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2017-2023 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.ai

import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.TensorflowWrapper
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

private[johnsnowlabs] class ZeroShotNerClassification(
    override val tensorflowWrapper: Option[TensorflowWrapper],
    override val onnxWrapper: Option[OnnxWrapper],
    override val sentenceStartTokenId: Int,
    override val sentenceEndTokenId: Int,
    override val sentencePadTokenId: Int,
    val handleImpossibleAnswer: Boolean = false,
    configProtoBytes: Option[Array[Byte]] = None,
    tags: Map[String, Int],
    signatures: Option[Map[String, String]] = None,
    merges: Map[(String, String), Int],
    vocabulary: Map[String, Int])
    extends RoBertaClassification(
      tensorflowWrapper,
      onnxWrapper,
      sentenceStartTokenId,
      sentenceEndTokenId,
      sentencePadTokenId,
      configProtoBytes,
      tags,
      signatures,
      merges,
      vocabulary) {

  override def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
    val (startLogits, endLogits) = super.tagSpan(batch)
    val contextStartOffsets = batch.map(_.indexOf(sentenceEndTokenId))

    (
      startLogits
        .zip(contextStartOffsets)
        .map(x =>
          x._1.zipWithIndex.map(y =>
            if (((y._2 > 0) && y._2 <= x._2) || (x._2 == startLogits.length - 1)) 0f else y._1)),
      endLogits
        .zip(contextStartOffsets)
        .map(x =>
          x._1.zipWithIndex.map(y =>
            if (((y._2 > 0) && y._2 <= x._2) || (x._2 == startLogits.length - 1)) 0f else y._1)))
  }

  override def predictSpan(
      documents: Seq[Annotation],
      maxSentenceLength: Int,
      caseSensitive: Boolean,
      mergeTokenStrategy: String,
      engine: String): Seq[Annotation] = {
    val questionAnnot = Seq(documents.head)
    val contextAnnot = documents.drop(1)

    val wordPieceTokenizedQuestion =
      tokenizeDocument(questionAnnot, maxSentenceLength, caseSensitive)
    val wordPieceTokenizedContext =
      tokenizeDocument(contextAnnot, maxSentenceLength, caseSensitive)

    val encodedInput =
      encodeSequence(wordPieceTokenizedQuestion, wordPieceTokenizedContext, maxSentenceLength)
    val (startLogits, endLogits) = tagSpan(encodedInput)

    val startScores = startLogits.map(x => x.map(y => y / x.sum)).head
    val endScores = endLogits.map(x => x.map(y => y / x.sum)).head

    val startIndex =
      startScores.zipWithIndex.drop(if (handleImpossibleAnswer) 0 else 1).maxBy(_._1)
    val endIndex = endScores.zipWithIndex.drop(if (handleImpossibleAnswer) 0 else 1).maxBy(_._1)

    val allTokenPieces =
      wordPieceTokenizedQuestion.head.tokens ++ wordPieceTokenizedContext.flatMap(x => x.tokens)
    val decodedAnswer = allTokenPieces.slice(startIndex._2 - 3, endIndex._2 - 2)
    // Check if the answer span starts at the CLS symbol 0 - if so return empty string
    val content =
      if (startIndex._2 > 0)
        mergeTokenStrategy match {
          case MergeTokenStrategy.vocab =>
            decodedAnswer.filter(_.isWordStart).map(x => x.token).mkString(" ")
          case MergeTokenStrategy.sentencePiece =>
            val token = ""
            decodedAnswer
              .map(x =>
                if (x.isWordStart) " " + token + x.token
                else token + x.token)
              .mkString("")
              .trim
        }
      else ""

    if (content.isEmpty) {
      Seq(
        Annotation(
          annotatorType = AnnotatorType.CHUNK,
          begin = 0,
          end = 0,
          result = content,
          metadata = Map(
            "sentence" -> contextAnnot.head.metadata.getOrElse("sentence", "0"),
            "chunk" -> "0",
            "start" -> "0",
            "start_score" -> "0",
            "end" -> "0",
            "end_score" -> "0",
            "score" -> "0",
            "start_char" -> "0",
            "end_char" -> "0")))
    } else {
      val sentenceOffset = contextAnnot.head.begin
      val tokenStartAdjustment =
        if (contextAnnot.head.result(decodedAnswer.head.begin - sentenceOffset) == ' ') 1 else 0
      val startPos = decodedAnswer.head.begin + tokenStartAdjustment
      val endPos = decodedAnswer.last.end
      Seq(
        Annotation(
          annotatorType = AnnotatorType.CHUNK,
          begin = startPos,
          end = endPos,
          result = content,
          metadata = Map(
            "sentence" -> contextAnnot.head.metadata.getOrElse("sentence", "0"),
            "score" -> ((startIndex._1 + endIndex._1) / 2).toString)))
    }

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy