All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.mayabot.nlp.fasttext.dictionary.Dictionary.kt Maven / Gradle / Ivy
package com.mayabot.nlp.fasttext.dictionary
import com.mayabot.nlp.common.IntArrayList
import com.mayabot.nlp.fasttext.args.Args
import com.mayabot.nlp.fasttext.args.ModelName
import com.mayabot.nlp.fasttext.utils.*
import java.io.IOException
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.util.*
import kotlin.collections.HashMap
import kotlin.math.min
import kotlin.random.Random
const val HASH_C = 116049371
const val MAX_VOCAB_SIZE = 30000000
const val MAX_LINE_SIZE = 1024
const val coeff: ULong = 116049371u
const val U64_START: ULong = 18446744069414584320u
/**
* 句子的结尾
* end of sentence
*/
const val EOS = ""
/**
* begin of word
*/
const val BOW = "<"
/**
* end of word
*/
const val EOW = ">"
/**
* 字典
* 分层
* [
* words,
* labels,
* bucket
* ]
*
* 目前的代码看来,labels和bucket是互斥的,只能存在一个
*
* @author jimichan
*/
@ExperimentalUnsignedTypes
class Dictionary(
val args: Args,
val onehotMap: FastWordMap,
val ntokens: Long,
val nwords: Int,
val nlabels: Int
) {
private var pdiscard: FloatArray = FloatArray(0)
var pruneidxSize = -1L
private var pruneidx: HashMap = HashMap()
private val maxn = args.maxn
private val minn = args.minn
private val bucket = args.bucket
private val bucketULong = bucket.toULong()
private val wordNgrams = args.wordNgrams
val size get() = onehotMap.size
fun isPruned() = pruneidxSize >= 0
private fun isWhiteSpaceChar(ch: Char) = ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r'
/**
* 统计每个label或者word的数量,变成一个一维向量
*/
fun getCounts(type: EntryType): LongArray {
val counts = if (EntryType.label == type)
LongArray(nlabels)
else
LongArray(nwords)
var i = 0
for ((_, count, type1) in onehotMap.wordList) {
if (type1 == type)
counts[i++] = count
}
return counts
}
/**
* tokens 是一系列的单词或者label。
* labels存放的是label对应的id向量,从0开始
* words里面存放了
*/
fun getLine(tokens: Iterable, words: IntArrayList, labels: IntArrayList): Int {
val word_hashes = IntArrayList()
var ntokens = 0
words.clear()
labels.clear()
for (token in tokens) {
val h = token.fnv1aHash()
val wid = onehotMap.getId(token, h)
val type = if (wid < 0) onehotMap.getType(token) else onehotMap.getType(wid)
ntokens++
if (type == EntryType.word) {
addSubwords(words, token, wid)
word_hashes.add(h.toInt())
} else if (type == EntryType.label && wid >= 0) {
labels.add(wid - nwords)
}
}
addWordNgrams(words, word_hashes, wordNgrams)
return ntokens
}
fun getLine(tokens: List, words: IntArrayList, rng: Random): Int {
var ntokens = 0
words.clear()
for (token in tokens) {
val h = onehotMap.find(token)
val wid = onehotMap.wordHash2WordId[h]
if (wid < 0) continue
ntokens++
if (onehotMap.getType(wid) == EntryType.word && !discard(wid, rng.nextFloat())) {
words.add(wid)
}
if (ntokens > MAX_LINE_SIZE || token == EOS) {
break
}
}
return ntokens
}
private fun addWordNgrams(line: IntArrayList,
hashes: IntArrayList,
n: Int) {
val hashSize = hashes.size()
for (i in 0 until hashSize) {
var h = hashes.get(i).toULong()
for (j in i + 1 until min(hashSize, i + n)) {
h = h * coeff + hashes.get(j).toULong()
pushHash(line, (h % bucketULong).toInt())
}
}
}
private fun addSubwords(line: IntArrayList,
token: String,
wid: Int) {
if (wid < 0) { // out of vocab
if (EOS != token) {
computeSubwords(BOW + token + EOW, line)
}
} else {
if (maxn <= 0) { // in vocab w/o subwords
line.add(wid)
} else { // in vocab w/ subwords
val ngrams = getSubwords(wid)
line.addAll(ngrams)
}
}
}
fun getSubwords(id: Int): IntArrayList {
check(id >= 0)
check(id < nwords)
return onehotMap[id].subwords
}
fun getSubwords(word: String): IntArrayList {
val i = onehotMap.getId(word)
if (i >= 0) {
return onehotMap[i].subwords
}
val ngrams = IntArrayList()
if (word != EOS) {
computeSubwords(BOW + word + EOW, ngrams)
}
return ngrams
}
fun getSubwords(word: String, ngrams: IntArrayList, substrings: MutableList) {
val i = onehotMap.getId(word)
ngrams.clear()
substrings.clear()
if (i >= 0) {
ngrams.add(i)
substrings.add(onehotMap[i].word)
} else {
ngrams.add(-1)
substrings.add(word)
}
if (word != EOS) {
computeSubwords(BOW + word + EOW, ngrams)
}
}
private fun discard(id: Int, rand: Float): Boolean {
check(id >= 0)
check(id < nwords)
return if (args.model == ModelName.sup) false else rand > pdiscard[id]
}
fun getLabel(lid: Int): String {
check(lid >= 0)
check(lid < nlabels)
return onehotMap[lid + nwords].word
}
operator fun get(word: String) = onehotMap.getId(word)
fun getWordId(word: String) = onehotMap.getId(word)
fun init() {
initTableDiscard()
initNgrams()
}
fun initTableDiscard() {
val pdiscard = FloatArray(onehotMap.size)
val t = args.t
val wordList = onehotMap.wordList
for (i in 0 until onehotMap.size) {
val f = wordList[i].count * 1.0f / ntokens
pdiscard[i] = (kotlin.math.sqrt(t / f) + t / f).toFloat()
}
this.pdiscard = pdiscard
}
/**
* 初始化 char ngrams 也就是 subwords
*/
fun initNgrams() {
val wordList = onehotMap.wordList
for (id in 0 until onehotMap.size) {
val e = wordList[id]
val word = BOW + e.word + EOW
if (maxn == 0) {
//优化 maxn 一定没有subwords ,这个是分类模型里面的默认定义
e.subwords = IntArrayList.from(id)
} else {
e.subwords = IntArrayList(1)
e.subwords.add(id)
if (e.word != EOS) {
computeSubwords(word, e.subwords)
}
}
}
}
private fun computeSubwords(word: String, ngrams: IntArrayList) {
val word_len = word.length
for (i in 0 until word_len) {
if (isWhiteSpaceChar(word[i])) {
continue
}
var ngram: StringBuilder? = null
var j = i
var n = 1
while (j < word_len && n <= maxn) {
if (ngram == null) {
ngram = StringBuilder()
}
ngram.append(word[j++])
while (j < word.length && isWhiteSpaceChar(word[j])) {
ngram.append(word[j++])
}
if (n >= minn && !(n == 1 && (i == 0 || j == word.length))) {
val h = (ngram.toString().fnv1aHash().toLong() % bucket).toInt()
if (h < 0) {
loggerln("computeSubwords h<0: $h on word: $word")
}
pushHash(ngrams, h)
}
n++
}
}
}
private fun pushHash(hashes: IntArrayList, id_: Int) {
var id = id_
if (pruneidxSize == 0L || id < 0) return
if (pruneidxSize > 0) {
if (pruneidx.containsKey(id)) {
id = pruneidx.getValue(id)
} else {
return
}
}
hashes.add(nwords + id)
}
@Throws(IOException::class)
fun save(channel: FileChannel) {
channel.writeInt(this.size)
channel.writeInt(nwords)
channel.writeInt(nlabels)
channel.writeLong(ntokens)
channel.writeLong(pruneidxSize)
val buffer = ByteBuffer.allocate(1024 * 1024)
val em = buffer.capacity() * 0.25f
val wordList = onehotMap.wordList
for (entry in wordList) {
buffer.writeUTF(entry.word)
buffer.putLong(entry.count)
buffer.put(entry.type.value.toByte())
if (buffer.remaining() < em) {
buffer.flip()
while (buffer.hasRemaining()) {
channel.write(buffer)
}
buffer.clear()
}
}
buffer.flip()
while (buffer.hasRemaining()) {
channel.write(buffer)
}
val buffer2 = ByteBuffer.allocate(pruneidx.size * 4)
pruneidx.forEach {
buffer2.putInt(it.key, it.value)
}
buffer2.flip()
channel.write(buffer2)
}
fun getWord(wid: Int) = onehotMap.getWord(wid)
fun getWordEntity(wid: Int) = onehotMap.get(wid)
companion object {
@Throws(IOException::class)
fun loadModel(args: Args, buffer: AutoDataInput): Dictionary {
// wordList.clear();
// word2int_.clear();
val size = buffer.readInt()
val nwords = buffer.readInt()
val nlabels = buffer.readInt()
val ntokens = buffer.readLong()
val pruneidxSize = buffer.readLong()
// word_hash_2_id = new LongIntScatterMap(size_);
val wordList = ArrayList(size)
for (i in 0 until size) {
val e = Entry(buffer.readUTF(), buffer.readLong(), EntryType.fromValue(buffer.readUnsignedByte().toInt()))
wordList.add(e)
}
val pruneidx = HashMap()
for (i in 0 until pruneidxSize) {
val first = buffer.readInt()
val second = buffer.readInt()
pruneidx.put(first, second)
}
// 这里的实际WordHash2WordId是词数量的0.75倍
val dict = Dictionary(args,
FastWordMap(
IntArray((size.toFloat() / 0.75).toInt()) { -1 },
wordList),
ntokens,
nwords,
nlabels
)
// saber 把这两个调整到下面两句之前,在之前fix模型加载错误的时候,是漏掉了这两句
dict.pruneidxSize = pruneidxSize
dict.pruneidx = pruneidx
dict.initTableDiscard()
dict.initNgrams()
dict.onehotMap.initWordHash2WordId()
return dict
}
}
}