All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.kotlinnlp.neuralparser.parsers.transitionbased.models.arceagerspine.ArcEagerSpineEmbeddingsFeaturesExtractor.kt Maven / Gradle / Ivy

Go to download

NeuralParser is a very simple to use dependency parser, based on the SimpleDNN library and the SyntaxDecoder transition systems framework.

There is a newer version: 0.6.5
Show newest version
/* Copyright 2017-present The KotlinNLP Authors. All Rights Reserved.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 * ------------------------------------------------------------------*/

package com.kotlinnlp.neuralparser.parsers.transitionbased.models.arceagerspine

import com.kotlinnlp.neuralparser.parsers.transitionbased.templates.inputcontexts.TokensEmbeddingsContext
import com.kotlinnlp.neuralparser.parsers.transitionbased.templates.featuresextractor.TWFeaturesExtractorTrainable
import com.kotlinnlp.neuralparser.parsers.transitionbased.templates.supportstructure.multiprediction.MPSupportStructure
import com.kotlinnlp.neuralparser.parsers.transitionbased.utils.features.GroupedDenseFeatures
import com.kotlinnlp.neuralparser.parsers.transitionbased.utils.items.DenseItem
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray
import com.kotlinnlp.syntaxdecoder.modules.featuresextractor.FeaturesExtractor
import com.kotlinnlp.syntaxdecoder.transitionsystem.Transition
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arceagerspine.ArcEagerSpineState
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arceagerspine.ArcEagerSpineTransition
import com.kotlinnlp.syntaxdecoder.transitionsystem.models.arceagerspine.ArcEagerSpineStateView
import com.kotlinnlp.syntaxdecoder.utils.DecodingContext
import com.kotlinnlp.syntaxdecoder.utils.toTransitions
import com.kotlinnlp.syntaxdecoder.utils.toTransitionsMap
import com.kotlinnlp.utils.MultiMap

/**
 * The FeaturesExtractor that extracts Embeddings as features for the ArcEagerSpine transition system.
 */
class ArcEagerSpineEmbeddingsFeaturesExtractor
  :
  TWFeaturesExtractorTrainable<
    ArcEagerSpineState,
    ArcEagerSpineTransition,
    TokensEmbeddingsContext,
    GroupedDenseFeatures,
    ArcEagerSpineStateView,
    MPSupportStructure>() {

  /**
   * The group id of this transition.
   */
  private val Transition.groupId: Int get() = Utils.getGroupId(this)

  /**
   * Extract features using the given [decodingContext] amd [supportStructure].
   *
   * @param decodingContext the decoding context
   * @param supportStructure the decoding support structure
   *
   * @return the extracted features
   */
  override fun extract(
    decodingContext: DecodingContext,
    supportStructure: MPSupportStructure): GroupedDenseFeatures {

    val featuresMap = mutableMapOf>()

    decodingContext.actions
      .filter { it.isAllowed }
      .toTransitions()
      .groupBy { it.groupId }
      .forEach { groupId, transitions ->

        featuresMap[groupId] = mutableMapOf()

        transitions.forEach { transition ->

          featuresMap.getValue(groupId)[transition.id] = this.extractViewFeatures(
            stateView = ArcEagerSpineStateView(state = decodingContext.extendedState.state, transition = transition),
            context = decodingContext.extendedState.context
          )
        }
      }

    return GroupedDenseFeatures(featuresMap = MultiMap(featuresMap))

  }

  /**
   * Backward errors through this [FeaturesExtractor], starting from the errors of the features contained in the given
   * [decodingContext].
   * Errors are required to be already set into the given features.
   *
   * @param decodingContext the decoding context that contains extracted features with their errors
   * @param supportStructure the decoding support structure
   */
  override fun backward(
    decodingContext: DecodingContext,
    supportStructure: MPSupportStructure) {

    val transitionsMap: Map = decodingContext.actions.toTransitionsMap()

    decodingContext.features.errors.errorsMap.forEach { _: Any, transitionId: Int, errors: DenseNDArray ->

      val itemsWindow: List = this.getTokensWindow(
        stateView = ArcEagerSpineStateView(
          state = decodingContext.extendedState.state,
          transition = transitionsMap.getValue(transitionId)))

      val tokensErrors: List = errors.splitV(decodingContext.extendedState.context.encodingSize)

      this.accumulateItemsErrors(
        items = decodingContext.extendedState.context.items,
        itemsErrors = itemsWindow.zip(tokensErrors))
    }
  }

  /**
   * Get the tokens window respect to a given state
   *
   * @param stateView a view of the state
   *
   * @return the tokens window as list of Int
   */
  override fun getTokensWindow(stateView: ArcEagerSpineStateView): List = when (stateView.transition.type) {

    Transition.Type.SHIFT -> listOf(
      stateView.stack[0],
      stateView.buffer[0]
    )

    Transition.Type.ROOT -> listOf(
      stateView.stack[0]
    )

    Transition.Type.ARC_LEFT, Transition.Type.ARC_RIGHT -> listOf(
      stateView.stack[0],
      stateView.buffer[0]
    )

    else -> throw RuntimeException()
  }

  /**
   * @param tokenId a token id
   * @param context a tokens dense context
   *
   * @return the token representation as dense array
   */
  override fun getTokenEncoding(tokenId: Int?, context: TokensEmbeddingsContext): DenseNDArray
    = context.getTokenEncoding(tokenId)

  /**
   * Beat the occurrence of a new example.
   */
  override fun newExample() = Unit

  /**
   * Beat the occurrence of a new batch.
   */
  override fun newBatch() = Unit

  /**
   * Beat the occurrence of a new epoch.
   */
  override fun newEpoch() = Unit

  /**
   * Update the trainable components of this FeaturesExtractor.
   */
  override fun update() = Unit

  /**
   * Accumulate the given [itemsErrors] into the related decoding [items].
   *
   * @param items the decoding items
   * @param itemsErrors a list of Pairs 
   */
  private fun accumulateItemsErrors(items: List, itemsErrors: List>) {

    itemsErrors.forEach { (itemIndex, errors) ->
      if (itemIndex != null) {
        items[itemIndex].accumulateErrors(errors)
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy