All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.mynlp.fasttext.Model.kt Maven / Gradle / Ivy

package com.mayabot.mynlp.fasttext

import com.carrotsearch.hppc.IntArrayList
import java.util.*


/**
 * 训练模型和计算模型都需要一个setTargetCounts方法。
 * BaseModel主要目的是构建negative sampling或者hierarchical softmax
 * @author jimichan
 * @see Model
 * @see TrainModel
 */
open class BaseModel(
        @JvmField val args_: Args,
        randomSeed: Number,
        @JvmField var outputMatrixSize: Int) {

    // used for negative sampling:
    @JvmField
    protected var negatives: IntArray = IntArray(0)

    @JvmField
    protected var negpos: Int = 0

    // used for hierarchical softmax:
    @JvmField
    protected var paths: MutableList = ArrayList()
    @JvmField
    protected var codes: MutableList = ArrayList()
    @JvmField
    protected var tree: MutableList = ArrayList()


    @Transient
    @JvmField
    val rng: Random = Random(randomSeed.toLong())

    fun setTargetCounts(counts: LongArray) {
        checkArgument(counts.size == outputMatrixSize)
        if (args_.loss == LossName.ns) {
            initTableNegatives(counts)
        } else if (args_.loss == LossName.hs) {
            buildTree(counts)
        }
    }

    private fun initTableNegatives(counts: LongArray) {
        val negatives_ = IntArrayList(counts.size)

        var z = counts.map { sqrt(it) }.sum()
        val size = counts.size

        val xxn = NEGATIVE_TABLE_SIZE / z
        for (i in 0 until size) {
            val c = sqrt(counts[i])
            var j = 0
            while (j < c * xxn) {
                negatives_.add(i)
                j++
            }
        }
        negatives = negatives_.toArray()
        shuffle(negatives, rng)
    }

    private fun buildTree(counts: LongArray) {
        val pathsLocal = ArrayList(outputMatrixSize)
        val codesLocal = ArrayList(outputMatrixSize)
        val treeLocal = ArrayList(2 * outputMatrixSize - 1)

        for (i in 0 until 2 * outputMatrixSize - 1) {
            treeLocal.add(Node().apply {
                this.parent = -1
                this.left = -1
                this.right = -1
                this.count = 1000000000000000L// 1e15f;
                this.binary = false
            })
        }

        for (i in 0 until outputMatrixSize) {
            treeLocal[i].count = counts[i]
        }

        var leaf = outputMatrixSize - 1
        var node = outputMatrixSize
        for (i in outputMatrixSize until 2 * outputMatrixSize - 1) {
            val mini = IntArray(2)
            for (j in 0..1) {
                if (leaf >= 0 && treeLocal[leaf].count < treeLocal[node].count) {
                    mini[j] = leaf--
                } else {
                    mini[j] = node++
                }
            }
            treeLocal[i].apply {
                this.left = mini[0]
                this.right = mini[1]
                this.count = treeLocal[mini[0]].count + treeLocal[mini[1]].count
            }
            treeLocal[mini[0]].parent = i
            treeLocal[mini[1]].parent = i
            treeLocal[mini[1]].binary = true
        }

        for (i in 0 until outputMatrixSize) {
            val path = ArrayList()
            val code = ArrayList()

            var j = i
            while (treeLocal[j].parent != -1) {
                path.add(treeLocal[j].parent - outputMatrixSize)
                code.add(treeLocal[j].binary)
                j = treeLocal[j].parent
            }
            pathsLocal.add(path.toIntArray())
            codesLocal.add(code.toBooleanArray())
        }

        this.paths = pathsLocal
        this.codes = codesLocal
        this.tree = treeLocal
    }

    companion object {
        private val tSigmoid: FloatArray = FloatArray(SIGMOID_TABLE_SIZE + 1) { i ->
            val x = (i * 2 * MAX_SIGMOID).toFloat() / SIGMOID_TABLE_SIZE - MAX_SIGMOID
            (1.0f / (1.0f + Math.exp((-x).toDouble()))).toFloat()
        }

        private val tLog: FloatArray = FloatArray(LOG_TABLE_SIZE + 1) { i ->
            val x = (i.toFloat() + 1e-5f) / LOG_TABLE_SIZE
            Math.log(x.toDouble()).toFloat()
        }

        fun log(x: Float): Float {
            if (x > 1.0f) {
                return 0.0f
            }
            val i = (x * LOG_TABLE_SIZE).toInt()
            return tLog[i]
        }

        fun sigmoid(x: Float): Float {
            return when {
                x < -MAX_SIGMOID -> 0.0f
                x > MAX_SIGMOID -> 1.0f
                else -> {
                    val i = ((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID.toFloat() / 2f).toInt()
                    tSigmoid[i]
                }
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy