All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.kotlinnlp.neuralparser.utils.actionsembeddings.ActionsVectorsMap.kt Maven / Gradle / Ivy

/* Copyright 2017-present The KotlinNLP Authors. All Rights Reserved.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 * ------------------------------------------------------------------*/

package com.kotlinnlp.neuralparser.utils.actionsembeddings

import com.kotlinnlp.simplednn.deeplearning.embeddings.EmbeddingsMap
import com.kotlinnlp.simplednn.simplemath.concatVectorsV
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray
import java.io.Serializable

/**
 * The container of vectors that encode actions, mapped to their TPD ids.
 *
 * @property size the size of each embedding vector
 * @property transitionsSize the number of transitions
 * @property posTagsSize the number of POS tags
 * @property deprelsSize the number of deprels
 */
class ActionsVectorsMap(
  val size: Int,
  val transitionsSize: Int,
  val posTagsSize: Int,
  val deprelsSize: Int
) : Serializable {

  companion object {

    /**
     * Private val used to serialize the class (needed from Serializable).
     */
    @Suppress("unused")
    private const val serialVersionUID: Long = 1L
  }

  /**
   * An Embeddings map that maps transition ids to embeddings vectors.
   */
  val transitionEmbeddings = EmbeddingsMap(size = this.size)

  /**
   * An Embeddings map that maps POS tag ids to embeddings vectors.
   */
  val posTagEmbeddings = EmbeddingsMap(size = this.size)

  /**
   * An Embeddings map that maps deprel ids to embeddings vectors.
   */
  val deprelEmbeddings = EmbeddingsMap(size = this.size)

  /**
   * Check sizes and initialize the deprels embeddings maps.
   */
  init {
    require(this.size > 0) { "The embeddings vectors size must be > 0" }
    require(this.transitionsSize > 0) { "The transitions size must be > 0" }
    require(this.posTagsSize > 0) { "The POS tags size must be > 0" }
    require(this.deprelsSize > 0) { "The deprels size must be > 0" }

    (0 until this.transitionsSize).forEach { dId -> this.transitionEmbeddings.set(dId)  }
    (0 until this.deprelsSize).forEach { dId -> this.deprelEmbeddings.set(dId)  }
    (0 until this.posTagsSize).forEach { dId -> this.posTagEmbeddings.set(dId)  }
  }

  /**
   * @param tId the transition id
   * @param pId the POS tag id
   * @param dId the deprel id
   *
   * @throws KotlinNullPointerException if [tId] is not within the range [0, [transitionsSize])
   * @throws KotlinNullPointerException if [pId] is not within the range [0, [posTagsSize])
   * @throws KotlinNullPointerException if [dId] is not within the range [0, [deprelsSize])
   *
   * @return the embedding vector related to the given ids
   */
  operator fun get(tId: Int, pId: Int, dId: Int): DenseNDArray = concatVectorsV(
    this.transitionEmbeddings.get(tId).array.values,
    this.posTagEmbeddings.get(pId).array.values,
    this.deprelEmbeddings.get(dId).array.values
  )

  /**
   * @return the 'nullEmbedding'
   */
  fun getNullVector(): DenseNDArray = concatVectorsV(
    this.transitionEmbeddings.get(null).array.values,
    this.posTagEmbeddings.get(null).array.values,
    this.deprelEmbeddings.get(null).array.values
  )
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy