All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.nlp.fasttext.dictionary.FastWordMap.kt Maven / Gradle / Ivy

package com.mayabot.nlp.fasttext.dictionary

import java.lang.Integer.max
import java.lang.Integer.min

/**
 * 每个词和一个WordId关联。快速更具id查找到word。
 * 也就是为Word进行OneHot编码
 */
@ExperimentalUnsignedTypes
class FastWordMap(
        /**
         * 记录的是hash和id的隐射关系
         */
        var wordHash2WordId: IntArray,

        val wordList: ArrayList,

        val label: String = "__label__"
) {

    /**
     * 构建一个空的
     */
    constructor(label: String,
                vocabSize: Int,
                initWordListSize: Int?) : this(
            IntArray(vocabSize) { -1 },
            ArrayList(initWordListSize ?: max(1000, min(10000, vocabSize / 1000))),
            label
    )

    val size get() = wordList.size

    operator fun get(id: Int) = wordList[id]

    /**
     * 向词典中新增一个词
     */
    fun add(w: String, count: Int = 1) {

        val h = find(w)
        val id = wordHash2WordId[h]

        if (id == -1) {
//            wordList.add(Entry(w, 1, getType(w)))
//            word_hash_2_id[h] = size++
            wordHash2WordId[h] = wordList.size
            val entry = Entry(w, count.toLong(), getType(w))
            wordList.add(entry)
        } else {
            wordList[id].count += count
        }
    }

    fun getWord(id: Int): String {
        return wordList[id].word
    }


    fun getType(id: Int): EntryType {
        return wordList[id].type
    }

    fun getType(w: String): EntryType {
        return if (w.startsWith(label)) EntryType.label else EntryType.word
    }

    /**
     * word 在wordList里面的下标,也就是词ID。
     *
     * @param w
     * @return
     */
    fun getId(w: String): Int {
        val id = find(w)
        return if (id == -1) {
            -1 //词不存在
        } else wordHash2WordId[id]
    }

    /**
     * word 在wordList里面的下标,也就是词ID。
     */
    fun getId(w: String, h: UInt): Int {
        val id = find(w, h)
        return if (id == -1) {
            -1 //词不存在
        } else wordHash2WordId[id]
    }

    /**
     * 返回的是word_hash_2_id的下标。返回的是不冲突的hash值,也是word_hash的下标索引的位置
     * 原来的find
     * @param w
     * @return 返回的是word_hash_2_id的下标 返回-1标识不存在
     */
    fun find(w: String): Int {
        return find(w, w.fnv1aHash())
    }

    /**
     * 找到word,对应的ID,要么还没人占坑。如果有人占坑了,那么要相等
     * word2int  [index -> words_id]
     * 就是为word在wordList中找到一个正确的下标,最终建立word和
     * Entry的关系。为什么不直接用HashMap呢,业务HashMap的代价太高了吧
     * @param w
     * @param hash
     * @return 返回的是word_hash_2_id的下标 返回-1标识不存在
     */
    fun find(w: String, hash: UInt): Int {
        val word2intSize = wordHash2WordId.size
        var id = (hash.toLong() % word2intSize).toInt()
        while (wordHash2WordId[id] != -1
                && wordList[wordHash2WordId[id]].word != w) {
            id = (id + 1) % word2intSize
        }
        return id
    }

    /**
     * 为wordList里面的每个word进行编码
     */
    fun initWordHash2WordId() {
        for (i in 0 until size) {
            wordHash2WordId[find(wordList[i].word)] = i
        }
    }

    fun collapseWordHash2Id() {
        wordHash2WordId = IntArray((size.toFloat() / 0.75).toInt()) { -1 }
        initWordHash2WordId()
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy