All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.common.SentenceWithEmbeddings.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.common

import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

import scala.collection.Map

case class WordpieceEmbeddingsSentence(tokens: Array[TokenPieceEmbeddings], sentenceId: Int)

case class TokenPieceEmbeddings(
    wordpiece: String,
    token: String,
    pieceId: Int,
    isWordStart: Boolean,
    isOOV: Boolean,
    embeddings: Array[Float],
    begin: Int,
    end: Int)

object TokenPieceEmbeddings {
  def apply(piece: TokenPiece, embeddings: Array[Float]): TokenPieceEmbeddings = {
    TokenPieceEmbeddings(
      wordpiece = piece.wordpiece,
      token = piece.token,
      pieceId = piece.pieceId,
      isWordStart = piece.isWordStart,
      isOOV =
        false, // FIXME: I think BERT wont have OOV, this "constructor" is called from TensorFlowBert
      embeddings = embeddings,
      begin = piece.begin,
      end = piece.end)
  }
  def apply(
      wordpiece: String,
      token: String,
      pieceId: Int,
      isWordStart: Boolean,
      embeddings: Option[Array[Float]],
      zeroArray: Array[Float],
      begin: Int,
      end: Int): TokenPieceEmbeddings = {

    val vector = embeddings.getOrElse(zeroArray)
    val oov = embeddings match { case Some(_) => false; case default => true; }
    TokenPieceEmbeddings(
      wordpiece = wordpiece,
      token = token,
      pieceId = pieceId,
      isWordStart = isWordStart,
      isOOV = oov,
      embeddings = vector,
      begin = begin,
      end = end)
  }
}

object WordpieceEmbeddingsSentence extends Annotated[WordpieceEmbeddingsSentence] {
  override def annotatorType: String = AnnotatorType.WORD_EMBEDDINGS

  override def unpack(annotations: Seq[Annotation]): Seq[WordpieceEmbeddingsSentence] = {
    val tokens = annotations
      .filter(_.annotatorType == annotatorType)
      .groupBy(_.metadata("sentence").toInt)

    tokens
      .map { case (idx: Int, sentenceTokens: Seq[Annotation]) =>
        val tokensWithSentence = sentenceTokens.map { token =>
          new TokenPieceEmbeddings(
            wordpiece = token.result,
            token = token.metadata("token"),
            pieceId = token.metadata("pieceId").toInt,
            isWordStart = token.metadata("isWordStart").toBoolean,
            isOOV = token.metadata.getOrElse("isOOV", "false").toBoolean,
            embeddings = token.embeddings,
            begin = token.begin,
            end = token.end)
        }.toArray

        WordpieceEmbeddingsSentence(tokensWithSentence, idx)
      }
      .toSeq
      .sortBy(_.sentenceId)
  }

  override def pack(sentences: Seq[WordpieceEmbeddingsSentence]): Seq[Annotation] = {
    sentences.flatMap { sentence =>
      var isFirstToken = true
      sentence.tokens.map { token =>
        // Store embeddings for token
        val embeddings = token.embeddings

        isFirstToken = false
        Annotation(
          annotatorType,
          token.begin,
          token.end,
          token.token,
          Map(
            "sentence" -> sentence.sentenceId.toString,
            "token" -> token.token,
            "pieceId" -> token.pieceId.toString,
            "isWordStart" -> token.isWordStart.toString,
            "isOOV" -> token.isOOV.toString),
          embeddings)
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy