
com.mayabot.mynlp.fasttext.Model.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of fastText4j Show documentation
Show all versions of fastText4j Show documentation
Java(kotlin) implementation of the Fasebook 's FastText
The newest version!
package com.mayabot.mynlp.fasttext
import com.carrotsearch.hppc.IntArrayList
import java.util.*
/**
* 训练模型和计算模型都需要一个setTargetCounts方法。
* BaseModel主要目的是构建negative sampling或者hierarchical softmax
* @author jimichan
* @see Model
* @see TrainModel
*/
open class BaseModel(
@JvmField val args_: Args,
randomSeed: Number,
@JvmField var outputMatrixSize: Int) {
// used for negative sampling:
@JvmField
protected var negatives: IntArray = IntArray(0)
@JvmField
protected var negpos: Int = 0
// used for hierarchical softmax:
@JvmField
protected var paths: MutableList = ArrayList()
@JvmField
protected var codes: MutableList = ArrayList()
@JvmField
protected var tree: MutableList = ArrayList()
@Transient
@JvmField
val rng: Random = Random(randomSeed.toLong())
fun setTargetCounts(counts: LongArray) {
checkArgument(counts.size == outputMatrixSize)
if (args_.loss == LossName.ns) {
initTableNegatives(counts)
} else if (args_.loss == LossName.hs) {
buildTree(counts)
}
}
private fun initTableNegatives(counts: LongArray) {
val negatives_ = IntArrayList(counts.size)
var z = counts.map { sqrt(it) }.sum()
val size = counts.size
val xxn = NEGATIVE_TABLE_SIZE / z
for (i in 0 until size) {
val c = sqrt(counts[i])
var j = 0
while (j < c * xxn) {
negatives_.add(i)
j++
}
}
negatives = negatives_.toArray()
shuffle(negatives, rng)
}
private fun buildTree(counts: LongArray) {
val pathsLocal = ArrayList(outputMatrixSize)
val codesLocal = ArrayList(outputMatrixSize)
val treeLocal = ArrayList(2 * outputMatrixSize - 1)
for (i in 0 until 2 * outputMatrixSize - 1) {
treeLocal.add(Node().apply {
this.parent = -1
this.left = -1
this.right = -1
this.count = 1000000000000000L// 1e15f;
this.binary = false
})
}
for (i in 0 until outputMatrixSize) {
treeLocal[i].count = counts[i]
}
var leaf = outputMatrixSize - 1
var node = outputMatrixSize
for (i in outputMatrixSize until 2 * outputMatrixSize - 1) {
val mini = IntArray(2)
for (j in 0..1) {
if (leaf >= 0 && treeLocal[leaf].count < treeLocal[node].count) {
mini[j] = leaf--
} else {
mini[j] = node++
}
}
treeLocal[i].apply {
this.left = mini[0]
this.right = mini[1]
this.count = treeLocal[mini[0]].count + treeLocal[mini[1]].count
}
treeLocal[mini[0]].parent = i
treeLocal[mini[1]].parent = i
treeLocal[mini[1]].binary = true
}
for (i in 0 until outputMatrixSize) {
val path = ArrayList()
val code = ArrayList()
var j = i
while (treeLocal[j].parent != -1) {
path.add(treeLocal[j].parent - outputMatrixSize)
code.add(treeLocal[j].binary)
j = treeLocal[j].parent
}
pathsLocal.add(path.toIntArray())
codesLocal.add(code.toBooleanArray())
}
this.paths = pathsLocal
this.codes = codesLocal
this.tree = treeLocal
}
companion object {
private val tSigmoid: FloatArray = FloatArray(SIGMOID_TABLE_SIZE + 1) { i ->
val x = (i * 2 * MAX_SIGMOID).toFloat() / SIGMOID_TABLE_SIZE - MAX_SIGMOID
(1.0f / (1.0f + Math.exp((-x).toDouble()))).toFloat()
}
private val tLog: FloatArray = FloatArray(LOG_TABLE_SIZE + 1) { i ->
val x = (i.toFloat() + 1e-5f) / LOG_TABLE_SIZE
Math.log(x.toDouble()).toFloat()
}
fun log(x: Float): Float {
if (x > 1.0f) {
return 0.0f
}
val i = (x * LOG_TABLE_SIZE).toInt()
return tLog[i]
}
fun sigmoid(x: Float): Float {
return when {
x < -MAX_SIGMOID -> 0.0f
x > MAX_SIGMOID -> 1.0f
else -> {
val i = ((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID.toFloat() / 2f).toInt()
tSigmoid[i]
}
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy