All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.tensorflow.NerDatasetEncoder.scala Maven / Gradle / Ivy

package com.johnsnowlabs.ml.tensorflow

import com.johnsnowlabs.nlp.annotators.common.WordpieceEmbeddingsSentence

class NerDatasetEncoder(val params: DatasetEncoderParams) extends Serializable {

  private val nonDefaultTags = params.tags
    .filter(_ != params.defaultTag)
    .zipWithIndex
    .map(p => (p._1, p._2 + 1))
    .toMap

  val tag2Id = Map(params.defaultTag -> 0) ++ nonDefaultTags
  val tags = tag2Id
    .map(p => (p._2, p._1))
    .toArray
    .sortBy(p => p._1)
    .map(p => p._2)

  val chars = params.chars.toArray

  val char2Id = params.chars.zip(1 to params.chars.length).toMap

  def getOrElse[T](source: Array[T], i: Int, value: => T): T = {
    if (i < source.length)
      source(i)
    else
      value
  }

  def encodeInputData(sentences: Array[WordpieceEmbeddingsSentence]): NerBatch = {

    val batchSize = sentences.length

    if (batchSize == 0)
      return NerBatch.empty

    val sentenceLengths = sentences.map(s => s.tokens.length)
    val maxSentenceLength = sentenceLengths.max

    if (maxSentenceLength == 0)
      return NerBatch.empty

    val wordLengths = sentences.map {
      sentence =>
        val lengths = sentence.tokens.map(word => word.wordpiece.length)
        Range(0, maxSentenceLength)
          .map{idx => getOrElse(lengths, idx, 0)}
          .toArray
    }


    assert(wordLengths.flatten.nonEmpty, "")
    if (wordLengths.flatten.isEmpty) {
      return NerBatch.empty
    }

    val maxWordLength = wordLengths.flatten.max

    val wordEmbeddings =
    Range(0, batchSize).map{i =>
      val sentence = sentences(i)
      Range(0, maxSentenceLength).map{j =>
        if (j < sentence.tokens.length)
          sentence.tokens(j).embeddings
        else
          params.emptyEmbeddings

      }.toArray
    }.toArray

    val charIds =
      Range(0, batchSize).map { i =>
        val sentence = sentences(i)
        Range(0, maxSentenceLength).map { j =>
          val word = (if (j < sentence.tokens.length)
            sentence.tokens(j).wordpiece
          else
            "").toCharArray

          Range(0, maxWordLength).map { k =>
            val char = getOrElse(word, k, Char.MinValue)
            char2Id.getOrElse(char, 0)
          }.toArray
        }.toArray
      }.toArray

    val isWordStart = sentences.map { sentence =>
      Range(0, maxSentenceLength).map { j =>
        if (j < sentence.tokens.length)
          sentence.tokens(j).isWordStart
        else
          false
      }.toArray
    }

    new NerBatch(
      wordEmbeddings,
      charIds,
      wordLengths,
      sentenceLengths,
      maxSentenceLength,
      isWordStart
    )
  }

  def encodeTags(tags: Array[Array[String]]): Array[Array[Int]] = {
    val batchSize = tags.length
    val maxSentence = tags.map(t => t.length).max

    (0 until batchSize).map{i =>
      (0 until maxSentence).map{j =>
        val tag = getOrElse(tags(i), j, params.defaultTag)
        tag2Id.getOrElse(tag, 0)
      }.toArray
    }.toArray
  }

  /**
    * Converts Tag Identifiers to Source Names
    *
    * @param tagIds Tag Ids encoded for Tensorflow Model.
    * @return Tag names
    */
  def decodeOutputData(tagIds: Array[Int]): Array[String] = {
    tagIds.map(id => getOrElse(tags, id, params.defaultTag))
  }

  /**
    * Converts Tensorflow tags output to 2-dimensional Array with shape: (Batch, Sentence Length).
    *
    * @param tags 2-dimensional tensor in plain array
    * @param sentenceLength Every sentence length (number of words).
    * @return List of tags for each sentence
    */
  def convertBatchTags(tags: Array[String], sentenceLength: Array[Int]): Array[Array[String]] = {
    val sentences = sentenceLength.length
    val maxSentenceLength = tags.length / sentences

    Range(0, sentences).map{i =>
      Range(0, sentenceLength(i)).map{j =>
        tags(i * maxSentenceLength + j)
      }.toArray
    }.toArray
  }
}

/**
  * Batch that contains data in Tensorflow input format.
  */
class NerBatch (
  // Word vector representation. Shape: Batch x Max Sentence Length x Embeddings Dim
  val wordEmbeddings: Array[Array[Array[Float]]],

  // Char ids for every word in every sentence. Shape: Batch x Max Sentence Length x Max Word length
  val charIds: Array[Array[Array[Int]]],

  // Word Length of every sentence. Shape: Batch x Max Sentence Length
  val wordLengths: Array[Array[Int]],

  // Length of every batch sentence. Shape: Batch
  val sentenceLengths: Array[Int],

  // Max length of sentence
  val maxLength: Int,

  // Is current wordpiece is token start? Shape: Batch x Max Sentence Length
  val isWordStart: Array[Array[Boolean]]
)
{
  def batchSize: Int = wordEmbeddings.length
}

object NerBatch {
  def empty = new NerBatch(Array.empty, Array.empty, Array.empty, Array.empty, 0, Array.empty)
}


case class DatasetEncoderParams
(
  tags: List[String],
  chars: List[Char],
  emptyVector: List[Float],
  embeddingsDim: Int,
  defaultTag: String = "O"
) {
  val emptyEmbeddings = emptyVector.toArray
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy