All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.kotlinnlp.simplednn.deeplearning.attentionnetwork.attentionlayer.AttentionLayerOptimizer.kt Maven / Gradle / Ivy

Go to download

SimpleDNN is a machine learning lightweight open-source library written in Kotlin whose purpose is to support the development of feed-forward and recurrent Artificial Neural Networks.

There is a newer version: 0.14.0
Show newest version
/* Copyright 2016-present The KotlinNLP Authors. All Rights Reserved.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, you can obtain one at http://mozilla.org/MPL/2.0/.
 * ------------------------------------------------------------------*/

package com.kotlinnlp.simplednn.deeplearning.attentionnetwork.attentionlayer

import com.kotlinnlp.simplednn.core.optimizer.Optimizer
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray

/**
 * The optimizer of the Attention Layer.
 *
 * @property params the attention layer parameters to optimize
 * @property updateMethod the [UpdateMethod] for the optimization (e.g. ADAM, AdaGrad, ...)
 */
class AttentionLayerOptimizer(
  val params: AttentionLayerParameters,
  updateMethod: UpdateMethod<*>
) : Optimizer(updateMethod) {

  /**
   * A support structure to store the errors of the context vector.
   */
  private val contextVectorErrors: DenseNDArray = this.params.contextVector.values.zerosLike()

  /**
   * The counter of the amount of errors accumulated.
   */
  private var count: Int = 0

  /**
   * Accumulate the parameters errors contained into the [errors].
   *
   * @param errors the errors of the Attention Layer parameters
   */
  fun accumulateErrors(errors: AttentionLayerParameters) {

    this.contextVectorErrors.assignSum(errors.contextVector.values)
    this.count += 1
  }

  /**
   * Update the parameters.
   */
  override fun update() {

    this.contextVectorErrors.assignDiv(this.count.toDouble()) // average errors
    this.updateMethod.update(this.params.contextVector, this.contextVectorErrors)

    this.contextVectorErrors.zeros()
    this.count = 0
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy