All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.nlp.fasttext.Model.kt Maven / Gradle / Ivy

package com.mayabot.nlp.fasttext

import com.mayabot.nlp.blas.DenseVector
import com.mayabot.nlp.blas.Matrix
import com.mayabot.nlp.common.IntArrayList
import com.mayabot.nlp.fasttext.loss.Loss
import kotlin.random.Random

typealias Predictions = MutableList

class Model(
        val wi: Matrix,
        val wo: Matrix,
        val loss: Loss,
        val normalizeGradient: Boolean
) {

    companion object {
        val kUnlimitedPredictions: Int = -1
        val kAllLabelsAsTarget = -1
    }

    /**
     * input里面存放的是row Id,这些来自input matrix
     */
    private fun computeHidden(input: IntArrayList, state: State) {
        val hidden = state.hidden
        hidden.zero()

        input.forEach { row ->
            wi.addRowToVector(hidden, row)
            //hidden += wi[row]
        }

        //长度归一化
        hidden *= (1.0f / input.size())
    }

    /**
     * 预测分类结果
     *
     * 预测过程。。。
     *
     * @param input 输入的词的下标
     *
     */
    fun predict(input: IntArrayList,
                k: Int,
                threshold: Float,
                heap: Predictions,
                state: State
    ) {
        val kk = if (k == kUnlimitedPredictions) {
            // output size
            wo.row
        } else {
            k
        }
        if (kk == 0) {
            throw RuntimeException("k needs to be 1 or higher")
        }

        computeHidden(input, state)

        loss.predict(k, threshold, heap, state)
    }

    fun update(input: IntArrayList,
               targets: IntArrayList,
               targetIndex: Int,
               lr: Float,
               state: State) {
        if (input.size() == 0) {
            return
        }

        computeHidden(input, state)

        val grad = state.grad
        grad.zero()

        val lossValue = loss.forward(targets, targetIndex, state, lr, true)

        state.incrementNExamples(lossValue)

        if (normalizeGradient) {
            grad *= (1.0f / input.size())
        }

        input.forEach { i ->
            wi.addVectorToRow(grad, i, 1.0f)
        }
    }

    class State(hiddenSize: Int, outputSize: Int, seed: Int) {
        private var lossValue = 0.0f
        private var nexamples = 0

        val hidden = DenseVector(hiddenSize)
        val output = DenseVector(outputSize)
        val grad = DenseVector(hiddenSize)
        val rng = Random(seed)

        val loss get() = lossValue / nexamples

        fun incrementNExamples(loss: Float) {
            lossValue += loss
            nexamples++
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy