All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.kotlinnlp.lssencoder.LSSOptimizer.kt Maven / Gradle / Ivy

/* Copyright 2018-present LHRParser Authors. All Rights Reserved.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 * -----------------------------------------------------------------------------*/

package com.kotlinnlp.lssencoder

import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod
import com.kotlinnlp.simplednn.core.optimizer.Optimizer
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer
import com.kotlinnlp.simplednn.deeplearning.birnn.BiRNNParameters
import com.kotlinnlp.simplednn.deeplearning.birnn.deepbirnn.DeepBiRNNParameters
import com.kotlinnlp.tokensencoder.TokensEncoderOptimizer

/**
 * The optimizer of the [LSSModel].
 *
 * @param model the model to optimize
 * @param updateMethod the update method helper (Learning Rate, ADAM, AdaGrad, ...)
 */
class LSSOptimizer(
  private val model: LSSModel<*, *>,
  updateMethod: UpdateMethod<*>
) : Optimizer(updateMethod) {

  /**
   * The optimizer of the tokens encoder.
   */
  private val tokensEncoderOptimizer: TokensEncoderOptimizer =
    this.model.tokensEncoderWrapperModel.model.buildOptimizer(updateMethod)

  /**
   * The optimizer of the heads encoder.
   */
  private val contextEncoderOptimizer: ParamsOptimizer =
    ParamsOptimizer(params = this.model.contextEncoderModel.model, updateMethod = updateMethod)

  /**
   * The optimizer of the heads encoder.
   */
  private val headsEncoderOptimizer: ParamsOptimizer =
    ParamsOptimizer(params = this.model.headsEncoderBiRNN.model, updateMethod = updateMethod)

  /**
   * Update the parameters of the neural modules of the [model].
   */
  override fun update() {
    this.tokensEncoderOptimizer.update()
    this.contextEncoderOptimizer.update()
    this.headsEncoderOptimizer.update()
  }

  /**
   * Accumulate the given [paramsErrors].
   *
   * @param paramsErrors the parameters errors to accumulate
   * @param copy a Boolean indicating if the [paramsErrors] can be used as reference or must be copied. Set copy = false
   *             to optimize the accumulation when the amount of the errors to accumulate is 1. (default = true)
   */
  override fun accumulate(paramsErrors: LSSParameters, copy: Boolean) {
    this.tokensEncoderOptimizer.accumulate(paramsErrors = paramsErrors.tokensEncoderParams, copy = copy)
    this.contextEncoderOptimizer.accumulate(paramsErrors = paramsErrors.contextEncoderParams, copy = copy)
    this.headsEncoderOptimizer.accumulate(paramsErrors = paramsErrors.headsEncoderParams, copy = copy)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy