All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.kotlinnlp.neuralparser.parsers.transitionbased.templates.inputcontexts.TokensEncodingContext.kt Maven / Gradle / Ivy

/* Copyright 2017-present The KotlinNLP Authors. All Rights Reserved.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 * ------------------------------------------------------------------*/

package com.kotlinnlp.neuralparser.parsers.transitionbased.templates.inputcontexts

import com.kotlinnlp.neuralparser.language.ParsingToken
import com.kotlinnlp.neuralparser.parsers.transitionbased.utils.items.DenseItem
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray
import com.kotlinnlp.syntaxdecoder.context.InputContext

/**
 * The tokens context with an encoding representation for each token.
 *
 * @property tokens a list of tokens
 * @property encodingSize the size of the each encoding
 * @property nullItemVector used to represent the encoding of a null item of the decoding window
 */
abstract class TokensEncodingContext>(
  val tokens: List,
  val encodingSize: Int,
  val nullItemVector: DenseNDArray,
  val trainingMode: Boolean = false
) : InputContext {

  /**
   * The length of the sentence.
   */
  override val length: Int = this.tokens.size

  /**
   * The errors associated to the [nullItemVector].
   */
  var nullItemErrors: DenseNDArray? = null
    private set

  /**
   * Check conditions
   */
  init {
    require(nullItemVector.length == this.encodingSize) {
      "nullItemVector size not compatible with declared encodingSize."
    }
  }

  /**
   * Accumulate the given [errors] into the related items at the given [itemIndex].
   *
   * @param itemIndex the index of the item affected by the errors
   * @param errors the errors array to accumulate
   */
  fun accumulateItemErrors(itemIndex: Int?, errors: DenseNDArray) =
    if (itemIndex != null) {
      this.items[itemIndex].accumulateErrors(errors)
    } else {
      this.accumulateNullItemErrors(errors)
    }

  /**
   * @param itemId the id of an item
   *
   * @return get the encoding vector of the item with the given id
   */
  abstract fun getTokenEncoding(itemId: Int?): DenseNDArray

  /**
   * @param errors the errors of the [nullItemVector] to accumulate
   */
  private fun accumulateNullItemErrors(errors: DenseNDArray) {

    if (this.nullItemErrors != null) {
      this.nullItemErrors!!.assignSum(errors)
    } else {
      this.nullItemErrors = errors.copy()
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy