All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.nlp.fasttext.dictionary.DictionaryBuilder.kt Maven / Gradle / Ivy

package com.mayabot.nlp.fasttext.dictionary

import com.mayabot.nlp.fasttext.args.Args
import java.util.ArrayList
import kotlin.Comparator

/**
 * 字典
 * 分层
 * [
 * words,
 * labels,
 * bucket
 * ]
 *
 * 目前的代码看来,labels和bucket是互斥的,只能存在一个
 *
 * @author jimichan
 */
@ExperimentalUnsignedTypes
class DictionaryBuilder(
        label: String,
        /**
         * 这个肯定要比initWordListSize数量大两个数量级吧
         */
        vocabSize: Int = MAX_VOCAB_SIZE,
        initWordListSize: Int? = null
) {

    val wordIdMap = FastWordMap(label, vocabSize, initWordListSize)

    /**
     * 一个有多少个词(不是排重)
     */
    var ntokens: Long = 0

    /**
     * word的排重的数量
     */
    var nwords: Int = 0

    /**
     * label数量
     */
    var nlabels: Int = 0

    fun toDictionary(args: Args): Dictionary {

        return Dictionary(
                args = args,
                onehotMap = wordIdMap,
                ntokens = ntokens,
                nwords = nwords,
                nlabels = nlabels
        )
    }

    val size get() = wordIdMap.size

    fun add(word: String, count: Int = 1) {
        wordIdMap.add(word, count)
        ntokens++
    }

    /**
     * 截断
     */
    fun threshold(t: Long, minLabelCount: Long) {

        val wordList = wordIdMap.wordList
                .filterNot {
                    (it.type == EntryType.word && it.count < t)
                            || (it.type == EntryType.label && it.count < minLabelCount)
                }
                .sortedWith(Comparator { o1, o2 ->
                    o1.type.compareTo(o2.type)
                    if (o1.type != o2.type) {
                        o1.type.compareTo(o2.type)
                    } else {
                        o2.count.compareTo(o1.count)
                    }
                })
                .toMutableList()
        (wordList as ArrayList).trimToSize()

        nwords = 0
        nlabels = 0

        wordIdMap.wordList.clear()
        wordIdMap.wordList.addAll(wordList)

        val word_hash_2_id = wordIdMap.wordHash2WordId
        word_hash_2_id.fill(-1)

        for ((index, entry) in wordList.withIndex()) {
            val h = wordIdMap.find(entry.word)
            word_hash_2_id[h] = index
            if (entry.type == EntryType.word) {
                nwords++
            } else if (entry.type == EntryType.label) {
                nlabels++
            }
        }

    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy