All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.tapas.TapasEncoder.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.tapas

import com.johnsnowlabs.nlp.annotators.common.{IndexedToken, Sentence, TableData}
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}

import java.time.format.DateTimeFormatter
import java.time.temporal.ChronoField
import scala.util.Try
import scala.util.matching.Regex

case class TapasCellDate(day: Option[Int], month: Option[Int], year: Option[Int]) {
  def isDefined: Boolean = {
    day.isDefined || month.isDefined || year.isDefined
  }

  def compareTo(other: TapasCellDate): Boolean = TapasCellDate.compare(this, other)

  def isEqualTo(other: TapasCellDate): Boolean = TapasCellDate.areEqual(this, other)
}

object TapasCellDate {
  def Empty(): TapasCellDate = {
    TapasCellDate(None, None, None)
  }

  def computeNumericRelation(v1: TapasCellDate, v2: TapasCellDate): Int = {
    if (v1.isDefined && v2.isDefined) {
      if (areEqual(v1, v2))
        TapasNumericRelation.EQ
      else if (compare(v1, v2)) TapasNumericRelation.LT
      else TapasNumericRelation.GT

    } else TapasNumericRelation.NONE
  }

  def areEqual(v1: TapasCellDate, v2: TapasCellDate): Boolean = {
    val dayMatch = v1.day.getOrElse(Int.MinValue) == v2.day.getOrElse(Int.MinValue)
    val monthMatch = v1.month.getOrElse(Int.MinValue) == v2.month.getOrElse(Int.MinValue)
    val yearMatch = v1.year.getOrElse(Int.MinValue) == v2.year.getOrElse(Int.MinValue)
    dayMatch && monthMatch && yearMatch
  }

  def compare(v1: TapasCellDate, v2: TapasCellDate): Boolean = {
    val dayDiff = Try(v1.day.get - v2.day.get).getOrElse(0)
    val monthDiff = Try(v1.month.get - v2.month.get).getOrElse(0)
    val yearDiff = Try(v1.year.get - v2.year.get).getOrElse(0)

    (dayDiff < 0) || (monthDiff < 0) || (yearDiff < 0)
  }
}

case class TapasCellValue(number: Option[Float], date: TapasCellDate) {
  def compareTo(other: TapasCellValue): Boolean = TapasCellValue.compare(this, other)
  def getNumericRelationTo(other: TapasCellValue): Int =
    TapasCellValue.computeNumericRelation(this, other)
}

object TapasCellValue {

  def compareNumbers(v1: TapasCellValue, v2: TapasCellValue): Boolean = {
    val n1 = v1.number.getOrElse(Float.MinValue)
    val n2 = v2.number.getOrElse(Float.MinValue)
    n1 < n2
  }

  def compareDates(v1: TapasCellValue, v2: TapasCellValue): Boolean = v1.date.compareTo(v2.date)

  def compare(v1: TapasCellValue, v2: TapasCellValue): Boolean = {
    if (v1.number.isDefined && v2.number.isDefined) {
      compareNumbers(v1, v2)
    } else if (v1.date.isDefined && v1.date.isDefined) {
      compareDates(v1, v2)
    } else {
      false
    }
  }

  def computeNumericRelation(v1: TapasCellValue, v2: TapasCellValue): Int = {
    if (v1.number.isDefined && v2.number.isDefined) (v1.number.get - v2.number.get) match {
      case 0 => return TapasNumericRelation.EQ
      case x if x > 0 => return TapasNumericRelation.GT
      case x if x < 0 => return TapasNumericRelation.LT
    }
    else if (v1.date.isDefined && v2.date.isDefined) {
      return TapasCellDate.computeNumericRelation(v1.date, v2.date)
    }

    TapasNumericRelation.NONE
  }
}

case class TapasNumericValueSpan(begin: Int, end: Int, value: TapasCellValue) {

  def valueId: String = {
    if (value.number.isDefined)
      value.number.toString
    else if (value.date.day.isDefined || value.date.month.isDefined || value.date.year.isDefined)
      Array(
        if (value.date.day.isDefined) value.date.day.get.toString else "NA",
        if (value.date.month.isDefined) value.date.month.get.toString else "NA",
        if (value.date.year.isDefined) value.date.year.get.toString else "NA").mkString("@")
    else TapasNumericValueSpan.emptyValueId
  }

  def compareTo(other: TapasNumericValueSpan): Boolean =
    TapasCellValue.compare(this.value, other.value)
}

object TapasNumericValueSpan {
  val emptyValueId = "NA"
}

object TapasNumericRelation {
  val NONE = 0
  val HEADER_TO_CELL = 1 // Connects header to cell.
  val CELL_TO_HEADER = 2 // Connects cell to header.
  val QUERY_TO_HEADER = 3 // Connects query to headers.
  val QUERY_TO_CELL = 4 // Connects query to cells.
  val ROW_TO_CELL = 5 // Connects row to cells.
  val CELL_TO_ROW = 6 // Connects cells to row.
  val EQ = 7 // Annotation value is same as cell value
  val LT = 8 // Annotation value is less than cell value
  val GT = 9 // Annotation value is greater than cell value
}

case class TapasInputData(
    inputIds: Array[Int],
    attentionMask: Array[Int],
    segmentIds: Array[Int],
    columnIds: Array[Int],
    rowIds: Array[Int],
    prevLabels: Array[Int],
    columnRanks: Array[Int],
    invertedColumnRanks: Array[Int],
    numericRelations: Array[Int])

class TapasEncoder(
    val sentenceStartTokenId: Int,
    val sentenceEndTokenId: Int,
    encoder: WordpieceEncoder) {

  protected val NUMBER_PATTERN: Regex =
    "((^|\\s)[+-])?((\\.\\d+)|(\\d+(,\\d\\d\\d)*(\\.\\d*)?))".r

  protected val DT_FORMATTERS: Array[(DateTimeFormatter, Regex)] = Array(
    ("MMMM", "\\w+".r),
    ("yyyy", "\\d{4}".r),
    ("yyyy's'", "\\d{4}s".r),
    ("MMM yyyy", "\\w{3}\\s+\\d{4}".r),
    ("MMMM yyyy", "\\w+\\s+\\d{4}".r),
    ("MMMM d", "\\w+\\s+\\d{1,2}".r),
    ("MMM d", "\\w{3}\\s+\\d{1,2}".r),
    ("d MMMM", "\\d{1,2}\\s+\\w+".r),
    ("d MMM", "\\d{1,2}\\s+\\w{3}".r),
    ("MMMM d, yyyy", "\\w+\\s+\\d{1,2},\\s*\\d{4}".r),
    ("d MMMM yyyy", "\\d{1,2}\\s+\\w+\\s+\\d{4}".r),
    ("M-d-yyyy", "\\d{1,2}-\\d{1,2}-\\d{4}".r),
    ("yyyyM-d", "\\d{4}-\\d{1,2}-\\d{1,2}".r),
    ("yyyy MMMM", "\\d{4}\\s+\\w+".r),
    ("d MMM yyyy", "\\d{1,2}\\s+\\w{3}\\s+\\d{4}".r),
    ("yyyy-M-d", "\\d{4}-\\d{1,2}-\\d{1,2}".r),
    ("MMM d, yyyy", "\\w{3}\\s+\\d{1,2},\\s*\\d{4}".r),
    ("d.M.yyyy", "\\d{1,2}\\.\\d{1,2}\\.\\d{4}".r),
    ("E, MMM d", "\\w{3},\\s+\\w{3}\\s+\\d{1,2}".r),
    ("EEEE, MMM d", "\\w+,\\s+\\w{3}\\s+\\d{1,2}".r),
    ("E, MMMM d", "\\w{3},\\s+\\w+\\s+\\d{1,2}".r),
    ("EEEE, MMMM d", "\\w+,\\s+\\w+\\s+\\d{1,2}".r)).map(x =>
    (DateTimeFormatter.ofPattern(x._1), x._2))

  protected val MIN_YEAR = 1700
  protected val MAX_YEAR = 2120

  protected val MIN_NUMBER_OF_ROWS_WITH_VALUES_PROPORTION = 0.5f

  protected val ORDINAL_SUFFIXES: Array[String] = Array("st", "nd", "rd", "th")
  protected val NUMBER_WORDS: Array[String] = Array(
    "zero",
    "one",
    "two",
    "three",
    "four",
    "five",
    "six",
    "seven",
    "eight",
    "nine",
    "ten",
    "eleven",
    "twelve")
  protected val ORDINAL_WORDS: Array[String] = Array(
    "zeroth",
    "first",
    "second",
    "third",
    "fourth",
    "fith",
    "sixth",
    "seventh",
    "eighth",
    "ninth",
    "tenth",
    "eleventh",
    "twelfth")

  protected val AGGREGATIONS = Map(0 -> "NONE", 1 -> "SUM", 2 -> "AVERAGE", 3 -> "COUNT")

  def getAggregationString(aggregationId: Int): String = {
    AGGREGATIONS(aggregationId)
  }

  protected def getAllSpans(text: String, maxNgramLength: Int): Array[(Int, Int)] = {
    var startIndices: Array[Int] = Array()
    text.zipWithIndex.flatMap { case (ch, i) =>
      if (ch.isLetterOrDigit) {
        if (i == 0 || !text(i - 1).isLetterOrDigit)
          startIndices = startIndices ++ Array(i)
        if (((i + 1) == text.length) || !text(i + 1).isLetterOrDigit) {
          startIndices.drop(startIndices.length - maxNgramLength).map(x => (x, i + 1))
        } else Array[(Int, Int)]()
      } else {
        Array[(Int, Int)]()
      }
    }.toArray
  }

  protected def parseNumber(text: String): Option[Float] = {
    var pText = text
    ORDINAL_SUFFIXES
      .foreach(suffix => {
        if (pText.endsWith(suffix)) {
          pText = pText.dropRight(suffix.length)
        }
      })
    pText = pText.replace(",", "")
    Try(pText.toFloat).toOption
  }

  protected def parseDate(text: String): Option[TapasCellDate] = {
    DT_FORMATTERS
      .filter(dtf => dtf._2.pattern.matcher(text).matches() && Try(dtf._1.parse(text)).isSuccess)
      .map(dtf => {
        val tempAccessor = dtf._1.parse(text)

        val day = Try(tempAccessor.get(ChronoField.DAY_OF_MONTH)).toOption
        val month = Try(tempAccessor.get(ChronoField.MONTH_OF_YEAR)).toOption
        val year1 = Try(tempAccessor.get(ChronoField.YEAR)).toOption
        val year2 =
          if (year1.isDefined) year1 else Try(tempAccessor.get(ChronoField.YEAR_OF_ERA)).toOption
        val year =
          if (year2.isDefined && year2.get >= MIN_YEAR && year2.get <= MAX_YEAR) year2 else None

        if (day.isDefined || month.isDefined || year.isDefined)
          Some(TapasCellDate(day, month, year))
        else
          None
      })
      .filter(_.isDefined)
      .map(x => return x)

    None
  }

  def parseText(text: String): Array[TapasNumericValueSpan] = {

    val spans = collection.mutable.Map[(Int, Int), Array[TapasCellValue]]()

    def addNumberSpan(span: (Int, Int), number: Float): Unit = {
      if (!spans.contains(span))
        spans(span) = Array()
      spans(span) = spans(span) ++ Array(TapasCellValue(Some(number), TapasCellDate.Empty()))
    }

    def addDateSpan(span: (Int, Int), date: TapasCellDate): Unit = {
      if (!spans.contains(span))
        spans(span) = Array()
      spans(span) = spans(span) ++ Array(TapasCellValue(None, date))
    }

    // add numbers using pattern
    NUMBER_PATTERN
      .findAllMatchIn(text)
      .foreach(m => {
        val spanText = text.slice(m.start, m.end)
        val number = parseNumber(spanText)
        if (number.isDefined)
          addNumberSpan((m.start, m.end), number.get)
      })

    // add numbers
    getAllSpans(text, 1)
      .filter(span => !spans.contains(span))
      .foreach(span => {
        val spanText = text.slice(span._1, span._2)
        val number = parseNumber(spanText)
        if (number.isDefined)
          addNumberSpan(span, number.get)
        NUMBER_WORDS.zipWithIndex.foreach { case (numWord, index) =>
          if (spanText == numWord)
            addNumberSpan(span, index.toFloat)
        }
        ORDINAL_WORDS.zipWithIndex.foreach { case (numWord, index) =>
          if (spanText == numWord)
            addNumberSpan(span, index.toFloat)
        }
      })
    // add dates
    getAllSpans(text, 5).foreach(span => {
      val spanText = text.slice(span._1, span._2)
      val date = parseDate(spanText)
      if (date.isDefined)
        addDateSpan(span, date.get)
    })

    val sortedSpans = spans.toArray.sortBy(x => (x._1._2 - x._1._1, -x._1._1)).reverse
    var selectedSpans = collection.mutable.ArrayBuffer[((Int, Int), Array[TapasCellValue])]()
    sortedSpans.foreach { case (span, values) =>
      if (!selectedSpans
          .map(_._1)
          .exists(selectedSpan => selectedSpan._1 <= span._1 && span._2 <= selectedSpan._2)) {
        selectedSpans = selectedSpans ++ Array((span, values))
      }
    }
    selectedSpans
      .sortBy(x => x._1._1)
      .flatMap { case (span, values) =>
        values.map(value => TapasNumericValueSpan(span._1, span._2, value))
      }
      .toArray
  }

  def encodeTapasData(
      questions: Seq[String],
      table: TableData,
      caseSensitive: Boolean,
      maxSentenceLength: Int): Seq[TapasInputData] = {

    val basicTokenizer = new BasicTokenizer(caseSensitive = true, hasBeginEnd = false)

    val questionInputIds = questions.map(question => {

      val sentence = new Sentence(start = 0, end = question.length, content = question, index = 0)
      val tokens = basicTokenizer.tokenize(sentence)
      (if (caseSensitive)
         tokens
       else
         tokens.map(x => IndexedToken(x.token.toLowerCase(), x.begin, x.end)))
        .flatMap(token => encoder.encode(token))
        .map(_.pieceId)
    })
    val maxQuestionLength = questionInputIds.map(_.length).max

    val inputIds = collection.mutable.ArrayBuffer[Int]()
    val attentionMask = collection.mutable.ArrayBuffer[Int]()
    val segmentIds = collection.mutable.ArrayBuffer[Int]()
    val columnIds = collection.mutable.ArrayBuffer[Int]()
    val rowIds = collection.mutable.ArrayBuffer[Int]()
    val prevLabels = collection.mutable.ArrayBuffer[Int]()
    val columnRanks = collection.mutable.ArrayBuffer[Int]()
    val invertedColumnRanks = collection.mutable.ArrayBuffer[Int]()

    table.header.indices.foreach(colIndex => {
      val sentence = new Sentence(
        start = 0,
        end = table.header(colIndex).length,
        content = table.header(colIndex),
        index = colIndex)
      val tokens = basicTokenizer.tokenize(sentence)
      val columnInputIds =
        if (caseSensitive)
          tokens
        else
          tokens.map(x => IndexedToken(x.token.toLowerCase(), x.begin, x.end))

      columnInputIds
        .flatMap(token => encoder.encode(token))
        .foreach(x => {
          inputIds.append(x.pieceId)
          attentionMask.append(1)
          segmentIds.append(1)
          columnIds.append(colIndex + 1)
          rowIds.append(0)
          prevLabels.append(0)
          columnRanks.append(0)
          invertedColumnRanks.append(0)
        })

    })

    val tableCellValues =
      collection.mutable.Map[Int, Array[(TapasNumericValueSpan, Int, Array[Int])]]()

    table.rows.indices
      .map(rowIndex => {
        table.header.indices
          .map(colIndex => {
            val cellText = table.rows(rowIndex)(colIndex)
            val sentence =
              new Sentence(start = 0, end = cellText.length, content = cellText, index = 0)
            val tokens = basicTokenizer.tokenize(sentence)
            val cellInputIds = (if (caseSensitive)
                                  tokens
                                else
                                  tokens.map(x =>
                                    IndexedToken(x.token.toLowerCase(), x.begin, x.end)))
              .flatMap(token => encoder.encode(token))

            cellInputIds.foreach(x => {
              inputIds.append(x.pieceId)
              attentionMask.append(1)
              segmentIds.append(1)
              columnIds.append(colIndex + 1)
              rowIds.append(rowIndex + 1)
              prevLabels.append(0)
              columnRanks.append(0)
              invertedColumnRanks.append(0)
            })

            val tapasNumValuesWithTokenIndices = parseText(cellText).map(numValue =>
              (
                numValue,
                rowIndex,
                cellInputIds.zipWithIndex
                  //                                  .filter(id => id._1.begin>=numValue.begin && id._1.end <= numValue.end)
                  .map(_._2 + (inputIds.length - cellInputIds.length))))
            tableCellValues(colIndex) =
              tableCellValues.getOrElse(colIndex, Array()) ++ tapasNumValuesWithTokenIndices

          })
          .toArray
      })
      .toArray

    // Compute column ranks
    tableCellValues.foreach { case (_, values) =>
      val rowsWithNumberValues =
        values.filter(x => x._1.value.number.isDefined).map(_._2).distinct.length
      val rowsWithDateValues =
        values.filter(x => x._1.value.number.isEmpty).map(_._2).distinct.length

      val sortedValues =
        if (rowsWithNumberValues >= math.max(
            rowsWithDateValues,
            MIN_NUMBER_OF_ROWS_WITH_VALUES_PROPORTION * table.rows.length)) {
          values.sortWith((v1, v2) => TapasCellValue.compareNumbers(v1._1.value, v2._1.value))
        } else if (rowsWithDateValues >= math.max(
            rowsWithNumberValues,
            MIN_NUMBER_OF_ROWS_WITH_VALUES_PROPORTION * table.rows.length)) {
          values.sortWith((v1, v2) => TapasCellValue.compareDates(v1._1.value, v2._1.value))
        } else Array[(TapasNumericValueSpan, Int, Array[Int])]()

      if (!sortedValues.isEmpty) {
        var rank = 0
        val curValue = TapasNumericValueSpan.emptyValueId
        val sortedValuesWithRanks = sortedValues
          .map(x => {
            if (x._1.valueId != curValue) {
              rank = rank + 1
            }
            x._3.foreach(inputIndex => columnRanks(inputIndex) = rank)
            (x, rank)
          })
        val maxRank = sortedValuesWithRanks.map(_._2).max
        sortedValuesWithRanks.foreach(x =>
          x._1._3.foreach(inputIndex => invertedColumnRanks(inputIndex) = maxRank - x._2 + 1))
      }
    }

    // generate data
    questions.zip(questionInputIds).map { case (question, qIds) =>
      // compute numeric relations
      val numericRelations = collection.mutable.ArrayBuffer.fill(inputIds.length)(0)
      val questionNumValues = parseText(question)
      val cellRelations = collection.mutable.Map[(Int, Int), Array[Int]]()
      questionNumValues.foreach(questionNumValue => {
        tableCellValues.foreach { case (columnIdx, cellValues) =>
          cellValues
            .foreach(cellValue => {
              val rel = questionNumValue.value.getNumericRelationTo(cellValue._1.value)
              if (rel != TapasNumericRelation.NONE)
                cellRelations((columnIdx, cellValue._2)) = cellRelations
                  .getOrElse((columnIdx, cellValue._2), Array()) ++ Array(rel)
            })
        }
      })
      cellRelations.foreach { case (pos, relations) =>
        tableCellValues(pos._1)
          .filter(_._2 == pos._2)
          .flatMap(_._3)
          .foreach(tokenIndex =>
            numericRelations(tokenIndex) = relations
              .map(rel => math.pow(2, rel - TapasNumericRelation.EQ))
              .sum
              .toInt)
      }
      val emptyTokenTypes = Array.fill(qIds.length + 2)(0)
      val padding = Array.fill(maxQuestionLength - qIds.length)(0)

      def setMaxSentenceLimit(vector: Array[Int]): Array[Int] = {
        vector.slice(0, math.min(maxSentenceLength, vector.length))
      }

      TapasInputData(
        inputIds = setMaxSentenceLimit(
          Array(sentenceStartTokenId) ++ qIds ++ Array(
            sentenceEndTokenId) ++ inputIds ++ padding),
        attentionMask =
          setMaxSentenceLimit(emptyTokenTypes.map(_ => 1) ++ attentionMask ++ padding),
        segmentIds = setMaxSentenceLimit(emptyTokenTypes ++ segmentIds ++ padding),
        columnIds = setMaxSentenceLimit(emptyTokenTypes ++ columnIds ++ padding),
        rowIds = setMaxSentenceLimit(emptyTokenTypes ++ rowIds ++ padding),
        prevLabels = setMaxSentenceLimit(emptyTokenTypes ++ prevLabels ++ padding),
        columnRanks = setMaxSentenceLimit(emptyTokenTypes ++ columnRanks ++ padding),
        invertedColumnRanks =
          setMaxSentenceLimit(emptyTokenTypes ++ invertedColumnRanks ++ padding),
        numericRelations = setMaxSentenceLimit(emptyTokenTypes ++ numericRelations ++ padding))
    }

  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy