All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.nlp.fasttext.dictionary.BuildDictFromSource.kt Maven / Gradle / Ivy

package com.mayabot.nlp.fasttext.dictionary

import com.mayabot.nlp.fasttext.args.Args
import com.mayabot.nlp.fasttext.train.SampleLine
import com.mayabot.nlp.fasttext.utils.logger
import com.mayabot.nlp.fasttext.utils.loggerln


/**
 * 读取分析原始语料,语料单词直接空格
 *
 * @param trainExampleSource 训练文件
 * @throws Exception
 */
@ExperimentalUnsignedTypes
@Throws(Exception::class)
fun buildFromFile(args: Args,
                  sources: List>,
                  maxVocabSize: Int = 500000,
                  initWordListSize: Int = 5000
): Dictionary {


    val builder = DictionaryBuilder(args.label, maxVocabSize, initWordListSize)

    val mmm = 0.75 * maxVocabSize


    var lastMinThreshold = 1L

    loggerln("Read file build dictionary ...")

    var thresholdCount = 0
    fun thresholdBuilder() {
        var minThreshold: Long = lastMinThreshold
        thresholdCount++
        while (builder.size > (mmm * 0.75f)) {
            val before = builder.size
            builder.threshold(minThreshold, minThreshold)
            lastMinThreshold = minThreshold
            loggerln("word size from ${before} to ${builder.size} , threshold min $minThreshold")
            minThreshold++
        }
    }

    if (args.keepPreTrainedVector && args.preTrainedVectors != null) {
        val file = args.preTrainedVectors!!
        var c = 0
        file.useLines { lines ->
            val min = args.minCount * 10
            lines.drop(1).forEach { line ->
                if (line.isNotEmpty()) {
                    val x = line.indexOf(' ')
                    if (x != -1 && x != 0) {
                        val word = line.substring(0, x)
                        if (word != EOS && word != BOW && word != EOW) {
                            builder.add(word, min)
                            c++
                            if (builder.size > mmm) {
                                thresholdBuilder()
                            }
                        }

                    }
                }
            }
        }

        builder.add(EOS)

        println("Load ${c} words from preTrainedVectors")

    }

    for (source in sources) {
        for (sample in source) {
            sample.words.forEach { token ->
                builder.add(token)
                if (builder.ntokens % 1000000 == 0L) {
                    logger("\rRead " + builder.ntokens / 1000000 + "M words")
                }

                if (builder.size > mmm) {
                    thresholdBuilder()
                }
            }
            builder.add(EOS)
        }
    }

    // 系统级别的裁剪
    builder.threshold(args.minCount.toLong(), args.minCountLabel.toLong())

    // 塌缩
    builder.wordIdMap.collapseWordHash2Id()

//        val dictionary = builder.toDictionary(args)
//        dictionary.initTableDiscard()
//        dictionary.initNgrams()

    loggerln("\rRead %d words\n".format(builder.ntokens))
    loggerln("Number of words:  ${builder.nwords}")
    loggerln("Number of labels: ${builder.nlabels}")

    if (thresholdCount > 0) loggerln("Max threshold count: $lastMinThreshold")

    if (builder.wordIdMap.size == 0) {
        throw RuntimeException("Empty vocabulary. Try a smaller -minCount second.")
    }

    val dict = builder.toDictionary(args)
    loggerln("Number of wordHash2Id: ${dict.onehotMap.wordHash2WordId.size}")
    return dict

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy