All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.mynlp.fasttext.FastTextTrain.kt Maven / Gradle / Ivy

The newest version!
package com.mayabot.mynlp.fasttext

import com.carrotsearch.hppc.IntArrayList
import com.google.common.base.CharMatcher
import com.google.common.base.Charsets
import com.google.common.base.Splitter
import com.google.common.collect.Lists
import com.google.common.io.Files
import com.google.common.primitives.Ints
import com.google.common.util.concurrent.AtomicDouble
import com.mayabot.blas.*
import java.io.*
import java.util.concurrent.atomic.AtomicLong

const val SIGMOID_TABLE_SIZE = 512
const val MAX_SIGMOID = 8
const val LOG_TABLE_SIZE = 512
const val NEGATIVE_TABLE_SIZE = 10000000



class FastTextTrain {

    lateinit var args:Args
    lateinit var dict: Dictionary

    private var tokenCount = AtomicLong(0)
    private var loss = AtomicDouble(-1.0)

    private var startTime = 0L

    var source: TrainExampleSource? = null
    var input: MutableFloatMatrix? = null
    var output: MutableFloatMatrix? = null



    fun train(file: File,modelName: ModelName, trainArgs: TrainArgs): FastText {
        return this.train(FileTrainExampleSource(whitespaceSplitter, file), modelName, trainArgs)
    }

    fun train(file: TrainExampleSource,modelName: ModelName, trainArgs: TrainArgs): FastText {
        args = Args().apply {

            bucket = trainArgs.bucket ?: bucket
            minCount  = trainArgs.minCount ?: minCount
            minCountLabel = trainArgs.minCountLabel ?: minCountLabel
            wordNgrams = trainArgs.wordNgrams ?: wordNgrams
            minn = trainArgs.minn ?: minn
            maxn  = trainArgs.maxn ?: maxn
            t  = trainArgs.t ?: t

            model = modelName
            if (modelName == ModelName.sup) {
                minCount = 1
                loss = LossName.softmax
                minCount = 1
                minn = 0
                maxn = 0
                lr = 0.1
            }

            thread = trainArgs.thread ?: thread
            dim = trainArgs.dim ?: dim
            epoch = trainArgs.epoch ?: epoch
            loss = trainArgs.loss ?: loss
            lr = trainArgs.lr ?: lr
            lrUpdateRate = trainArgs.lrUpdateRate ?: lrUpdateRate
            neg = trainArgs.neg ?: neg
            ws = trainArgs.ws ?: ws

            // -wordNgrams         max length of word ngram [1]
            // -maxn               max length of char ngram [0]

            if (wordNgrams <= 1 && maxn == 0) {
                bucket = 0
            }
        }

        dict = com.mayabot.mynlp.fasttext.Dictionary(args)
        dict.buildFromFile(file)

        val output: MutableFloatMatrix = if (ModelName.sup == args.model) {//分类模型
            FloatMatrix.floatArrayMatrix(dict.nlabels(), args.dim)
        } else {
            FloatMatrix.floatArrayMatrix(dict.nwords(), args.dim)
        }

        val input: MutableFloatMatrix

        val pretrainedVectors: File? = if (!trainArgs.pretrainedVectors.isNullOrBlank()) {
            val filePre = File(trainArgs.pretrainedVectors)
            if (filePre.exists() && filePre.canRead()) {
                filePre
            } else {
                throw RuntimeException("Not found File " + trainArgs.pretrainedVectors)
            }
        } else {
            null
        }

        if (pretrainedVectors != null) {
            input = loadVectors(pretrainedVectors)
        } else {
            input = FloatMatrix.floatArrayMatrix(dict.nwords() + args.bucket, args.dim)
            input.uniform(1.0f / args.dim)
        }

        output.fill(0f)

        this.source = file
        this.input = input
        this.output = output

        startThreads()

        val fastText = FastText(args,dict,  Model(input, output, args, 0).apply {
            if (args.model == ModelName.sup) {
                this.setTargetCounts(dict.getCounts(EntryType.label))
            } else {
                this.setTargetCounts(dict.getCounts(EntryType.word))
            }
        })


        println("Train use time ${System.currentTimeMillis() - startTime} ms")
        return fastText
    }

    @Throws(Exception::class)
    private fun startThreads() {
        startTime = System.currentTimeMillis()

        tokenCount = AtomicLong(0)
        loss = AtomicDouble(-1.0)

        val sourceParts = source!!.split(args.thread)

        val threads = Lists.newArrayList()
        for (i in 0 until args.thread) {
            threads.add(Thread(TrainThread(i,sourceParts[i])))
        }

        for (i in 0 until args.thread) {
            threads[i].start()
        }

        val ntokens = dict.ntokens()
        // Same condition as trainThread
        while (tokenCount.toLong() < args.epoch * ntokens) {

            Thread.sleep(100)
            if (loss.toFloat() >= 0 && args.verbose > 1) {
                val progress = tokenCount.toFloat() / (args.epoch * ntokens)

                print("\r")
                printInfo(progress, loss)
            }
        }

        for (i in 0 until args.thread) {
            threads[i].join()
        }

        if (args.verbose > 0) {
            print("\r")
            printInfo(1.0f, loss)
            println()
        }

        source?.close()
    }

    internal inner class TrainThread(
            private val threadId: Int,
            private val parts: TrainExampleSource

    ) : Runnable {

        override fun run() {
            try {
                LoopReader(parts).use { loopReader ->

                    val model = TrainModel(input!!, output!!, args, threadId)
                    val rng = model.rng
                    // setTargetCounts 相当耗时

                    if (args.model == ModelName.sup) {
                        model.setTargetCounts(dict.getCounts(EntryType.label))
                    } else {
                        model.setTargetCounts(dict.getCounts(EntryType.word))
                    }

                    val ntokens = dict.ntokens() //文件中词语的总数量(非排重)
                    var localTokenCount: Long = 0
                    val up_ = args.epoch * ntokens

                    val line = IntArrayList()
                    val labels = IntArrayList()

                    if (args.model == ModelName.sup) {
                        while (tokenCount.toLong() < up_) {
                            val progress = tokenCount.toFloat() / up_ //总的进度
                            val lr = args.lr.toFloat() * (1.0f - progress) //学习率自动放缓

                            val tokens = loopReader.readLineTokens()

                            localTokenCount += dict.getLine(tokens, line, labels).toLong()
                            supervised(model, lr, line, labels)

                            if (localTokenCount > args.lrUpdateRate) {
                                tokenCount.addAndGet(localTokenCount)
                                localTokenCount = 0
                                if (threadId == 0) {
                                    loss.set(model.getLoss().toDouble())
                                }
                            }
                        }
                    }
                    if (args.model == ModelName.cbow) {
                        while (tokenCount.toLong() < up_) {
                            val progress = tokenCount.toFloat() / up_ //总的进度
                            val lr = args.lr.toFloat() * (1.0f - progress) //学习率自动放缓

                            val tokens = loopReader.readLineTokens()

                            localTokenCount += dict.getLine(tokens, line, rng).toLong()
                            cbow(model, lr, line)

                            if (localTokenCount > args.lrUpdateRate) {
                                tokenCount.addAndGet(localTokenCount)
                                localTokenCount = 0
                                if (threadId == 0) {
                                    loss.set(model.getLoss().toDouble())
                                }
                            }
                        }
                    }
                    if (args.model == ModelName.sg) {
                        while (tokenCount.toLong() < up_) {
                            val progress = tokenCount.toFloat() / up_ //总的进度
                            val lr = args.lr.toFloat() * (1.0f - progress) //学习率自动放缓

                            val tokens = loopReader.readLineTokens()

                            localTokenCount += dict.getLine(tokens, line, rng).toLong()
                            skipgram(model, lr, line)

                            if (localTokenCount > args.lrUpdateRate) {
                                tokenCount.addAndGet(localTokenCount)
                                localTokenCount = 0
                                if (threadId == 0) {
                                    loss.set(model.getLoss().toDouble())
                                }
                            }
                        }
                    }

                    if (threadId == 0) {
                        loss.set(model.getLoss().toDouble())
                    }

                }
            } catch (e: Exception) {
                throw RuntimeException(e)
            }

        }


        internal fun supervised(
                model: TrainModel,
                lr: Float,
                line: IntArrayList,
                labels: IntArrayList) {
            if (labels.size() == 0 || line.size() == 0) {
                return
            }
            val i = if (labels.size() == 1) 0 else model.rng.nextInt(labels.size())
            model.update(line, labels.get(i), lr)
        }


        private fun cbow(model: TrainModel, lr: Float,
                         line: IntArrayList) {
            val bow = IntArrayList()

            // std::uniform_int_distribution<> uniform(1, args_->ws);
            for (w in 0 until line.size()) {
                val boundary = model.rng.nextInt(args.ws) + 1 // 1~5
                bow.clear()
                for (c in -boundary..boundary) {
                    if (c != 0 && w + c >= 0 && w + c < line.size()) {
                        val ngrams = dict.getSubwords(line.get(w + c))
                        bow.addAll(ngrams)
                    }
                }
                model.update(bow, line.get(w), lr)
            }
        }

        private fun skipgram(model: TrainModel, lr: Float,
                             line: IntArrayList) {
            for (w in 0 until line.size()) {
                val boundary = model.rng.nextInt(args.ws) + 1 // 1~5

                val ngrams = dict.getSubwords(line.get(w))
                for (c in -boundary..boundary) {
                    if (c != 0 && w + c >= 0 && w + c < line.size()) {
                        model.update(ngrams, line.get(w + c), lr)
                    }
                }
            }
        }

    }

    private fun printInfo(progress: Float, loss: AtomicDouble) {
        var progress = progress
        // clock_t might also only be 32bits wide on some systems
        val t = ((System.currentTimeMillis() - startTime) / 1000).toDouble()
        val lr = args.lr * (1.0 - progress)
        var wst = 0.0
        var eta = (720 * 3600).toLong() // Default to one month
        if (progress > 0 && t >= 0) {
            eta = (t / progress * (1 - progress) / args.thread).toInt().toLong()
            wst = tokenCount.toFloat() / t
        }
        val etah = eta / 3600
        val etam = eta % 3600 / 60
        val etas = eta % 3600 % 60
        progress *= 100
        val sb = StringBuilder()
        sb.append("Progress: " +
                String.format("%2.2f", progress) + "% words/sec/thread: " + String.format("%8.0f", wst))
        sb.append(String.format(" lr: %2.5f", lr))
        sb.append(String.format(" loss: %2.5f", loss.toFloat()))
        sb.append(" ETA: " + etah + "h " + etam + "m " + etas + "s")

        print(sb)
    }


    @Throws(Exception::class)
    private fun loadVectors(filename: File): MutableFloatMatrix {

        var n: Int = 0
        var dim: Int = 0

        val firstLine = filename.firstLine()!!
        run {
            val strings = Splitter.on(CharMatcher.whitespace()).splitToList(firstLine)
            n = Ints.tryParse(strings[0])!!
            dim = Ints.tryParse(strings[1])!!
        }
        if (n == 0 || dim == 0) {
            throw Exception("Error format for " + filename.name + ",First line must be rows and dim arg")
        }
        if (dim != args.dim) {
            throw Exception("Dimension of pretrained vectors " + dim + " does not match dimension (" + args.dim + ")")
        }

        val mat = MutableByteBufferMatrix(n, dim)

        val sp = Splitter.on(" ").omitEmptyStrings()

        val words = Lists.newArrayListWithExpectedSize(n)
        val charSource = Files.asCharSource(filename, Charsets.UTF_8)
        charSource.openBufferedStream().use { reader ->
            reader.readLine()//first line
            for (i in 0 until n) {
                val line = reader.readLine()
                var parts: MutableList = sp.splitToList(line)
                if (parts.size != dim + 1) {
                    if (parts.size == dim) {
                        parts = Lists.newArrayList(line.substring(0, line.indexOf(parts[0]) - 1))
                        parts.addAll(sp.splitToList(line))
                    } else {
                        throw RuntimeException("line $line parse error")
                    }

                }

                val word = parts[0]
                dict.add(word)
                words.add(word)
                val row = mat[i]
                var x = 0
                for (j in 1..dim) {
                    row[x++] = parts[j].toFloat()
                }
            }
        }

        dict.threshold(1, 0)
        val input = FloatMatrix.floatArrayMatrix(dict.nwords() + args.bucket, args.dim)
        input.uniform(1.0f / args.dim)

        for (i in 0 until n) {
            val idx = dict.getId(words[i])
            if (idx < 0 || idx > dict.nwords()) {
                continue
            }

            input[idx](mat[i])

//            System.arraycopy(matrixData, i * dim, input.getData(), idx * dim, dim)
            //            for (int j = 0; j < dim; j++) {
            //                input.set(idx, j, mat.get(i, j));
            //            }

        }

        return input

    }

}


class TrainModel(
        private val inputMatrix: MutableFloatMatrix // input
        , private val outputMatrix: MutableFloatMatrix // output
        , args_: Args, seed: Int
) : BaseModel(args_, seed, outputMatrix.rows()) {

    private val hidden = Vector.floatArrayVector(args_.dim)
    private val output = Vector.floatArrayVector(outputMatrix.rows())
    private val grad = Vector.floatArrayVector(args_.dim)

    private val hsz: Int = args_.dim // dim
    //    private val isz: Int = input.rows()// input vocabSize
    private var loss = 0f
    private var nexamples = 1L

    fun getLoss(): Float {
        return loss / nexamples
    }


    private fun LongArray.sqrtSum() : Long{
        var t = 0L
        for (i in this) {
            t += i*i
        }
        return t
    }


    fun update(input: IntArrayList, target: Int, lr: Float) {
        checkArgument(target >= 0)
        checkArgument(target < outputMatrixSize)
        if (input.size() == 0) {
            return
        }
        computeHidden(input, hidden)

        loss += when (args_.loss) {
            LossName.ns -> negativeSampling(target, lr)
            LossName.hs -> hierarchicalSoftmax(target, lr)
            LossName.softmax -> softmax(target, lr)
        }
        nexamples += 1

        if (args_.model == ModelName.sup) {
            grad *= (1.0f / input.size())
        }

        val buffer = input.buffer
        var i = 0
        val size = input.size()
        while (i < size) {
            val it = buffer[i]
            inputMatrix[it] += grad
            i++
        }
    }

    private fun computeHidden(input: IntArrayList, hidden: MutableVector) {
        checkArgument(hidden.length() == hsz)
        hidden.zero()

        val buffer = input.buffer
        var i = 0
        val size = input.size()
        while (i < size) {
            val it = buffer[i]
            hidden += inputMatrix[it]
            i++
        }
        hidden *= (1.0f / input.size())
    }

    private fun negativeSampling(target: Int, lr: Float): Float {
        var loss = 0.0f
        grad.zero()
        for (n in 0..args_.neg) {
            loss += if (n == 0) {
                binaryLogistic(target, true, lr)
            } else {
                binaryLogistic(getNegative(target), false, lr)
            }
        }
        return loss
    }

    private fun binaryLogistic(target: Int, label: Boolean, lr: Float): Float {
        val score = sigmoid(outputMatrix[target] * hidden)
        val alpha = lr * ((if (label) 1.0f else 0.0f) - score)
        grad += alpha to outputMatrix[target]

        outputMatrix[target] += alpha to hidden

        return if (label) {
            -log(score)
        } else {
            -log(1.0f - score)
        }
    }

    private fun getNegative(target: Int): Int {
        var negative: Int
        do {
            negative = negatives[negpos]
            negpos = (negpos + 1) % negatives.size
        } while (target == negative)
        return negative
    }


    private fun hierarchicalSoftmax(target: Int, lr: Float): Float {
        var loss = 0.0f
        grad.zero()
        val binaryCode = codes[target]
        val pathToRoot = paths[target]
        for (i in 0 until pathToRoot.size) {
            loss += binaryLogistic(pathToRoot[i], binaryCode[i], lr)
        }
        return loss
    }

    private fun softmax(target: Int, lr: Float): Float {
        grad.zero()
        computeOutputSoftmax()
        for (i in 0 until outputMatrixSize) {
            val label = if (i == target) 1.0f else 0.0f
            val alpha = lr * (label - output[i])
            grad += alpha to outputMatrix[i]
            outputMatrix[i] += alpha to hidden
        }
        return -log(output[target])
    }

    @JvmOverloads
    private fun computeOutputSoftmax(hidden: Vector = this.hidden, output: MutableVector = this.output) {
        matrixMulVector(outputMatrix, hidden, output)
        var max = output[0]
        var z = 0.0f
        for (i in 1 until outputMatrixSize) {
            max = Math.max(output[i], max)
        }
        for (i in 0 until outputMatrixSize) {
            output[i] = Math.exp((output[i] - max).toDouble()).toFloat()
            z += output[i]
        }
        for (i in 0 until outputMatrixSize) {
            output[i] = output[i] / z
        }
    }

}

class Node {
    @JvmField
    var parent: Int = 0
    @JvmField
    var left: Int = 0
    @JvmField
    var right: Int = 0
    @JvmField
    var count: Long = 0
    @JvmField
    var binary: Boolean = false
}

class LoopReader @Throws(IOException::class)
constructor(private val parts: TrainExampleSource) : AutoCloseable {

    var reader: ExampleIterator

    internal var splitter = Splitter.on(CharMatcher.whitespace()).omitEmptyStrings().trimResults()

    init {
        reader = parts.iteratorAll()
    }

    @Throws(IOException::class)
    fun readLineTokens(): List {
        var line = loopLine()
        //skip empty line
        //此处容易导致死循环
        while (line.isEmpty()) {
            line = loopLine()
        }

        return line2Tokens(line)
    }

    @Throws(IOException::class)
    private fun loopLine(): List {
        if (reader.hasNext()) {
            return reader.next()
        }else{
            reader.close()
            reader = parts.iteratorAll()
            return Lists.newArrayList()
        }
    }

    private fun line2Tokens(tokens: List): List {
        val list = Lists.newArrayList(tokens)
        list.add(EOS)
        return list
    }

    @Throws(Exception::class)
    override fun close() {
        if (reader != null) {
            reader!!.close()
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy