All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.kotlinnlp.simplednn.core.layers.RecurrentLayerUnit.kt Maven / Gradle / Ivy

/* Copyright 2016-present The KotlinNLP Authors. All Rights Reserved.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, you can obtain one at http://mozilla.org/MPL/2.0/.
 * ------------------------------------------------------------------*/

package com.kotlinnlp.simplednn.core.layers

import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArrayMask
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape
import com.kotlinnlp.simplednn.simplemath.ndarray.sparse.SparseNDArray

/**
 * The basic unit of the recurrent layer, which extends the [LayerUnit] with the recurrent contribution.
 */
class RecurrentLayerUnit>(size: Int) : LayerUnit(size) {

  init {
    this.assignValues(DenseNDArrayFactory.emptyArray(Shape(size)))
  }

  /**
   * Add the recurrent contribution to the unit.
   *
   * @param parameters the parameters associated to this unit
   * @param prevContribution the output array to add as contribution from the previous state
   *
   * g += wRec (dot) prevContribution
   */
  fun addRecurrentContribution(parameters: RecurrentParametersUnit, prevContribution: DenseNDArray) {

    val wRec = parameters.recurrentWeights.values

    this.values.assignSum(wRec.dot(prevContribution))
  }

  /**
   * Assign errors to the parameters associated to this unit. The errors of the output must be already set.
   *
   * gb = errors * 1
   * gw = errors (dot) x
   * gwRec = errors (dot) yPrev
   *
   * @param paramsErrors the parameters errors associated to this unit
   * @param x the input of the unit
   * @param yPrev the output array as contribution from the previous state
   * @param mePropMask the mask of the top k output nodes, in order to execute the 'meProp' algorithm
   */
  fun assignParamsGradients(paramsErrors: RecurrentParametersUnit,
                            x: InputNDArrayType,
                            yPrev: DenseNDArray? = null,
                            mePropMask: NDArrayMask? = null) {

    super.assignParamsGradients(paramsErrors = paramsErrors, x = x, mePropMask = mePropMask)

    val gwRec: NDArray<*> = paramsErrors.recurrentWeights.values

    if (yPrev != null) {

      if (mePropMask != null) {
        require(x is DenseNDArray) { "Cannot apply 'meProp' method if input is not dense" }
        require(gwRec is SparseNDArray) {
          "Cannot apply 'meProp' method with errors not sparse. Ensure to enable 'meProp' into the params."
        }

        x as DenseNDArray; gwRec as SparseNDArray

        gwRec.assignDot(this.errors.maskBy(mePropMask), yPrev.t)

      } else {
        gwRec.assignDot(this.errors, yPrev.t)
      }

    } else {
      gwRec.zeros()
    }
  }

  /**
   * Get the errors of the output in the previous state. The errors of the output in the current state must be
   * already set.
   *
   * @param parameters the parameters associated to this unit
   * @param mePropMask the mask of the top k output nodes, in order to execute the 'meProp' algorithm
   *
   * @return the errors of the recursion of this unit
   */
  fun getRecurrentErrors(parameters: RecurrentParametersUnit, mePropMask: NDArrayMask? = null): DenseNDArray {

    val wRec: DenseNDArray = parameters.recurrentWeights.values as DenseNDArray

    return if (mePropMask != null)
      this.errors.t.dot(wRec, mask = mePropMask)
    else
      this.errors.t.dot(wRec)
  }

  /**
   * Get the relevance of the input of the unit. The relevance of the output must be already set.
   *
   * @param x the input of the unit
   * @param contributions the contribution of the input to calculate the output
   *
   * @return the relevance of the input of the unit
   */
  fun getInputRelevance(x: InputNDArrayType,
                        contributions: RecurrentParametersUnit,
                        prevStateExists: Boolean): NDArray<*> {

    return if (prevStateExists)
      this.getInputRelevancePartitioned(x = x, contributions = contributions)
    else
      this.getInputRelevance(x = x, contributions = contributions)
  }

  /**
   * @param x the input of the unit
   * @param contributions the contributions of this unit in the last forward
   *
   * @return the relevance of the input of the unit, calculated from the input partition of the output relevance
   */
  private fun getInputRelevancePartitioned(x: InputNDArrayType, contributions: RecurrentParametersUnit): NDArray<*> {

    val y: DenseNDArray = this.valuesNotActivated
    val yRec: DenseNDArray = contributions.biases.values as DenseNDArray
    val yInput: DenseNDArray = y.sub(yRec)
    val yInputRelevance: DenseNDArray = RelevanceUtils.getRelevancePartition1(
      yRelevance = this.relevance as DenseNDArray,
      y = y,
      yContribute1 = yInput,
      yContribute2 = yRec)  // the recurrent contrib. to y is saved into the biases contrib.

    return RelevanceUtils.calculateRelevanceOfArray(
      x = x,
      y = yInput,
      yRelevance = yInputRelevance,
      contributions = contributions.weights.values
    )
  }

  /**
   * Get the relevance of the output in the previous state. The relevance of the output in the current state must be
   * already set.
   *
   * @param contributions the contributions of this unit in the last forward
   * @param yPrev the output array as contribution from the previous state
   *
   * @return the relevance of the output in the previous state in respect of the current one
   */
  fun getRecurrentRelevance(contributions: RecurrentParametersUnit, yPrev: DenseNDArray): DenseNDArray {

    val yRec: DenseNDArray = contributions.biases.values as DenseNDArray
    val yRecRelevance: DenseNDArray = RelevanceUtils.getRelevancePartition2(
      yRelevance = this.relevance as DenseNDArray,
      y = this.valuesNotActivated,
      yContribute2 = yRec)

    return RelevanceUtils.calculateRelevanceOfDenseArray(
      x = yPrev,
      y = yRec,
      yRelevance = yRecRelevance,
      contributions = contributions.recurrentWeights.values as DenseNDArray
    )
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy