All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.kotlinnlp.neuralparser.helpers.Validator.kt Maven / Gradle / Ivy

Go to download

NeuralParser is a very simple to use dependency parser, based on the SimpleDNN library and the SyntaxDecoder transition systems framework.

There is a newer version: 0.6.5
Show newest version
/* Copyright 2017-present The KotlinNLP Authors. All Rights Reserved.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 * ------------------------------------------------------------------*/

package com.kotlinnlp.neuralparser.helpers

import com.kotlinnlp.conllio.Sentence as CoNLLSentence
import com.kotlinnlp.conllio.Token as CoNLLToken
import com.kotlinnlp.dependencytree.DependencyTree
import com.kotlinnlp.neuralparser.NeuralParser
import com.kotlinnlp.neuralparser.helpers.preprocessors.BasePreprocessor
import com.kotlinnlp.neuralparser.helpers.preprocessors.SentencePreprocessor
import com.kotlinnlp.neuralparser.helpers.statistics.MetricsCounter
import com.kotlinnlp.neuralparser.helpers.statistics.SentenceMetrics
import com.kotlinnlp.neuralparser.helpers.statistics.Statistics
import com.kotlinnlp.utils.progressindicator.ProgressIndicatorBar

/**
 * The Validator.
 *
 * @param neuralParser the neural parser
 * @property sentences the sentences to parse containing the gold annotation
 * @param sentencePreprocessor the sentence preprocessor (e.g. to perform morphological analysis)
 * @property verbose a Boolean indicating if the verbose mode is enabled (default = true)
 */
class Validator(
  neuralParser: NeuralParser<*>,
  val sentences: List,
  sentencePreprocessor: SentencePreprocessor = BasePreprocessor(),
  val verbose: Boolean = true
) {

  companion object {

    /**
     * The regular expression to match punctuation forms.
     */
    val punctuationRegex = Regex("^[-!\"#%&'()*,./:;?@\\[\\]_{}]+$")
  }

  init {
    require(sentences.all { it.hasAnnotatedHeads() }) {
      "The gold dependency tree of the sentence cannot be null during the evaluation."
    }
  }

  /**
   * A counter of statistic metrics.
   */
  private lateinit var counter: MetricsCounter

  /**
   * A counter of statistic metrics, without considering punctuation.
   */
  private lateinit var counterNoPunct: MetricsCounter

  /**
   * The metrics of a sentence.
   */
  private lateinit var sentenceMetrics: SentenceMetrics

  /**
   * The parser wrapper to parse sentences in CoNLL format.
   */
  private val conllParser = CoNLLDependencyParser(
    neuralParser = neuralParser,
    sentencePreprocessor = sentencePreprocessor)

  /**
   * Get statistics about the evaluation of the parsing accuracy on the given [sentences].
   *
   * @return the statistics of the parsing accuracy
   */
  fun evaluate(): Statistics {

    val parsedSentences: List = this.parseSentences()

    this.initCounters(parsedSentences)

    this.sentences.zip(parsedSentences).forEach { (goldSentence, parsedSentence) ->

      val goldTree: DependencyTree = goldSentence.getDependencyTree()
      val parsedTree: DependencyTree = parsedSentence.getDependencyTree()

      require(parsedTree.size == goldTree.size) { "The dependency tree and its gold haven't the same size" }

      this.sentenceMetrics = SentenceMetrics()

      (0 until parsedSentence.tokens.size).forEach { i ->
        this.addTokenMetrics(
          token = goldSentence.tokens[i],
          tokenIndex = i,
          parsedTree = parsedTree,
          goldTree = goldTree)
      }

      this.updateCorrectSentences()
    }

    return this.buildStats()
  }

  /**
   * Parse the validation CoNLL sentences.
   *
   * @return the list of parsed CoNLL sentences
   */
  private fun parseSentences(): List {

    val progress: ProgressIndicatorBar? = if (this.verbose) ProgressIndicatorBar(this.sentences.size) else null

    if (this.verbose) println("Start parsing of %d sentences:".format(this.sentences.size))

    return this.sentences.map {

      progress?.tick()

      this.conllParser.parse(it)
    }
  }

  /**
   * Initialize the metrics counters.
   *
   * @param parsedSentences a list of parsed sentences
   */
  private fun initCounters(parsedSentences: List) {

    this.counter = MetricsCounter()
    this.counterNoPunct = MetricsCounter()

    this.counter.totalSentences = parsedSentences.size
    this.counterNoPunct.totalSentences = parsedSentences.size
    this.counter.totalTokens = parsedSentences.sumBy { it.tokens.count() }
  }

  /**
   * Add the statistic metrics of a given [token].
   *
   * @param token a token of a sentence
   * @param tokenIndex the index of the given [token]
   * @param parsedTree the dependency tree of the parsed sentence
   * @param goldTree the gold dependency tree of the parsed sentence
   */
  private fun addTokenMetrics(token: CoNLLToken, tokenIndex: Int, parsedTree: DependencyTree, goldTree: DependencyTree) {

    val isNotPunct: Boolean = !punctuationRegex.matches(token.form)

    if (isNotPunct) this.counterNoPunct.totalTokens++

    if (parsedTree.heads[tokenIndex] == goldTree.heads[tokenIndex]) {

      this.addCorrectAttachment(isNotPunct)

      if (parsedTree.deprels[tokenIndex] == goldTree.deprels[tokenIndex])
        this.addCorrectLabeledAttachment(isNotPunct)
      else
        this.addUncorrectLabeledAttachment(isNotPunct)

    } else {
      this.addUncorrectAttachment(isNotPunct)
      this.addUncorrectLabeledAttachment(isNotPunct)
    }

    if (parsedTree.posTags[tokenIndex] == goldTree.posTags[tokenIndex]) this.addCorrectPOSTag(isNotPunct)

    if (parsedTree.deprels[tokenIndex]?.softEquals(goldTree.deprels[tokenIndex]) ?:
      (parsedTree.deprels[tokenIndex] == goldTree.deprels[tokenIndex])) this.addCorrectDeprel(isNotPunct)
  }

  /**
   * Add a correct attachment to the current statistic metrics.
   *
   * @param isNotPunct a Boolean indicating if the attachment is related to a non-punctuation token
   */
  private fun addCorrectAttachment(isNotPunct: Boolean) {

    this.counter.unlabeledAttachments++

    if (isNotPunct) this.counterNoPunct.unlabeledAttachments++
  }

  /**
   * Add an uncorrect attachment to the current statistic metrics.
   *
   * @param isNotPunct a Boolean indicating if the attachment is related to a non-punctuation token
   */
  private fun addUncorrectAttachment(isNotPunct: Boolean) {

    this.sentenceMetrics.correctUnlabeled = false

    if (isNotPunct) this.sentenceMetrics.correctUnlabeledNoPunct = false
  }

  /**
   * Add a correct labeled attachment to the current statistic metrics.
   *
   * @param isNotPunct a Boolean indicating if the attachment is related to a non-punctuation token
   */
  private fun addCorrectLabeledAttachment(isNotPunct: Boolean) {

    this.counter.labeledAttachments++

    if (isNotPunct) this.counterNoPunct.labeledAttachments++
  }

  /**
   * Add an uncorrect labeled attachment to the current statistic metrics.
   *
   * @param isNotPunct a Boolean indicating if the attachment is related to a non-punctuation token
   */
  private fun addUncorrectLabeledAttachment(isNotPunct: Boolean) {

    this.sentenceMetrics.correctLabeled = false

    if (isNotPunct) this.sentenceMetrics.correctLabeledNoPunct = false
  }

  /**
   * Add a correct POS tag to the current statistic metrics.
   *
   * @param isNotPunct a Boolean indicating if the POS tag is related to a non-punctuation token
   */
  private fun addCorrectPOSTag(isNotPunct: Boolean) {

    this.counter.correctPOSTags++

    if (isNotPunct) this.counterNoPunct.correctPOSTags++
  }

  /**
   * Add a correct deprel to the current statistic metrics.
   *
   * @param isNotPunct a Boolean indicating if the deprel is related to a non-punctuation token
   */
  private fun addCorrectDeprel(isNotPunct: Boolean) {

    this.counter.correctDeprels++

    if (isNotPunct) this.counterNoPunct.correctDeprels++
  }

  /**
   * Update the counters of correct sentences with the current [sentenceMetrics].
   */
  private fun updateCorrectSentences() {

    if (this.sentenceMetrics.correctLabeled) this.counter.correctLabeledSentences++
    if (this.sentenceMetrics.correctUnlabeled) this.counter.correctUnlabeledSentences++
    if (this.sentenceMetrics.correctLabeledNoPunct) this.counterNoPunct.correctLabeledSentences++
    if (this.sentenceMetrics.correctUnlabeledNoPunct) this.counterNoPunct.correctUnlabeledSentences++
  }

  /**
   * Build the statistics related to the current counted metrics.
   */
  private fun buildStats(): Statistics {

    val punctStats = this.counter.toStatistics()
    val noPunctStats = this.counterNoPunct.toStatistics()

    return Statistics(
      las = punctStats.las,
      uas = punctStats.uas,
      ps = punctStats.ps,
      ds = punctStats.ds,
      slas = punctStats.slas,
      suas = punctStats.suas,
      noPunctuation = noPunctStats)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy