All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.kotlinnlp.simplednn.deeplearning.attentiverecurrentnetwork.ForwardHelper.kt Maven / Gradle / Ivy

Go to download

SimpleDNN is a machine learning lightweight open-source library written in Kotlin whose purpose is to support the development of feed-forward and recurrent Artificial Neural Networks.

There is a newer version: 0.14.0
Show newest version
/* Copyright 2016-present The KotlinNLP Authors. All Rights Reserved.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, you can obtain one at http://mozilla.org/MPL/2.0/.
 * ------------------------------------------------------------------*/

package com.kotlinnlp.simplednn.deeplearning.attentiverecurrentnetwork

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray
import com.kotlinnlp.simplednn.core.layers.feedforward.FeedforwardLayerStructure
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor
import com.kotlinnlp.simplednn.deeplearning.attentionnetwork.AttentionNetwork
import com.kotlinnlp.simplednn.simplemath.concatVectorsV
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory

/**
 * The forward helper of the [AttentiveRecurrentNetwork].
 *
 * @property network the attentive recurrent network of this helper
 */
class ForwardHelper(private val network: AttentiveRecurrentNetwork) {

  /**
   * The recurrent context vector used in first state.
   */
  private val initRecurrentContext: DenseNDArray =
    DenseNDArrayFactory.zeros(Shape(this.network.model.recurrentContextSize))

  /**
   * A boolean indicating if the current is the first state of recursion.
   */
  private var firstRecurrentState: Boolean = true

  /**
   * @param inputSequence the input sequence
   * @param lastPredictionLabel the context label vector used to encode the memory of the last prediction (can be null
   *                            if it is the first state)
   * @param firstState a boolean indicating if this is the first state
   * @param initHidden the initial hidden array (null by default)
   */
  fun forward(inputSequence: List,
              lastPredictionLabel: DenseNDArray?,
              firstState: Boolean,
              initHidden: DenseNDArray? = null): DenseNDArray {

    require(firstState || lastPredictionLabel != null) {
      "The last prediction label cannot be null for states after the first."
    }

    if (firstState) this.resetHistory()

    val recurrentContext: DenseNDArray = if (firstState)
      initHidden ?: this.initRecurrentContext
    else
      this.forwardRecurrentContext(lastPredictionLabel = lastPredictionLabel!!, initHidden = initHidden)

    val stateEncoding: DenseNDArray = this.encodeState(sequence = inputSequence, recurrentContext = recurrentContext)

    return this.getOutputProcessor().forward(concatVectorsV(stateEncoding, recurrentContext))
  }

  /**
   * Reset the recurrent history of the network.
   */
  private fun resetHistory() {

    this.firstRecurrentState = true

    this.network.transformLayersPool.releaseAll()
    this.network.usedTransformLayers.clear()

    this.network.attentionNetworksPool.releaseAll()
    this.network.usedAttentionNetworks.clear()

    this.network.outputNetworkPool.releaseAll()
    this.network.usedOutputProcessors.clear()
  }

  /***
   * @param lastPredictionLabel the context label vector used to encode the memory of the last prediction
   * @param initHidden the initial hidden array (can be null)
   *
   * @return the recurrent context for the current state
   */
  private fun forwardRecurrentContext(lastPredictionLabel: DenseNDArray, initHidden: DenseNDArray?): DenseNDArray {

    val output: DenseNDArray = this.network.recurrentContextProcessor.forward(
      featuresArray = concatVectorsV(this.getLastStateEncoding(), lastPredictionLabel),
      initHiddenArrays = if (this.firstRecurrentState) listOf(initHidden) else null,
      firstState = this.firstRecurrentState)

    this.firstRecurrentState = false

    return output
  }

  /**
   * Encode the current state.
   *
   * @param sequence the sequence to decode
   * @param recurrentContext the recurrent context
   *
   * @return the state encoding as result of the [AttentionNetwork]
   */
  private fun encodeState(sequence: List, recurrentContext: DenseNDArray): DenseNDArray {

    val attentionNetwork = this.getAttentionNetwork()

    return attentionNetwork.forward(
      inputSequence = ArrayList(sequence.map { AugmentedArray(values = it) }),
      attentionSequence = this.buildAttentionSequence(sequence = sequence, recurrentContext = recurrentContext))
  }

  /**
   * @param sequence the input sequence
   * @param recurrentContext the recurrent context
   *
   * @return the sequence of attention arrays
   */
  private fun buildAttentionSequence(sequence: List,
                                     recurrentContext: DenseNDArray): ArrayList {

    val transformLayers = this.getTransformLayers(size = sequence.size)

    return ArrayList(transformLayers.zip(sequence).map { (layer, inputArray) ->

      layer.setInput(concatVectorsV(inputArray, recurrentContext))
      layer.forward()

      layer.outputArray.values
    })
  }

  /**
   * Get an available group of transform layers, adding it into the usedTransformLayers list.
   *
   * @param size the number of transform layer to build
   *
   * @return an available transform layer
   */
  private fun getTransformLayers(size: Int): List> {

    this.network.usedTransformLayers.add(
      List(size = size, init = { this.network.transformLayersPool.getItem() })
    )

    return this.network.usedTransformLayers.last()
  }

  /**
   * Get an available Attention Network, adding it into the usedAttentionNetworks list.
   *
   * @return an available Attention Network
   */
  private fun getAttentionNetwork(): AttentionNetwork {

    this.network.usedAttentionNetworks.add(this.network.attentionNetworksPool.getItem())

    return this.network.usedAttentionNetworks.last()
  }

  /**
   * Get an available output processor, adding it into the usedOutputProcessors list.
   *
   * @return an available output processor
   */
  private fun getOutputProcessor(): FeedforwardNeuralProcessor {

    this.network.usedOutputProcessors.add(this.network.outputNetworkPool.getItem())

    return this.network.usedOutputProcessors.last()
  }

  /**
   * @return the last state encoding
   */
  private fun getLastStateEncoding(): DenseNDArray = this.network.usedAttentionNetworks.last().getOutput()
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy