All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.common.SentenceWithEmbeddings.scala Maven / Gradle / Ivy

package com.johnsnowlabs.nlp.annotators.common

import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import scala.collection.Map


case class WordpieceEmbeddingsSentence
(
  tokens: Array[TokenPieceEmbeddings],
  sentenceId: Int,
  sentenceEmbeddings: Option[Array[Float]] = None
)

case class TokenPieceEmbeddings(wordpiece: String, token: String, pieceId: Int,
                                isWordStart: Boolean,
                                embeddings: Array[Float], begin: Int, end: Int)

object TokenPieceEmbeddings {
  def apply(piece: TokenPiece, embeddings: Array[Float]): TokenPieceEmbeddings = {
    TokenPieceEmbeddings(
      wordpiece = piece.wordpiece,
      token = piece.token,
      pieceId = piece.pieceId,
      isWordStart = piece.isWordStart,
      embeddings = embeddings,
      begin = piece.begin,
      end = piece.end)
  }
}

object WordpieceEmbeddingsSentence extends Annotated[WordpieceEmbeddingsSentence] {
  override def annotatorType: String = AnnotatorType.WORD_EMBEDDINGS

  override def unpack(annotations: Seq[Annotation]): Seq[WordpieceEmbeddingsSentence] = {
    val tokens = annotations
      .filter(_.annotatorType == annotatorType)
      .groupBy(_.metadata("sentence").toInt)

    tokens.map{case (idx: Int, sentenceTokens: Seq[Annotation]) =>
      val sentenceEmbeddings = sentenceTokens.map(t => t.sentence_embeddings).headOption
      val tokensWithSentence = sentenceTokens.map { token =>
        new TokenPieceEmbeddings(
          wordpiece = token.result,
          token = token.metadata("token"),
          pieceId = token.metadata("pieceId").toInt,
          isWordStart = token.metadata("isWordStart").toBoolean,
          embeddings = token.embeddings,
          begin = token.begin,
          end = token.end
        )
      }.toArray

      WordpieceEmbeddingsSentence(tokensWithSentence, idx, sentenceEmbeddings)
    }.toSeq.sortBy(_.sentenceId)
  }

  override def pack(sentences: Seq[WordpieceEmbeddingsSentence]): Seq[Annotation] = {
    sentences.zipWithIndex.flatMap{case (sentence, sentenceIndex) =>
      var isFirstToken = true
      sentence.tokens.map{token =>
        // Store embeddings for token
        val embeddings = token.embeddings

        // Store sentence embeddings only in one token
        val sentenceEmbeddings =
          if (isFirstToken && sentence.sentenceEmbeddings.isDefined)
            sentence.sentenceEmbeddings.get
          else
            Array.emptyFloatArray

        isFirstToken = false
        Annotation(annotatorType, token.begin, token.end, token.wordpiece,
          Map("sentence" -> sentenceIndex.toString,
            "token" -> token.token,
            "pieceId" -> token.pieceId.toString,
            "isWordStart" -> token.isWordStart.toString
          ),
          embeddings,
          sentenceEmbeddings
        )
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy