
com.kotlinnlp.neuralparser.helpers.Trainer.kt Maven / Gradle / Ivy
/* Copyright 2017-present The KotlinNLP Authors. All Rights Reserved.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
* ------------------------------------------------------------------*/
package com.kotlinnlp.neuralparser.helpers
import com.kotlinnlp.conllio.Sentence as CoNLLSentence
import com.kotlinnlp.dependencytree.DependencyTree
import com.kotlinnlp.neuralparser.NeuralParser
import com.kotlinnlp.neuralparser.helpers.preprocessors.BasePreprocessor
import com.kotlinnlp.neuralparser.helpers.preprocessors.SentencePreprocessor
import com.kotlinnlp.neuralparser.helpers.statistics.Statistics
import com.kotlinnlp.neuralparser.helpers.validator.Validator
import com.kotlinnlp.neuralparser.language.BaseSentence
import com.kotlinnlp.neuralparser.language.ParsingSentence
import com.kotlinnlp.utils.ExamplesIndices
import com.kotlinnlp.utils.Shuffler
import com.kotlinnlp.utils.Timer
import com.kotlinnlp.utils.progressindicator.ProgressIndicatorBar
import java.io.File
import java.io.FileOutputStream
/**
* The training helper of the [NeuralParser].
*
* @param neuralParser a neural parser
* @param batchSize the number of sentences that compose a batch
* @param epochs the number of training epochs
* @param validator the validation helper (if it is null no validation is done after each epoch)
* @param modelFilename the name of the file in which to save the best trained model
* @param minRelevantErrorsCountToUpdate the min count of relevant errors needed to update the neural parser (default 1)
* @param sentencePreprocessor the sentence preprocessor (e.g. to perform morphological analysis)
* @param verbose a Boolean indicating if the verbose mode is enabled (default = true)
*/
abstract class Trainer(
private val neuralParser: NeuralParser<*>,
private val batchSize: Int,
private val epochs: Int,
private val validator: Validator?,
private val modelFilename: String,
private val minRelevantErrorsCountToUpdate: Int = 1,
private val sentencePreprocessor: SentencePreprocessor = BasePreprocessor(),
private val verbose: Boolean = true
) {
/**
* A timer to track the elapsed time.
*/
private var timer = Timer()
/**
* The best accuracy reached during the training.
*/
private var bestAccuracy: Double = -1.0 // -1 used as init value (all accuracy values are in the range [0.0, 1.0])
/**
* Check requirements.
*/
init {
require(this.epochs > 0) { "The number of epochs must be > 0" }
require(this.batchSize > 0) { "The size of the batch must be > 0" }
require(this.minRelevantErrorsCountToUpdate > 0) { "minRelevantErrorsCountToUpdate must be > 0" }
}
/**
* Train the [neuralParser] with the given sentences.
*
* @param trainingSentences the sentences used to train the parser
* @param shuffler a shuffle to shuffle the sentences at each epoch (can be null)
*/
fun train(trainingSentences: List,
shuffler: Shuffler? = Shuffler(enablePseudoRandom = true, seed = 743)) {
(0 until this.epochs).forEach { i ->
this.logTrainingStart(epochIndex = i)
this.newEpoch()
this.trainEpoch(trainingSentences = trainingSentences, shuffler = shuffler)
this.logTrainingEnd()
this.validator?.apply {
logValidationStart()
validateAndSaveModel()
logValidationEnd()
}
}
}
/**
* Train the parser for an epoch.
*
* @param trainingSentences the training sentences
* @param shuffler a shuffle to shuffle the sentences at each epoch (can be null)
*/
private fun trainEpoch(trainingSentences: List,
shuffler: Shuffler?) {
val progress = ProgressIndicatorBar(trainingSentences.size)
this.newBatch()
ExamplesIndices(trainingSentences.size, shuffler = shuffler).forEachIndexed { i, sentenceIndex ->
val endOfBatch: Boolean = (i + 1) % this.batchSize == 0 || i == trainingSentences.lastIndex
progress.tick()
val sentence: CoNLLSentence = trainingSentences[sentenceIndex]
require(sentence.hasAnnotatedHeads()) {
"The gold dependency tree of a sentence cannot be null during the evaluation."
}
this.trainSentence(
sentence = this.sentencePreprocessor.convert(BaseSentence.fromCoNLL(sentence, index = sentenceIndex)),
goldTree = DependencyTree.Labeled(sentence))
if (endOfBatch && this.getRelevantErrorsCount() >= this.minRelevantErrorsCountToUpdate) {
this.update()
this.newBatch()
}
}
}
/**
* Validate the [neuralParser] with the validation helper and save the best model.
* The [validator] is required to be not null.
*/
private fun validateAndSaveModel() {
val stats: Statistics = this.validator!!.evaluate()
println("\n$stats")
if (stats.noPunctuation.uas.perc > this.bestAccuracy) {
this.saveModel()
this.bestAccuracy = stats.noPunctuation.uas.perc
}
}
/**
* Save the model to [modelFilename].
*/
private fun saveModel() {
this.neuralParser.model.dump(FileOutputStream(File(this.modelFilename)))
println("\nNEW BEST ACCURACY! Model saved to \"${this.modelFilename}\"")
}
/**
* Log when training starts.
*
* @param epochIndex the current epoch index
*/
private fun logTrainingStart(epochIndex: Int) {
if (this.verbose) {
this.timer.reset()
println("\nEpoch ${epochIndex + 1} of ${this.epochs}")
println("\nStart training...")
}
}
/**
* Log when training ends.
*/
private fun logTrainingEnd() {
if (this.verbose) {
println("Elapsed time: %s".format(this.timer.formatElapsedTime()))
}
}
/**
* Log when validation starts.
*/
private fun logValidationStart() {
if (this.verbose) {
this.timer.reset()
println() // new line
}
}
/**
* Log when validation ends.
*/
private fun logValidationEnd() {
if (this.verbose) {
println("Elapsed time: %s".format(this.timer.formatElapsedTime()))
}
}
/**
* Beat the occurrence of a new batch.
*/
protected open fun newBatch() = Unit
/**
* Beat the occurrence of a new epoch.
*/
protected open fun newEpoch() = Unit
/**
* Update the [neuralParser].
*/
protected abstract fun update()
/**
* Train the parser with the given [sentence] and [goldTree].
*
* @param sentence a sentence
* @param goldTree the gold dependency tree
*/
protected abstract fun trainSentence(sentence: ParsingSentence, goldTree: DependencyTree.Labeled)
/**
* @return the count of the relevant errors
*/
protected abstract fun getRelevantErrorsCount(): Int
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy