com.kotlinnlp.simplednn.deeplearning.transformers.BERTModel.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of simplednn Show documentation
Show all versions of simplednn Show documentation
SimpleDNN is a machine learning lightweight open-source library written in Kotlin whose purpose is to
support the development of feed-forward and recurrent Artificial Neural Networks.
The newest version!
/* Copyright 2020-present Simone Cangialosi. All Rights Reserved.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, you can obtain one at http://mozilla.org/MPL/2.0/.
* ------------------------------------------------------------------*/
package com.kotlinnlp.simplednn.deeplearning.transformers
import com.kotlinnlp.simplednn.core.embeddings.EmbeddingsMap
import com.kotlinnlp.simplednn.core.functionalities.activations.GeLU
import com.kotlinnlp.simplednn.core.functionalities.activations.Softmax
import com.kotlinnlp.simplednn.core.functionalities.initializers.GlorotInitializer
import com.kotlinnlp.simplednn.core.functionalities.initializers.Initializer
import com.kotlinnlp.simplednn.core.layers.LayerInterface
import com.kotlinnlp.simplednn.core.layers.LayerType
import com.kotlinnlp.simplednn.core.layers.StackedLayersParameters
import com.kotlinnlp.utils.DictionarySet
import com.kotlinnlp.utils.Serializer
import com.kotlinnlp.utils.removeFrom
import java.io.InputStream
import java.io.OutputStream
import java.io.Serializable
/**
* The BERT model.
*
* @property inputSize the size of the input arrays
* @property attentionSize the size of the attention arrays
* @property attentionOutputSize the size of the attention outputs
* @property outputHiddenSize the number of the hidden nodes of the output feed-forward
* @property numOfHeads the number of self-attention heads
* @property vocabulary the vocabulary with all the well-known forms of the model (forms not present in it are treated
* as unknown)
* @param wordEmbeddings pre-trained word embeddings or null to generate them randomly using the [vocabulary]
* @param numOfLayers the number of stacked layers
* @param weightsInitializer the initializer of the weights (zeros if null, default: Glorot)
* @param biasesInitializer the initializer of the biases (zeros if null, default: Glorot)
*/
class BERTModel(
val inputSize: Int,
val attentionSize: Int,
val attentionOutputSize: Int,
val outputHiddenSize: Int,
val numOfHeads: Int,
val vocabulary: DictionarySet,
wordEmbeddings: EmbeddingsMap? = null,
numOfLayers: Int,
weightsInitializer: Initializer? = GlorotInitializer(),
biasesInitializer: Initializer? = GlorotInitializer()
) : Serializable {
/**
* Functional token.
*/
enum class FuncToken(val form: String) {
CLS("[CLS]"),
SEP("[SEP]"),
PAD("[PAD]"),
UNK("[UNK]"),
MASK("[MASK]");
companion object {
/**
* The [FuncToken] associated by form.
*/
private val tokensByForm: Map = values().associateBy { it.form }
/**
* @param form a token form
*
* @return the [FuncToken] with the given form
*/
fun byForm(form: String) = tokensByForm.getValue(form)
}
}
companion object {
/**
* Private val used to serialize the class (needed by Serializable).
*/
@Suppress("unused")
private const val serialVersionUID: Long = 1L
/**
* Read [BERTModel] (serialized) from an input stream and decode it.
*
* @param inputStream the [InputStream] from which to read the serialized [BERTModel]
*
* @return the [BERTModel] read from [inputStream] and decoded
*/
fun load(inputStream: InputStream): BERTModel = Serializer.deserialize(inputStream)
}
/**
* The size of the output arrays (equal to the input).
*/
val outputSize: Int = this.inputSize
/**
* The initial parameters of the stacked BERT layers.
* They can be reduced with the method [reduceLayersTo].
*/
private val initLayers: MutableList = MutableList(
size = numOfLayers,
init = {
BERTParameters(
inputSize = this.inputSize,
attentionSize = this.attentionSize,
attentionOutputSize = this.attentionOutputSize,
outputHiddenSize = this.outputHiddenSize,
numOfHeads = this.numOfHeads,
weightsInitializer = weightsInitializer,
biasesInitializer = biasesInitializer)
}
)
/**
* The parameters of the stacked BERT layers.
*/
var layers: List = this.initLayers.toList()
private set
/**
* The parameters of the embeddings norm layer.
*/
val embNorm = StackedLayersParameters(
LayerInterface(size = this.inputSize),
LayerInterface(size = this.inputSize, connectionType = LayerType.Connection.Norm),
weightsInitializer = weightsInitializer,
biasesInitializer = biasesInitializer)
/**
* The word embeddings.
* If not trained, they can be set to null before a model serialization and re-set after deserialization, in order to
* make the model lighter.
*/
var wordEmb: EmbeddingsMap? = wordEmbeddings ?: EmbeddingsMap(this.inputSize).apply {
vocabulary.getElements().forEach { set(it) }
}
/**
* The functional embeddings associated to the [FuncToken].
*/
var funcEmb: EmbeddingsMap = EmbeddingsMap(this.inputSize).apply {
FuncToken.values().forEach {
set(key = it, embedding = wordEmb!!.getOrNull(it.form))
}
}
/**
* The positional embeddings.
*/
val positionalEmb: EmbeddingsMap = EmbeddingsMap(this.inputSize)
/**
* The token type embeddings.
*/
val tokenTypeEmb: EmbeddingsMap = EmbeddingsMap(this.inputSize).apply {
set(0)
set(1)
}
/**
* The model of the classifier used to train the model.
*/
var classifier: StackedLayersParameters = StackedLayersParameters(
LayerInterface(size = inputSize),
LayerInterface(size = inputSize, connectionType = LayerType.Connection.Feedforward, activationFunction = GeLU),
LayerInterface(size = inputSize, connectionType = LayerType.Connection.Norm),
LayerInterface(
size = vocabulary.size, connectionType = LayerType.Connection.Feedforward, activationFunction = Softmax())
)
/**
* Serialize this [BERTModel] and write it to an output stream.
*
* @param outputStream the [OutputStream] in which to write this serialized [BERTModel]
*/
fun dump(outputStream: OutputStream) = Serializer.serialize(this, outputStream)
/**
* Reduce the layers of this model to a given size, starting from the last.
*
* @param size the new number of [BERT] layers
*/
fun reduceLayersTo(size: Int) {
require(size < this.layers.size) {
"The reducing size ($size) must be lower than the current layer size (${this.layers.size})"
}
this.initLayers.removeFrom(size)
this.layers = this.initLayers.toList()
}
}