All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.tensorflow.TensorflowTapas.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.tensorflow

import com.johnsnowlabs.ml.tensorflow.sign.ModelSignatureConstants
import com.johnsnowlabs.nlp.annotators.common.TableData
import com.johnsnowlabs.nlp.annotators.tapas.{TapasEncoder, TapasInputData}
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.WordpieceEncoder
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import org.tensorflow.ndarray.buffer.IntDataBuffer

import scala.collection.JavaConverters._

class TensorflowTapas(
    override val tensorflowWrapper: TensorflowWrapper,
    override val sentenceStartTokenId: Int,
    override val sentenceEndTokenId: Int,
    configProtoBytes: Option[Array[Byte]] = None,
    tags: Map[String, Int],
    signatures: Option[Map[String, String]] = None,
    vocabulary: Map[String, Int])
    extends TensorflowBertClassification(
      tensorflowWrapper = tensorflowWrapper,
      sentenceStartTokenId = sentenceStartTokenId,
      sentenceEndTokenId = sentenceEndTokenId,
      configProtoBytes = configProtoBytes,
      tags = tags,
      signatures = signatures,
      vocabulary = vocabulary) {

  def tagTapasSpan(batch: Seq[TapasInputData]): (Array[Array[Float]], Array[Int]) = {

    val tensors = new TensorResources()

    val maxSentenceLength = batch.head.inputIds.length
    val batchLength = batch.length

    val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
    val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
    val segmentBuffers: IntDataBuffer =
      tensors.createIntBuffer(batchLength * maxSentenceLength * 7)

    batch.zipWithIndex
      .foreach { case (input, idx) =>
        val offset = idx * maxSentenceLength
        tokenBuffers.offset(offset).write(input.inputIds)
        maskBuffers.offset(offset).write(input.attentionMask)
        val segmentOffset = idx * maxSentenceLength * 7

        (0 until maxSentenceLength).foreach(pos => {
          val tokenTypeOffset = segmentOffset + pos * 7
          segmentBuffers
            .offset(tokenTypeOffset)
            .write(Array(
              input.segmentIds(pos),
              input.columnIds(pos),
              input.rowIds(pos),
              input.prevLabels(pos),
              input.columnRanks(pos),
              input.invertedColumnRanks(pos),
              input.numericRelations(pos)))
        })
      }

    val session = tensorflowWrapper.getTFSessionWithSignature(
      configProtoBytes = configProtoBytes,
      savedSignatures = signatures,
      initAllTables = false)
    val runner = session.runner

    val tokenTensors =
      tensors.createIntBufferTensor(Array(batchLength.toLong, maxSentenceLength), tokenBuffers)
    val maskTensors =
      tensors.createIntBufferTensor(Array(batchLength.toLong, maxSentenceLength), maskBuffers)
    val segmentTensors = tensors.createIntBufferTensor(
      Array(batchLength.toLong, maxSentenceLength, 7),
      segmentBuffers)

    runner
      .feed(
        _tfBertSignatures.getOrElse(ModelSignatureConstants.InputIds.key, "missing_input_id_key"),
        tokenTensors)
      .feed(
        _tfBertSignatures.getOrElse(
          ModelSignatureConstants.AttentionMask.key,
          "missing_input_mask_key"),
        maskTensors)
      .feed(
        _tfBertSignatures.getOrElse(
          ModelSignatureConstants.TokenTypeIds.key,
          "missing_segment_ids_key"),
        segmentTensors)
      .fetch(_tfBertSignatures
        .getOrElse(ModelSignatureConstants.TapasLogitsOutput.key, "missing_end_logits_key"))
      .fetch(
        _tfBertSignatures
          .getOrElse(
            ModelSignatureConstants.TapasLogitsAggregationOutput.key,
            "missing_start_logits_key"))

    val outs = runner.run().asScala
    val logitsRaw = TensorResources.extractFloats(outs.head)
    val aggregationRaw = TensorResources.extractFloats(outs.last)

    outs.foreach(_.close())
    tensors.clearSession(outs)
    tensors.clearTensors()

    val probabilitiesDim = logitsRaw.length / batchLength
    val flatMask = batch.flatMap(_.attentionMask)
    val probabilities: Array[Array[Float]] = logitsRaw
      .map(x => if (x < -88.7f) -88.7f else x)
      .zipWithIndex
      .map { case (logit, logitIdx) =>
        (1 / (1 + math.exp(-logit).toFloat)) * flatMask(logitIdx)
      }
      .grouped(probabilitiesDim)
      .toArray

    val aggregationDim = aggregationRaw.length / batchLength
    val aggregations: Array[Int] = aggregationRaw
      .grouped(aggregationDim)
      .map(batchLogitAggregations => {
        batchLogitAggregations.zipWithIndex.maxBy(_._1)._2
      })
      .toArray

    (probabilities, aggregations)
  }

  def predictTapasSpan(
      questions: Seq[Annotation],
      tableAnnotation: Annotation,
      maxSentenceLength: Int,
      caseSensitive: Boolean,
      minCellProbability: Float): Seq[Annotation] = {

    val encoder = new WordpieceEncoder(vocabulary)
    val tapasEncoder = new TapasEncoder(sentenceStartTokenId, sentenceEndTokenId, encoder)

    val table = TableData.fromJson(tableAnnotation.result)

    val tapasData = tapasEncoder.encodeTapasData(
      questions = questions.map(_.result),
      table = table,
      caseSensitive = caseSensitive,
      maxSentenceLength = maxSentenceLength)

    val (probabilities, aggregations) = tagTapasSpan(batch = tapasData)

    val cellPredictions = tapasData.zipWithIndex
      .flatMap { case (input, idx) =>
        val maxWidth = input.columnIds.max
        val maxHeight = input.rowIds.max
        if (maxWidth > 0 || maxHeight > 0) {
          val coordToProbabilities = collection.mutable.Map[(Int, Int), Array[Float]]()
          probabilities(idx).zipWithIndex.foreach { case (prob, probIdx) =>
            if (input.segmentIds(probIdx) == 1 && input.columnIds(probIdx) > 0 && input.rowIds(
                probIdx) > 0) {
              val coord = (input.columnIds(probIdx) - 1, input.rowIds(probIdx) - 1)
              coordToProbabilities(coord) =
                coordToProbabilities.getOrElse(coord, Array()) ++ Array(prob)
            }
          }
          val meanCoordProbs = coordToProbabilities.map(x => (x._1, x._2.sum / x._2.length))
          val answerCoordinates = collection.mutable.ArrayBuffer[(Int, Int)]()
          val answerScores = collection.mutable.ArrayBuffer[Float]()
          input.columnIds.indices.foreach { col =>
            input.rowIds.indices.foreach { row =>
              val score = meanCoordProbs.getOrElse((col, row), -1f)
              if (meanCoordProbs.getOrElse((col, row), -1f) > minCellProbability) {
                answerCoordinates.append((col, row))
                answerScores.append(score)
              }
            }
          }
          val results = answerCoordinates.zip(answerScores).sortBy(_._1).toArray
          Seq(results)
        } else {
          Seq()
        }
      }
    cellPredictions.zipWithIndex
      .map { case (result, queryId) =>
        val cellPositions = result.map(_._1)
        val scores = result.map(_._2)
        val answers = cellPositions.map(x => table.rows(x._2)(x._1)).mkString(", ")
        val aggrString = tapasEncoder.getAggregationString(aggregations(queryId))
        val resultString = if (aggrString != "NONE") s"$aggrString($answers)" else answers
        new Annotation(
          annotatorType = AnnotatorType.CHUNK,
          begin = 0,
          end = resultString.length,
          result = resultString,
          metadata = Map(
            "question" -> questions(queryId).result,
            "aggregation" -> aggrString,
            "cell_positions" ->
              cellPositions.map(x => "[%d, %d]".format(x._1, x._2)).mkString(", "),
            "cell_scores" -> scores.map(_.toString).mkString(", ")))
      }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy