All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.nlp.fasttext.loss.NegativeSamplingLoss.kt Maven / Gradle / Ivy

package com.mayabot.nlp.fasttext.loss

import com.mayabot.nlp.blas.Matrix
import com.mayabot.nlp.common.IntArrayList
import com.mayabot.nlp.fasttext.Model
import kotlin.random.Random


class NegativeSamplingLoss(wo: Matrix, val neg: Int, targetCounts: LongArray) : BinaryLogisticLoss(wo) {
    companion object {
        const val NEGATIVE_TABLE_SIZE = 10000000
    }

    val negatives = IntArrayList()


    val uniform: (random: Random) -> Int

    init {
        var z = 0.0
        for (i in 0 until targetCounts.size) {
            z += Math.pow(targetCounts[i].toDouble(), 0.5)
        }

        for (i in 0 until targetCounts.size) {
            val c = Math.pow(targetCounts[i].toDouble(), 0.5)
            for (j in 0 until (c * NEGATIVE_TABLE_SIZE / z).toInt()) {
                negatives.add(i)
            }
        }
        val ns = negatives.size()
        //uniform_ = std::uniform_int_distribution(0, negatives_.size());
        uniform = { random -> random.nextInt(ns) }
    }

    override fun forward(targets: IntArrayList, targetIndex: Int, state: Model.State, lr: Float, backprop: Boolean): Float {
        val target = targets[targetIndex]
        var loss = binaryLogistic(target, state, true, lr, backprop)
        for (n in 0 until neg) {
            var negativeTarget = getNegative(target, state.rng)
            loss += binaryLogistic(negativeTarget, state, false, lr, backprop)
        }
        return loss
    }

    private fun getNegative(target: Int, rng: Random): Int {
        var negative = -1
        do {
            negative = negatives[uniform(rng)]
        } while (target == negative)
        return negative
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy