com.kotlinnlp.simplednn.helpers.training.TrainingHelper.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of simplednn Show documentation
Show all versions of simplednn Show documentation
SimpleDNN is a machine learning lightweight open-source library written in Kotlin whose purpose is to
support the development of feed-forward and recurrent Artificial Neural Networks.
/* Copyright 2016-present The KotlinNLP Authors. All Rights Reserved.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, you can obtain one at http://mozilla.org/MPL/2.0/.
* ------------------------------------------------------------------*/
package com.kotlinnlp.simplednn.helpers.training
import com.kotlinnlp.simplednn.core.functionalities.losses.LossCalculator
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer
import com.kotlinnlp.simplednn.helpers.training.utils.TrainingStatistics
import com.kotlinnlp.simplednn.helpers.validation.ValidationHelper
import com.kotlinnlp.simplednn.dataset.Example
import com.kotlinnlp.simplednn.dataset.Shuffler
import com.kotlinnlp.simplednn.helpers.training.utils.ExamplesIndices
import com.kotlinnlp.progressindicator.ProgressIndicatorBar
/**
*
*/
abstract class TrainingHelper(
open val neuralProcessor: NeuralProcessor,
val optimizer: ParamsOptimizer,
val lossCalculator: LossCalculator,
val verbose: Boolean = false) {
/**
* The statistics of training (accuracy, loss, etc..).
*/
val statistics = TrainingStatistics()
/**
* When timing started.
*/
private var startTime: Long = 0
/**
* Train the NeuralNetwork of the [neuralProcessor] over the specified number of [epochs], grouping examples in
* batches of the given [batchSize] and shuffling them with the given [shuffler] before each epoch.
* If [validationHelper] is not null, the NeuralNetwork is tested over the given [validationExamples] after each
* epoch.
*
* @param trainingExamples training examples
* @param epochs number of epochs
* @param batchSize the size of each batch (default 1)
* @param validationExamples validation examples (default null)
* @param validationHelper the helper for the validation (default null)
* @param shuffler the [Shuffler] to shuffle [trainingExamples] before each epoch (default null)
*/
fun train(trainingExamples: ArrayList,
epochs: Int,
batchSize: Int,
validationExamples: ArrayList? = null,
validationHelper: ValidationHelper? = null,
shuffler: Shuffler? = null) {
require(batchSize > 0)
this.statistics.reset()
for (i in 0 until epochs) {
this.logTrainingStart(epochIndex = i)
this.trainEpoch(trainingExamples = trainingExamples, batchSize = batchSize, shuffler = shuffler)
this.logTrainingEnd()
if (validationHelper != null) {
require(validationExamples != null)
this.logValidationStart(validationExamples!!.size)
this.statistics.lastAccuracy = validationHelper.validate(validationExamples)
this.logValidationEnd()
}
}
}
/**
* Train the NeuralNetwork of the [neuralProcessor], grouping examples in batches of the given [batchSize] and
* shuffling them with the given [shuffler] before training.
*
* @param trainingExamples training examples
* @param batchSize the size of each batch (default 1)
* @param shuffler the [Shuffler] to shuffle [trainingExamples] before training (default null)
*/
private fun trainEpoch(trainingExamples: ArrayList, batchSize: Int, shuffler: Shuffler? = null) {
this.newEpoch()
val progress = ProgressIndicatorBar(trainingExamples.size)
for (exampleIndex in ExamplesIndices(trainingExamples.size, shuffler = shuffler)) {
progress.tick()
this.statistics.lastLoss = this.trainExample(example = trainingExamples[exampleIndex], batchSize = batchSize)
}
}
/**
* Train the network with the given [example] and accumulate the errors of the parameters into the [optimizer].
*
* @param example the example used to train the network
*
* @return the loss of the output compared to the expected gold
*/
private fun trainExample(example: ExampleType, batchSize: Int = 1): Double {
if (this.statistics.exampleCount % batchSize == 0) { // A new batch starts
this.newBatch()
}
this.newExample() // !! must be called after this.newBatch() !!
val loss = this.learnFromExample(example)
this.optimizer.accumulate(this.neuralProcessor.getParamsErrors(copy = batchSize > 1), copy = batchSize > 1)
if (this.statistics.exampleCount == batchSize) { // a batch is just ended
this.optimizer.update()
}
return loss
}
/**
* Learn from an example (forward + backward).
*
* @param example the example used to train the network
*
* @return the loss of the output respect to the gold
*/
abstract protected fun learnFromExample(example: ExampleType): Double
/**
* Method to call every new epoch.
* It increments the epochCount and sets the batchCount and the exampleCount to zero
*
* In turn it calls the same method into the `optimizer`
*/
private fun newEpoch() {
this.statistics.newEpoch()
this.optimizer.newEpoch()
}
/**
* Method to call every new batch.
* It increments the batchCount and sets the exampleCount to zero
*
* In turn it calls the same method into the `optimizer`
*/
private fun newBatch() {
this.statistics.newBatch()
this.optimizer.newBatch()
}
/**
* Method to call every new example.
* It increments the exampleCount
*
* In turn it calls the same method into the `optimizer`
*/
private fun newExample() {
this.statistics.newExample()
this.optimizer.newExample()
}
/**
* Log when training starts.
*/
private fun logTrainingStart(epochIndex: Int) {
if (this.verbose) {
this.startTiming()
println("\nEpoch ${epochIndex + 1}")
}
}
/**
* Log when training ends.
*/
private fun logTrainingEnd() {
if (this.verbose) { // TODO: replace lastLoss with another more valuable value
println("Elapsed time: %s".format(this.formatElapsedTime()))
println("Loss: %.5f".format(100.0 * this.statistics.lastLoss))
}
}
/**
* Log when validation starts.
*/
private fun logValidationStart(validationExamplesSize: Int) {
if (this.verbose) {
this.startTiming()
println("Validate on $validationExamplesSize examples")
}
}
/**
* Log when validation ends.
*/
private fun logValidationEnd() {
if (this.verbose) {
println("Elapsed time: %s".format(this.formatElapsedTime()))
println("Accuracy: %.2f%%".format(100.0 * this.statistics.lastAccuracy))
}
}
/**
* Start registering time.
*/
private fun startTiming() {
this.startTime = System.currentTimeMillis()
}
/**
* @return the formatted string with elapsed time in seconds and minutes.
*/
private fun formatElapsedTime(): String {
val elapsedTime = System.currentTimeMillis() - this.startTime
val elapsedSecs = elapsedTime / 1000.0
return "%.3f s (%.1f min)".format(elapsedSecs, elapsedSecs / 60.0)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy