All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.mynlp.fasttext.FastText.kt Maven / Gradle / Ivy

The newest version!
package com.mayabot.mynlp.fasttext

import com.carrotsearch.hppc.IntArrayList
import com.google.common.base.Charsets
import com.google.common.base.Stopwatch
import com.google.common.collect.ImmutableList
import com.google.common.collect.Iterables
import com.google.common.collect.Lists
import com.google.common.collect.Sets
import com.google.common.io.Files
import com.google.common.primitives.Floats
import com.mayabot.blas.*
import com.mayabot.blas.Vector
import java.io.File
import java.io.IOException
import java.io.InputStream
import java.text.DecimalFormat
import java.util.*
import java.util.concurrent.TimeUnit
import kotlin.math.exp
import kotlin.system.exitProcess

const val FASTTEXT_VERSION = 12
const val FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314

data class FloatIntPair(@JvmField var first: Float, @JvmField var second: Int)
data class FloatStringPair(@JvmField var first: Float, @JvmField var second: String){
    override fun toString(): String {
        return "[$second,$first]"
    }
}

class FastText(internal val args: Args,
               internal val dict: Dictionary,
               internal val model: Model
) {

    /**
     * 是否量化. 指的是隐藏层或者LEFT或者是词向量是否向量化
     */
    val quant = model.quant

    val input = model.input
    val output = model.output

    lateinit var wordVectors: FloatMatrix


    /**
     * 预测分类标签
     *
     * @param tokens
     * @param k
     * @return
     */
    fun predict(tokens: Iterable, k: Int): List {
        val tokens2 = Iterables.concat(tokens, listOf(EOS))
        val words = IntArrayList()
        val labels = IntArrayList()

        dict.getLine(tokens2, words, labels)

        if (words.isEmpty) {
            return ImmutableList.of()
        }
        val hidden = MutableByteBufferVector(args.dim)
        val output = MutableByteBufferVector(dict.nlabels())

        val modelPredictions = Lists.newArrayListWithCapacity(k)

        model.predict(words, k, modelPredictions, hidden, output)

        return modelPredictions.map { x -> FloatStringPair(exp(x.first), dict.getLabel(x.second)) }
    }


    private fun findNN(wordVectors: FloatMatrix, queryVec: Vector, k: Int, sets: Set): List {

        var queryNorm = queryVec.norm2()
        if (Math.abs(queryNorm) < 1e-8) {
            queryNorm = 1f
        }

        val mostSimilar = (0 until k).map { FloatStringPair(-1f,"") }.toList().toTypedArray()
        val mastSimilarLast = mostSimilar.size - 1

        for (i in 0 until dict.nwords()) {
            val dp = wordVectors[i] *queryVec / queryNorm
            val last = mostSimilar[mastSimilarLast]
            if (dp > last.first) {
                last.first = dp
                last.second = dict.getWord(i)

                mostSimilar.sortByDescending { it.first }
            }
        }

        val result = Lists.newArrayList()
        for (r in mostSimilar) {
            if (r.first != -1f && !sets.contains(r.second)) {
                result.add(r)
            }
        }

        return result
    }


    /**
     * NearestNeighbor
     */
    fun nearestNeighbor(wordQuery: String, k: Int): List {
        if (!this::wordVectors.isInitialized) {
            val stopwatch = Stopwatch.createStarted()
            wordVectors = FloatMatrix.floatArrayMatrix(dict.nwords,args.dim).apply {
                preComputeWordVectors(this)
            }
            stopwatch.stop()
            println("Init wordVectors martix use time ${stopwatch.elapsed(TimeUnit.MILLISECONDS)} ms")
        }
        val queryVec = getWordVector(wordQuery)
        val sets = HashSet()
        sets.add(wordQuery)
        return findNN(wordVectors, queryVec, k, sets)
    }

    /**
     * Query triplet (A - B + C)?
     * @param A
     * @param B
     * @param C
     * @param k
     */
    fun analogies(A: String, B: String, C: String, k: Int): List {
        if (!this::wordVectors.isInitialized) {
            val stopwatch = Stopwatch.createStarted()
            wordVectors = FloatMatrix.floatArrayMatrix(dict.nwords,args.dim).apply {
                preComputeWordVectors(this)
            }
            stopwatch.stop()
            println("Init wordVectors martix use time ${stopwatch.elapsed(TimeUnit.MILLISECONDS)} ms")
        }

        val buffer = Vector.floatArrayVector(args.dim)
        val query = Vector.floatArrayVector(args.dim)

        getWordVector(buffer, A)
        query += buffer

        getWordVector(buffer, B)
        query += -1f to buffer

        getWordVector(buffer, C)
        query += buffer

        val sets = Sets.newHashSet(A, B, C)

        return findNN(wordVectors, query, k, sets)
    }


    /**
     * 计算所有词的向量。
     * 之所以向量都除以norm进行归一化。因为使用者。使用dot表达相似度,也会除以query vector的norm。然后归一化。
     * 最后距离结构都是0 ~ 1 的数字
     * @param wordVectors
     */
    private fun preComputeWordVectors(wordVectors: MutableFloatMatrix) {
        val vec = Vector.floatArrayVector(args.dim)
        wordVectors.fill(0f)
        for (i in 0 until dict.nwords()) {
            val word = dict.getWord(i)
            getWordVector(vec, word)
            val norm = vec.norm2()
            if (norm > 0) {
                wordVectors[i] += 1.0f/norm to vec
            }
        }
    }

    /**
     * 把词向量填充到一个Vector对象里面去
     *
     * @param vec
     * @param word
     */
    fun getWordVector(vec: MutableVector, word: String) {
        vec.zero()
        val ngrams = dict.getSubwords(word)
        val buffer = ngrams.buffer
        var i = 0
        val len = ngrams.size()
        while (i < len) {
            addInputVector(vec, buffer[i])
            i++
        }

        if (ngrams.size() > 0) {
            vec *= 1.0f / ngrams.size()
        }
    }

    fun getWordVector(word: String): Vector {
        val vec = MutableByteBufferVector(args.dim)
        getWordVector(vec, word)
        return vec
    }


    /**
     * 计算句子向量
     * @return 句子向量
     */
    fun getSentenceVector(tokens: Iterable): Vector {
        val svec = MutableByteBufferVector(args.dim)
        getSentenceVector(svec, tokens)
        return svec
    }


    /**
     * 句子向量
     *
     * @param svec
     * @param tokens
     */
    private fun getSentenceVector(svec: MutableVector, tokens: Iterable) {
        svec.zero()
        if (args.model == ModelName.sup) {
            val line = IntArrayList()
            val labels = IntArrayList()
            dict.getLine(tokens, line, labels)

            for (i in 0 until line.size()) {
                addInputVector(svec, line.get(i))
            }

            if (!line.isEmpty) {
                svec *= (1.0f / line.size())
            }
        } else {
            val vec = MutableByteBufferVector(args.dim)
            var count = 0
            for (word in tokens) {
                getWordVector(vec, word)
                val norm = vec.norm2()
                if (norm > 0) {
                    vec *= (1.0f / norm)
                    svec += vec
                    count++
                }
            }
            if (count > 0) {
                svec *= (1.0f / count)
            }
        }
    }

    private fun addInputVector(vec: MutableVector, ind: Int) {
        if (quant) {
            model.qinput.addToVector(vec, ind)
        } else {
            vec += input[ind]
        }
    }


    /**
     * 把词向量另存为文本格式
     *
     * @param file
     */
    @Throws(Exception::class)
    fun saveVectors(fileName: String) {
        var fileName = fileName
        if (!fileName.endsWith("vec")) {
            fileName += ".vec"
        }

        val file = File(fileName)
        if (file.exists()) {
            file.delete()
        }
        if (file.parentFile != null) {
            file.parentFile.mkdirs()
        }

        val vec = MutableByteBufferVector(args.dim)
        val df = DecimalFormat("0.#####")

        Files.asByteSink(file).asCharSink(Charsets.UTF_8).openBufferedStream().use { writer ->
            writer.write("${dict.nwords()} ${args.dim}\n")
            for (i in 0 until dict.nwords()) {
                val word = dict.getWord(i)
                getWordVector(vec, word)
                writer.write(word)
                writer.write(" ")
                for (j in 0 until vec.length()) {
                    writer.write(df.format(vec[j].toDouble()))
                    writer.write(" ")
                }
                writer.write("\n")
            }
        }
    }

    /**
     * 保存为自有的文件格式(多文件)
     */
    @Throws(Exception::class)
    fun saveModel(path: String) {
        var path = File(path)
        if (path.exists()) {
            path.deleteRecursively()
        }
        path.mkdirs()

        //dict
        File(path, "dict.bin").outputStream().channel.use {
            dict.save(it)
        }

        //args
        File(path, "args.bin").outputStream().channel.use {
            args.save(it)
        }

        if (!quant) {
            //input float matrix
            File(path, "input.matrix").outputStream().channel.use {
                it.writeInt(model.input.rows())
                it.writeInt(model.input.cols())
                model.input.write(it)
            }
        } else {
            File(path, "qinput.matrix").outputStream().channel.use {
                model.qinput.save(it)
            }
        }

        if (quant && model.quantOut) {
            File(path, "qoutput.matrix").outputStream().channel.use {
                model.qoutput!!.save(it)
            }
        } else {
            File(path, "output.matrix").outputStream().channel.use {
                it.writeInt(model.output.rows())
                it.writeInt(model.output.cols())
                model.output.write(it)
            }
        }
    }


    companion object {

        /**
         * 加载facebook官方C程序保存的文件模型,支持bin和ftz模型
         *
         * @param modelFilePath
         * @throws IOException
         */
        @JvmStatic
        @Throws(Exception::class)
        fun loadFasttextBinModel(modelFilePath: String): FastText {
            return LoadFastTextFromClangModel.loadCModel(modelFilePath)
        }
        /**
         * 加载facebook官方C程序保存的文件模型,支持bin和ftz模型
         *
         * @param modelPath
         * @throws IOException
         */
        @JvmStatic
        @Throws(Exception::class)
        fun loadFasttextBinModel(modelFile: File): FastText {
            return LoadFastTextFromClangModel.loadCModel(modelFile)
        }
        /**
         * 加载facebook官方C程序保存的文件模型,支持bin和ftz模型
         *
         * @param modelPath
         * @throws IOException
         */
        @JvmStatic
        @Throws(Exception::class)
        fun loadFasttextBinModel(modelStream: InputStream): FastText {
            return LoadFastTextFromClangModel.loadCModel(modelStream)
        }

        private fun File.openAutoDataInput() = AutoDataInput.open(this)


        /**
         * 加载java程序保存的文件模型.
         * path应该是一个目录,下面保存各个细节的文件
         */
        @JvmOverloads
        @JvmStatic
        fun loadModel(modelPath: String, mmap: Boolean = false): FastText {
            val dir = File(modelPath)

            if (!dir.exists() || dir.isFile) {
                println("error file $dir")
                exitProcess(0)
            }

            val args = Args().loadClang(File(dir, "args.bin").openAutoDataInput())

            val dictionary = Dictionary(args).load(File(dir, "dict.bin").openAutoDataInput())

            fun loadMatrix(file: File): FloatMatrix {
                return FloatMatrix.loadMatrix(file,mmap)
            }

            val quant = File(dir, "qinput.matrix").exists()

            var input: FloatMatrix = FloatMatrix.floatArrayMatrix(0, 0)
            var qinput: QMatrix? = null

            if (quant) {
                qinput = QMatrix.load(File(dir, "qinput.matrix").openAutoDataInput())
            } else {
                input = loadMatrix(File(dir, "input.matrix"))
            }

            val quantInput = quant
            if (!quantInput && dictionary.isPruned()) {
                throw RuntimeException("Invalid model file.\n"
                        + "Please download the updated model from www.fasttext.cc.\n"
                        + "See issue #332 on Github for more information.\n")
            }

            var output: FloatMatrix = FloatMatrix.floatArrayMatrix(0, 0)
            var qoutput: QMatrix? = null

            val qout = File(dir, "qoutput.matrix").exists()
            if (quant && qout) {
                qoutput = QMatrix.load(File(dir, "qoutput.matrix").openAutoDataInput())
            } else {
                output = loadMatrix(File(dir, "output.matrix"))
            }

            val model = Model(input, output, args, 0)
            if(quantInput){
                model.setQuantizePointer(qinput, qoutput)
            }

            if (args.model == ModelName.sup) {
                model.setTargetCounts(dictionary.getCounts(EntryType.label))
            } else {
                model.setTargetCounts(dictionary.getCounts(EntryType.word))
            }


            return FastText(args, dictionary,  model)
        }



        @JvmOverloads
        @Throws(Exception::class)
        @JvmStatic
        fun train(trainFile: File, model_name: ModelName = ModelName.sup, args: TrainArgs = TrainArgs()): FastText {
            return FastTextTrain().train(trainFile, model_name, args)
        }

        @JvmOverloads
        @Throws(Exception::class)
        @JvmStatic
        fun train(source: TrainExampleSource, model_name: ModelName = ModelName.sup, args: TrainArgs = TrainArgs()): FastText {
            return FastTextTrain().train(source, model_name, args)
        }


        /**
         * 分类模型量化
         *
         * @param out
         */
        @JvmOverloads
        @Throws(Exception::class)
        @JvmStatic
        fun quantize(fastText: FastText,
                     dsub:Int=2,
                     qnorm:Boolean=false):FastText {

            if (fastText.quant) {
                println("该模型已经被量化过")
                return fastText
            }

            if(fastText.args.model != ModelName.sup){
                throw RuntimeException("Only for sup model")
            }

            val qMatrix = QMatrix(fastText.input.rows(),fastText.input.cols(), dsub, qnorm)
            val inputMatrix = fastText.input.toMutableFloatMatrix()
            qMatrix.quantize(inputMatrix)


            val qModel = Model(FloatMatrix.floatArrayMatrix(0, 0),fastText.output,fastText.args,0)

            qModel.setQuantizePointer(qMatrix,null)


            val QFastText = FastText(fastText.args,fastText.dict,qModel)

            return QFastText
        }
    }
}

class Model(val input: FloatMatrix
            , val output: FloatMatrix,
            args_: Args,
            seed: Int) : BaseModel(args_, seed, output.rows()) {

    /**
     * 是否乘积量化模型(input)
     */
    var quant: Boolean = false

    /**
     * Right 是否量化
     */
    var quantOut = false

    var qinput = QMatrix()
    var qoutput = QMatrix()

    /**
     * hidden size 也就是向量的维度
     */
    private val hsz: Int = args_.dim // dim

    private val comparePairs = { o1: FloatIntPair, o2: FloatIntPair -> Floats.compare(o2.first, o1.first) }

    fun std_log(d: Float)=Math.log(d+1e-5)


    fun setQuantizePointer(qinput: QMatrix?, qoutput: QMatrix?) {

        qinput?.let {
            quant = true
            this.qinput = qinput
        }
        // qoutput 不为null就是out向量化
        qoutput?.let {
            quantOut = true
            this.qoutput = it
            this.outputMatrixSize = qoutput.m
        }
    }

    fun predict(input: IntArrayList, k: Int,
                heap: MutableList,
                hidden: MutableVector,
                output: MutableVector) {
        checkArgument(k > 0)

        computeHidden(input, hidden)
        if (args_.loss == LossName.hs) {
            dfs(k, 2 * outputMatrixSize - 2, 0.0f, heap, hidden)
        } else {
            findKBest(k, heap, hidden, output)
        }
        Collections.sort(heap, comparePairs)
    }

    fun findKBest(k: Int, heap: MutableList, hidden: Vector, output: MutableVector) {
        computeOutputSoftmax(hidden, output)
        for (i in 0 until outputMatrixSize) {
            val logoutputi = std_log(output[i]).toFloat()
            if (heap.size == k && logoutputi < heap[heap.size - 1].first) {
                continue
            }
            heap.add(FloatIntPair(logoutputi, i))
            Collections.sort(heap, comparePairs)
            if (heap.size > k) {
                Collections.sort(heap, comparePairs)
                heap.removeAt(heap.size - 1) // pop last
            }
        }
    }

    fun dfs(k: Int, node: Int, score: Float, heap: MutableList, hidden: Vector) {
        if (heap.size == k && score < heap[heap.size - 1].first) {
            return
        }

        if (tree[node].left == -1 && tree[node].right == -1) {
            heap.add(FloatIntPair(score, node))
            Collections.sort(heap, comparePairs)
            if (heap.size > k) {
                Collections.sort(heap, comparePairs)
                heap.removeAt(heap.size - 1) // pop last
            }
            return
        }

//        val f = sigmoid(output.dotRow(hidden, node - outputMatrixSize))
        var f = if (quant && quantOut) {
            qoutput.dotRow(hidden, node - outputMatrixSize)
        } else {
            output[node - outputMatrixSize] * hidden
        }
        f = 1.0f / (1 + exp(-f))


        dfs(k, tree[node].left, score + std_log(1.0f - f).toFloat(), heap, hidden)
        dfs(k, tree[node].right, score + std_log(f).toFloat(), heap, hidden)
    }


    private fun computeHidden(input: IntArrayList, hidden: MutableVector) {
        checkArgument(hidden.length() == hsz)
        hidden.zero()

        val buffer = input.buffer
        var i = 0
        val size = input.size()
        while (i < size) {
            val it = buffer[i]
            if (quant) {
                qinput.addToVector(hidden, it)
            } else {
                hidden += this.input[it]
            }
            i++
        }
        hidden *= (1.0f / input.size())
    }

    private fun computeOutputSoftmax(hidden: Vector, output: MutableVector) {
        if (quant && quantOut) {
            matrixMulVector(qoutput, hidden, output)
        } else {
            matrixMulVector(this.output, hidden, output)
        }

        var max = output[0]
        var z = 0.0f
        for (i in 1 until outputMatrixSize) {
            max = Math.max(output.get(i), max)
        }
        for (i in 0 until outputMatrixSize) {
            output[i] = Math.exp((output[i] - max).toDouble()).toFloat()
            z += output[i]
        }
        for (i in 0 until outputMatrixSize) {
            output[i] = output[i] / z
        }
    }

    private fun matrixMulVector(matrix: QMatrix, v: Vector, target: MutableVector) {
        checkArgument(matrix.m == target.length())
        checkArgument(matrix.n == v.length())

        val m_ = matrix.m
        for (i in 0 until m_) {
            target[i] = matrix.dotRow(v,i)
        }
    }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy