All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.mynlp.fasttext.Dictionary.kt Maven / Gradle / Ivy

The newest version!
package com.mayabot.mynlp.fasttext

import com.carrotsearch.hppc.IntArrayList
import com.carrotsearch.hppc.IntIntHashMap
import com.carrotsearch.hppc.IntIntMap
import com.carrotsearch.hppc.LongArrayList
import com.google.common.base.CharMatcher
import com.google.common.base.Splitter
import java.io.File
import java.io.IOException
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import com.mayabot.blas.*
import java.util.*
import com.google.common.primitives.UnsignedLong



const val HASH_C = 116049371
const val MAX_VOCAB_SIZE = 30000000
const val MAX_LINE_SIZE = 1024
const val EOS = ""
/**
 * begin of word
 */
const val BOW = "<"

/**
 * end of word
 */
const val EOW = ">"

/**
 * 字典
 * 分层
 * [
 * words,
 * labels,
 * bucket
 * ]
 *
 * 目前的代码看来,labels和bucket是互斥的,只能存在一个
 *
 * @author jimichan
 */
class Dictionary(private val args: Args) {

    var size: Int = 0
        private set

    var wordList: MutableList = ArrayList(50000 * 4)
    private var word_hash_2_id: IntArray = IntArray(MAX_VOCAB_SIZE).apply {
        fill(-1)
    }

    var nwords: Int = 0
    var nlabels: Int = 0
    var ntokens: Long = 0

    var pruneidxSize = -1L
    var pdiscard: FloatArray = FloatArray(0)
    var pruneidx: IntIntMap = IntIntHashMap()

    /**
     * maxn length of char ngram
     */
    val maxn = args.maxn
    val minn = args.minn
    val bucket = args.bucket
    val bucketLong = args.bucket.toLong()
    val wordNgrams = args.wordNgrams
    val label = args.label
    val model = args.model


    fun isPruned() = pruneidxSize >= 0


    fun getType(id: Int): EntryType {
        checkArgument(id >= 0)
        checkArgument(id < size)
        return wordList[id].type
    }

    fun getType(w: String): EntryType {
        return if (w.startsWith(label)) EntryType.label else EntryType.word
    }

    /**
     * word 在words_里面的下标,也就是词ID。
     *
     * @param w
     * @return
     */
    fun getId(w: String): Int {
        val id = find(w)
        return if (id == -1) {
            -1 //词不存在
        } else word_hash_2_id[id]
    }

    private fun getId(w: String, h: Long): Int {
        val id = find(w, h)
        return if (id == -1) {
            -1 //词不存在
        } else word_hash_2_id[id]
    }

    /**
     * 向词典中新增一个词
     */
    fun add(w: String) {
        val h = find(w)
        val id = word_hash_2_id[h]

        if (id == -1) {
            wordList.add(Entry(w, 1, getType(w)))
            word_hash_2_id[h] = size++
        } else {
            wordList[id].count++
        }
        ntokens++
    }

    /**
     * 返回的是word_hash_2_id的下标。返回的是不冲突的hash值,也是word_hash的下标索引的位置
     * 原来的find
     * @param w
     * @return
     */
    private fun find(w: String): Int {
        return find(w, stringHash(w))
    }

    /**
     * 找到word,对应的ID,要么还没人占坑。如果有人占坑了,那么要相等
     * word2int  [index -> words_id]
     *
     * @param w
     * @param hash
     * @return 返回的是word2int的下标
     */
    private fun find(w: String, hash: Long): Int {
        var h = (hash % MAX_VOCAB_SIZE).toInt()
        while (word_hash_2_id[h] != -1 && wordList[word_hash_2_id[h]].word != w) {
            h = (h + 1) % MAX_VOCAB_SIZE
        }
        return h
    }

    private fun stringHash(str: String): Long {
        // 0xffffffc5;
        var h = 2166136261L.toInt()
        for (strByte in str.toByteArray()) {
            // FNV-1a
            h = (h xor strByte.toInt()) * 16777619
        }

        return h.toLong() and 0xffffffffL
    }

    fun getWord(id: Int): String {
        checkArgument(id >= 0)
        checkArgument(id < size)
        return wordList[id].word
    }


    /**
     * 读取分析原始语料,语料单词直接空格
     *
     * @param file 训练文件
     * @throws Exception
     */
    @Throws(Exception::class)
    fun buildFromFile(file: TrainExampleSource) {

        val mmm = 0.75 * MAX_VOCAB_SIZE

        //final String lineDelimitingRegex_ = " |\r|\t|\\v|\f|\0";

        var minThreshold: Long = 1

        println("Read file build dictionary ...")

        val splitter = Splitter.on(CharMatcher.whitespace())
                .omitEmptyStrings().trimResults()

        val lines = file.iteratorAll()
        lines.use {
            it.forEach { line->
                        line.forEach { token->
                            add(token)
                            if (ntokens % 1000000 == 0L && args.verbose > 1) {
                                print("\rRead " + ntokens / 1000000 + "M words")
                            }

                            if (size > mmm) {
                                minThreshold++
                                threshold(minThreshold, minThreshold)
                            }
                        }
                        add(EOS)
           }
            threshold(args.minCount.toLong(), args.minCountLabel.toLong())

            initTableDiscard()
            initNgrams()


            if (args.verbose > 0) {
                System.out.printf("\rRead %dM words\n", ntokens / 1000000)
                println("Number of words:  $nwords")
                println("Number of labels: $nlabels")
            }
            if (size == 0) {
                System.err.println("Empty vocabulary. Try a smaller -minCount second.")
                System.exit(1)
            }
        }

//
//        file.useLines { lines ->
//            lines.filterNot { it.isNullOrBlank() || it.startsWith("#") }
//                    .forEach { line ->
//                        splitter.split(line).forEach { token ->
//                            add(token)
//                            if (ntokens % 1000000 == 0L && args.verbose > 1) {
//                                print("\rRead " + ntokens / 1000000 + "M words")
//                            }
//
//                            if (size > mmm) {
//                                minThreshold++
//                                threshold(minThreshold, minThreshold)
//                            }
//                        }
//                        add(EOS)
//                    }
//
//
//        }


    }

    /**
     * 初始化 char ngrams 也就是 subwords
     */
    private fun initNgrams() {
        for (id in 0 until size) {
            val e = wordList[id]
            val word = BOW + e.word + EOW

            if (maxn == 0) {
                //优化 maxn 一定没有subwords ,这个是分类模型里面的默认定义
                e.subwords = IntArrayList.from(id)
            } else {
                e.subwords = IntArrayList(1)
                e.subwords.add(id)

                if (e.word != EOS) {
                    computeSubwords(word, e.subwords)
                }
            }
            //e.subwords.trimToSize();
        }
    }


    private fun computeSubwords(word: String, ngrams: IntArrayList) {
        val word_len = word.length
        for (i in 0 until word_len) {

            if (charMatches(word[i])) {
                continue
            }

            var ngram: StringBuilder? = null

            var j = i
            var n = 1
            while (j < word_len && n <= maxn) {
                if (ngram == null) {
                    ngram = StringBuilder()
                }
                ngram.append(word[j++])
                while (j < word.length && charMatches(word[j])) {
                    ngram.append(word[j++])
                }
                if (n >= minn && !(n == 1 && (i == 0 || j == word.length))) {
                    val h = (stringHash(ngram.toString()) % bucket).toInt()
                    if (h < 0) {
                        System.err.println("computeSubwords h<0: $h on word: $word")
                    }
                    pushHash(ngrams, h)
                }
                n++
            }
        }
    }

    private fun pushHash(hashes: IntArrayList, id_: Int) {
        var id = id_
        if (pruneidxSize == 0L || id < 0) return

        if (pruneidxSize > 0) {
            if (pruneidx.containsKey(id)) {
                id = pruneidx.get(id)
            } else {
                return
            }
        }

        hashes.add(nwords + id)
    }


    private fun charMatches(ch: Char) = ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r'

    fun initTableDiscard() {
        pdiscard = FloatArray(size)
        val t = args.t
        for (i in 0 until size) {
            val f = wordList[i].count * 1.0f / ntokens
            pdiscard[i] = (Math.sqrt(t / f) + t / f).toFloat()
        }
    }

    fun threshold(t: Long, tl: Long) {
        wordList = wordList.filterNot { it.type == EntryType.word && it.count < t || it.type == EntryType.label && it.count < tl }
                .sortedWith(Comparator { o1, o2 -> o1.type.compareTo(o2.type) }.thenByDescending { it.count })
                .toMutableList()
        (wordList as java.util.ArrayList).trimToSize()

        size = 0
        nwords = 0
        nlabels = 0

        word_hash_2_id.fill(-1)

        wordList.forEach {
            val h = find(it.word)
            word_hash_2_id[h] = size++
            if (it.type == EntryType.word) {
                nwords++
            } else if (it.type == EntryType.label) {
                nlabels++
            }
        }
    }


    fun nwords() = nwords
    fun nlabels() = nlabels
    fun ntokens() = ntokens


    fun getCounts(type: EntryType): LongArray {
        val counts = if (EntryType.label == type)
            LongArray(nlabels())
        else
            LongArray(nwords())
        var i = 0
        for ((_, count, type1) in wordList) {
            if (type1 == type)
                counts[i++] = count
        }
        return counts
    }


    fun getLine(tokens: Iterable, words: IntArrayList, labels: IntArrayList): Int {
        val word_hashes = LongArrayList()
        var ntokens = 0

        words.clear()
        labels.clear()

        for (token in tokens) {
            val h = stringHash(token)
            val wid = getId(token, h)
            val type = if (wid < 0) getType(token) else getType(wid)
            ntokens++

            if (type == EntryType.word) {
                addSubwords(words, token, wid)
                word_hashes.add(h)
            } else if (type == EntryType.label && wid >= 0) {
                labels.add(wid - nwords)
            }
        }

        addWordNgrams(words, word_hashes, wordNgrams)

        return ntokens
    }

    companion object {
         val coeff = UnsignedLong.valueOf(116049371L)
         val U64_START = UnsignedLong.valueOf("18446744069414584320")
    }

    private fun addWordNgrams(line: IntArrayList,
                              hashes: LongArrayList,
                              n: Int) {
        //read word^ hash 3675003649 int32 -619963647 uint64 18446744073089587969 wid 1
//        for (i in 0 until hashes.size()) {
//            var h = hashes.get(i)
//            var j = i + 1
//            while (j < hashes.size() && j < i + n) {
//                h = (h * 116049371) + hashes.get(j)
//                pushHash(line, (h % bucket).toInt())
//                j++
//            }
//        }
//        AddWordNgramsHelper.addWordNGrams(line,hashes,n,bucket.toLong(),{x->
//            pushHash(h)
//        })
        val hashSize = hashes.size()

        for (i in 0 until hashSize) {
            var h = toUnsignedLong64(hashes.get(i))
            var j = i + 1
            while (j < hashSize && j < i + n) {
                //val h2 = hashes.get(j)
                val h2 = hashes.get(j).toInt().toLong()

                if (h2 >= 0) {
                    h = h.times(coeff).plus(UnsignedLong.valueOf(h2))
                } else {
                    h = h.times(coeff).minus(UnsignedLong.valueOf(-h2))
                }
                var id = h.mod(UnsignedLong.valueOf(bucketLong)).toInt()

                pushHash(line ,id)
                j++
            }
        }
    }

    // from https://github.com/linkfluence/fastText4j/blob/b018438e84bebd20f89a701c35f022139418930c/src/main/java/fasttext/BaseDictionary.java


    private fun toUnsignedLong64(l: Long): UnsignedLong {
        return if (l > Integer.MAX_VALUE) {
            U64_START.plus(UnsignedLong.valueOf(l))
        } else {
            UnsignedLong.valueOf(l)
        }
    }

    private fun addSubwords(line: IntArrayList,
                            token: String,
                            wid: Int) {
        if (wid < 0) { // out of vocab
            if (EOS != token) {
                computeSubwords(BOW + token + EOW, line)
            }
        } else {
            if (maxn <= 0) { // in vocab w/o subwords
                line.add(wid)
            } else { // in vocab w/ subwords
                val ngrams = getSubwords(wid)
                line.addAll(ngrams)
            }
        }
    }

    fun getSubwords(id: Int): IntArrayList {
        checkArgument(id >= 0)
        checkArgument(id < nwords)
        return wordList[id].subwords
    }

    fun getSubwords(word: String): IntArrayList {
        val i = getId(word)

        if (i >= 0) {
            return wordList[i].subwords
        }

        val ngrams = IntArrayList()
        computeSubwords(BOW + word + EOW, ngrams)

        return ngrams
    }

    fun getSubwords(word: String, ngrams: IntArrayList,
                    substrings: MutableList) {
        val i = getId(word)
        ngrams.clear()
        substrings.clear()
        if (i >= 0) {
            ngrams.add(i)
            substrings.add(wordList[i].word)
        } else {
            ngrams.add(-1)
            substrings.add(word)
        }

        computeSubwords(BOW + word + EOW, ngrams)
    }

    fun getLine(tokens: List, words: IntArrayList,
                rng: Random): Int {
        var ntokens = 0
        words.clear()
        for (token in tokens) {
            val h = find(token)
            val wid = word_hash_2_id[h]
            if (wid < 0) continue

            ntokens++

            if (getType(wid) == EntryType.word && !discard(wid, rng.nextFloat())) {
                words.add(wid)
            }
            if (ntokens > MAX_LINE_SIZE || token == EOS) {
                break
            }
        }


        return ntokens
    }

    private fun discard(id: Int, rand: Float): Boolean {
        checkArgument(id >= 0)
        checkArgument(id < nwords)
        return if (model == ModelName.sup) false else rand > pdiscard[id]
    }


    fun getLabel(lid: Int): String {
        checkArgument(lid >= 0)
        checkArgument(lid < nlabels)
        return wordList[lid + nwords].word
    }

    @Throws(IOException::class)
    fun save(channel: FileChannel) {
        channel.writeInt(size)
        channel.writeInt(nwords)
        channel.writeInt(nlabels)
        channel.writeLong(ntokens)
        channel.writeLong(pruneidxSize)

        val buffer = ByteBuffer.allocate(1024 * 1024)
        val em = buffer.capacity() * 0.25f
        for (entry in wordList) {
            buffer.writeUTF(entry.word)
            buffer.putLong(entry.count)
            buffer.put(entry.type.value.toByte())

            if (buffer.remaining() < em) {
                buffer.flip()
                while (buffer.hasRemaining()) {
                    channel.write(buffer)
                }
                buffer.clear()
            }
        }

        buffer.flip()
        while (buffer.hasRemaining()) {
            channel.write(buffer)
        }

        val buffer2 = ByteBuffer.allocate(pruneidx.size() * 4)
        pruneidx.forEach {
            buffer2.putInt(it.key, it.value)
        }
        buffer2.flip()
        channel.write(buffer2)

    }


    @Throws(IOException::class)
    fun load(buffer: AutoDataInput): Dictionary {
        // wordList.clear();
        // word2int_.clear();

        size = buffer.readInt()
        nwords = buffer.readInt()
        nlabels = buffer.readInt()
        ntokens = buffer.readLong()
        pruneidxSize = buffer.readLong()

        //        word_hash_2_id = new LongIntScatterMap(size_);
        wordList = ArrayList(size)

        //size 189997 18万的词汇
        //val byteArray = ByteArray(1024)
        for (i in 0 until size) {
            val e = Entry(buffer.readUTF(), buffer.readLong(), EntryType.fromValue(buffer.readUnsignedByte().toInt()))
            wordList.add(e)
            word_hash_2_id[find(e.word)] = i
        }

        pruneidx.clear()
        for (i in 0 until pruneidxSize) {
            val first = buffer.readInt()
            val second = buffer.readInt()
            pruneidx.put(first, second)
        }

        initTableDiscard()
        //if (ModelName.cbow == args_.model || ModelName.sg == args_.model) {
        initNgrams()
        //}
        return this
    }
}

/**
 * 返回的是word2int的下标。返回的是不冲突的hash值,也是word_hash的下标索引的位置
 * 原来的find
 * @param w
 * @return
 */

val Empty_IntArrayList = IntArrayList(0)

data class Entry(
        val word: String,
        var count: Long,
        val type: EntryType
) {
    var subwords: IntArrayList = Empty_IntArrayList
}


enum class EntryType constructor(var value: Int) {

    word(0), label(1);

    override fun toString(): String {
        return if (value == 0) "word" else if (value == 1) "label" else "unknown"
    }

    companion object {

        internal var types = EntryType.values()

        @Throws(IllegalArgumentException::class)
        fun fromValue(value: Int): EntryType {
            try {
                return types[value]
            } catch (e: ArrayIndexOutOfBoundsException) {
                throw IllegalArgumentException("Unknown EntryType enum second :$value")
            }

        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy