
com.mayabot.mynlp.fasttext.FastText.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of fastText4j Show documentation
Show all versions of fastText4j Show documentation
Java(kotlin) implementation of the Fasebook 's FastText
The newest version!
package com.mayabot.mynlp.fasttext
import com.carrotsearch.hppc.IntArrayList
import com.google.common.base.Charsets
import com.google.common.base.Stopwatch
import com.google.common.collect.ImmutableList
import com.google.common.collect.Iterables
import com.google.common.collect.Lists
import com.google.common.collect.Sets
import com.google.common.io.Files
import com.google.common.primitives.Floats
import com.mayabot.blas.*
import com.mayabot.blas.Vector
import java.io.File
import java.io.IOException
import java.io.InputStream
import java.text.DecimalFormat
import java.util.*
import java.util.concurrent.TimeUnit
import kotlin.math.exp
import kotlin.system.exitProcess
const val FASTTEXT_VERSION = 12
const val FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314
data class FloatIntPair(@JvmField var first: Float, @JvmField var second: Int)
data class FloatStringPair(@JvmField var first: Float, @JvmField var second: String){
override fun toString(): String {
return "[$second,$first]"
}
}
class FastText(internal val args: Args,
internal val dict: Dictionary,
internal val model: Model
) {
/**
* 是否量化. 指的是隐藏层或者LEFT或者是词向量是否向量化
*/
val quant = model.quant
val input = model.input
val output = model.output
lateinit var wordVectors: FloatMatrix
/**
* 预测分类标签
*
* @param tokens
* @param k
* @return
*/
fun predict(tokens: Iterable, k: Int): List {
val tokens2 = Iterables.concat(tokens, listOf(EOS))
val words = IntArrayList()
val labels = IntArrayList()
dict.getLine(tokens2, words, labels)
if (words.isEmpty) {
return ImmutableList.of()
}
val hidden = MutableByteBufferVector(args.dim)
val output = MutableByteBufferVector(dict.nlabels())
val modelPredictions = Lists.newArrayListWithCapacity(k)
model.predict(words, k, modelPredictions, hidden, output)
return modelPredictions.map { x -> FloatStringPair(exp(x.first), dict.getLabel(x.second)) }
}
private fun findNN(wordVectors: FloatMatrix, queryVec: Vector, k: Int, sets: Set): List {
var queryNorm = queryVec.norm2()
if (Math.abs(queryNorm) < 1e-8) {
queryNorm = 1f
}
val mostSimilar = (0 until k).map { FloatStringPair(-1f,"") }.toList().toTypedArray()
val mastSimilarLast = mostSimilar.size - 1
for (i in 0 until dict.nwords()) {
val dp = wordVectors[i] *queryVec / queryNorm
val last = mostSimilar[mastSimilarLast]
if (dp > last.first) {
last.first = dp
last.second = dict.getWord(i)
mostSimilar.sortByDescending { it.first }
}
}
val result = Lists.newArrayList()
for (r in mostSimilar) {
if (r.first != -1f && !sets.contains(r.second)) {
result.add(r)
}
}
return result
}
/**
* NearestNeighbor
*/
fun nearestNeighbor(wordQuery: String, k: Int): List {
if (!this::wordVectors.isInitialized) {
val stopwatch = Stopwatch.createStarted()
wordVectors = FloatMatrix.floatArrayMatrix(dict.nwords,args.dim).apply {
preComputeWordVectors(this)
}
stopwatch.stop()
println("Init wordVectors martix use time ${stopwatch.elapsed(TimeUnit.MILLISECONDS)} ms")
}
val queryVec = getWordVector(wordQuery)
val sets = HashSet()
sets.add(wordQuery)
return findNN(wordVectors, queryVec, k, sets)
}
/**
* Query triplet (A - B + C)?
* @param A
* @param B
* @param C
* @param k
*/
fun analogies(A: String, B: String, C: String, k: Int): List {
if (!this::wordVectors.isInitialized) {
val stopwatch = Stopwatch.createStarted()
wordVectors = FloatMatrix.floatArrayMatrix(dict.nwords,args.dim).apply {
preComputeWordVectors(this)
}
stopwatch.stop()
println("Init wordVectors martix use time ${stopwatch.elapsed(TimeUnit.MILLISECONDS)} ms")
}
val buffer = Vector.floatArrayVector(args.dim)
val query = Vector.floatArrayVector(args.dim)
getWordVector(buffer, A)
query += buffer
getWordVector(buffer, B)
query += -1f to buffer
getWordVector(buffer, C)
query += buffer
val sets = Sets.newHashSet(A, B, C)
return findNN(wordVectors, query, k, sets)
}
/**
* 计算所有词的向量。
* 之所以向量都除以norm进行归一化。因为使用者。使用dot表达相似度,也会除以query vector的norm。然后归一化。
* 最后距离结构都是0 ~ 1 的数字
* @param wordVectors
*/
private fun preComputeWordVectors(wordVectors: MutableFloatMatrix) {
val vec = Vector.floatArrayVector(args.dim)
wordVectors.fill(0f)
for (i in 0 until dict.nwords()) {
val word = dict.getWord(i)
getWordVector(vec, word)
val norm = vec.norm2()
if (norm > 0) {
wordVectors[i] += 1.0f/norm to vec
}
}
}
/**
* 把词向量填充到一个Vector对象里面去
*
* @param vec
* @param word
*/
fun getWordVector(vec: MutableVector, word: String) {
vec.zero()
val ngrams = dict.getSubwords(word)
val buffer = ngrams.buffer
var i = 0
val len = ngrams.size()
while (i < len) {
addInputVector(vec, buffer[i])
i++
}
if (ngrams.size() > 0) {
vec *= 1.0f / ngrams.size()
}
}
fun getWordVector(word: String): Vector {
val vec = MutableByteBufferVector(args.dim)
getWordVector(vec, word)
return vec
}
/**
* 计算句子向量
* @return 句子向量
*/
fun getSentenceVector(tokens: Iterable): Vector {
val svec = MutableByteBufferVector(args.dim)
getSentenceVector(svec, tokens)
return svec
}
/**
* 句子向量
*
* @param svec
* @param tokens
*/
private fun getSentenceVector(svec: MutableVector, tokens: Iterable) {
svec.zero()
if (args.model == ModelName.sup) {
val line = IntArrayList()
val labels = IntArrayList()
dict.getLine(tokens, line, labels)
for (i in 0 until line.size()) {
addInputVector(svec, line.get(i))
}
if (!line.isEmpty) {
svec *= (1.0f / line.size())
}
} else {
val vec = MutableByteBufferVector(args.dim)
var count = 0
for (word in tokens) {
getWordVector(vec, word)
val norm = vec.norm2()
if (norm > 0) {
vec *= (1.0f / norm)
svec += vec
count++
}
}
if (count > 0) {
svec *= (1.0f / count)
}
}
}
private fun addInputVector(vec: MutableVector, ind: Int) {
if (quant) {
model.qinput.addToVector(vec, ind)
} else {
vec += input[ind]
}
}
/**
* 把词向量另存为文本格式
*
* @param file
*/
@Throws(Exception::class)
fun saveVectors(fileName: String) {
var fileName = fileName
if (!fileName.endsWith("vec")) {
fileName += ".vec"
}
val file = File(fileName)
if (file.exists()) {
file.delete()
}
if (file.parentFile != null) {
file.parentFile.mkdirs()
}
val vec = MutableByteBufferVector(args.dim)
val df = DecimalFormat("0.#####")
Files.asByteSink(file).asCharSink(Charsets.UTF_8).openBufferedStream().use { writer ->
writer.write("${dict.nwords()} ${args.dim}\n")
for (i in 0 until dict.nwords()) {
val word = dict.getWord(i)
getWordVector(vec, word)
writer.write(word)
writer.write(" ")
for (j in 0 until vec.length()) {
writer.write(df.format(vec[j].toDouble()))
writer.write(" ")
}
writer.write("\n")
}
}
}
/**
* 保存为自有的文件格式(多文件)
*/
@Throws(Exception::class)
fun saveModel(path: String) {
var path = File(path)
if (path.exists()) {
path.deleteRecursively()
}
path.mkdirs()
//dict
File(path, "dict.bin").outputStream().channel.use {
dict.save(it)
}
//args
File(path, "args.bin").outputStream().channel.use {
args.save(it)
}
if (!quant) {
//input float matrix
File(path, "input.matrix").outputStream().channel.use {
it.writeInt(model.input.rows())
it.writeInt(model.input.cols())
model.input.write(it)
}
} else {
File(path, "qinput.matrix").outputStream().channel.use {
model.qinput.save(it)
}
}
if (quant && model.quantOut) {
File(path, "qoutput.matrix").outputStream().channel.use {
model.qoutput!!.save(it)
}
} else {
File(path, "output.matrix").outputStream().channel.use {
it.writeInt(model.output.rows())
it.writeInt(model.output.cols())
model.output.write(it)
}
}
}
companion object {
/**
* 加载facebook官方C程序保存的文件模型,支持bin和ftz模型
*
* @param modelFilePath
* @throws IOException
*/
@JvmStatic
@Throws(Exception::class)
fun loadFasttextBinModel(modelFilePath: String): FastText {
return LoadFastTextFromClangModel.loadCModel(modelFilePath)
}
/**
* 加载facebook官方C程序保存的文件模型,支持bin和ftz模型
*
* @param modelPath
* @throws IOException
*/
@JvmStatic
@Throws(Exception::class)
fun loadFasttextBinModel(modelFile: File): FastText {
return LoadFastTextFromClangModel.loadCModel(modelFile)
}
/**
* 加载facebook官方C程序保存的文件模型,支持bin和ftz模型
*
* @param modelPath
* @throws IOException
*/
@JvmStatic
@Throws(Exception::class)
fun loadFasttextBinModel(modelStream: InputStream): FastText {
return LoadFastTextFromClangModel.loadCModel(modelStream)
}
private fun File.openAutoDataInput() = AutoDataInput.open(this)
/**
* 加载java程序保存的文件模型.
* path应该是一个目录,下面保存各个细节的文件
*/
@JvmOverloads
@JvmStatic
fun loadModel(modelPath: String, mmap: Boolean = false): FastText {
val dir = File(modelPath)
if (!dir.exists() || dir.isFile) {
println("error file $dir")
exitProcess(0)
}
val args = Args().loadClang(File(dir, "args.bin").openAutoDataInput())
val dictionary = Dictionary(args).load(File(dir, "dict.bin").openAutoDataInput())
fun loadMatrix(file: File): FloatMatrix {
return FloatMatrix.loadMatrix(file,mmap)
}
val quant = File(dir, "qinput.matrix").exists()
var input: FloatMatrix = FloatMatrix.floatArrayMatrix(0, 0)
var qinput: QMatrix? = null
if (quant) {
qinput = QMatrix.load(File(dir, "qinput.matrix").openAutoDataInput())
} else {
input = loadMatrix(File(dir, "input.matrix"))
}
val quantInput = quant
if (!quantInput && dictionary.isPruned()) {
throw RuntimeException("Invalid model file.\n"
+ "Please download the updated model from www.fasttext.cc.\n"
+ "See issue #332 on Github for more information.\n")
}
var output: FloatMatrix = FloatMatrix.floatArrayMatrix(0, 0)
var qoutput: QMatrix? = null
val qout = File(dir, "qoutput.matrix").exists()
if (quant && qout) {
qoutput = QMatrix.load(File(dir, "qoutput.matrix").openAutoDataInput())
} else {
output = loadMatrix(File(dir, "output.matrix"))
}
val model = Model(input, output, args, 0)
if(quantInput){
model.setQuantizePointer(qinput, qoutput)
}
if (args.model == ModelName.sup) {
model.setTargetCounts(dictionary.getCounts(EntryType.label))
} else {
model.setTargetCounts(dictionary.getCounts(EntryType.word))
}
return FastText(args, dictionary, model)
}
@JvmOverloads
@Throws(Exception::class)
@JvmStatic
fun train(trainFile: File, model_name: ModelName = ModelName.sup, args: TrainArgs = TrainArgs()): FastText {
return FastTextTrain().train(trainFile, model_name, args)
}
@JvmOverloads
@Throws(Exception::class)
@JvmStatic
fun train(source: TrainExampleSource, model_name: ModelName = ModelName.sup, args: TrainArgs = TrainArgs()): FastText {
return FastTextTrain().train(source, model_name, args)
}
/**
* 分类模型量化
*
* @param out
*/
@JvmOverloads
@Throws(Exception::class)
@JvmStatic
fun quantize(fastText: FastText,
dsub:Int=2,
qnorm:Boolean=false):FastText {
if (fastText.quant) {
println("该模型已经被量化过")
return fastText
}
if(fastText.args.model != ModelName.sup){
throw RuntimeException("Only for sup model")
}
val qMatrix = QMatrix(fastText.input.rows(),fastText.input.cols(), dsub, qnorm)
val inputMatrix = fastText.input.toMutableFloatMatrix()
qMatrix.quantize(inputMatrix)
val qModel = Model(FloatMatrix.floatArrayMatrix(0, 0),fastText.output,fastText.args,0)
qModel.setQuantizePointer(qMatrix,null)
val QFastText = FastText(fastText.args,fastText.dict,qModel)
return QFastText
}
}
}
class Model(val input: FloatMatrix
, val output: FloatMatrix,
args_: Args,
seed: Int) : BaseModel(args_, seed, output.rows()) {
/**
* 是否乘积量化模型(input)
*/
var quant: Boolean = false
/**
* Right 是否量化
*/
var quantOut = false
var qinput = QMatrix()
var qoutput = QMatrix()
/**
* hidden size 也就是向量的维度
*/
private val hsz: Int = args_.dim // dim
private val comparePairs = { o1: FloatIntPair, o2: FloatIntPair -> Floats.compare(o2.first, o1.first) }
fun std_log(d: Float)=Math.log(d+1e-5)
fun setQuantizePointer(qinput: QMatrix?, qoutput: QMatrix?) {
qinput?.let {
quant = true
this.qinput = qinput
}
// qoutput 不为null就是out向量化
qoutput?.let {
quantOut = true
this.qoutput = it
this.outputMatrixSize = qoutput.m
}
}
fun predict(input: IntArrayList, k: Int,
heap: MutableList,
hidden: MutableVector,
output: MutableVector) {
checkArgument(k > 0)
computeHidden(input, hidden)
if (args_.loss == LossName.hs) {
dfs(k, 2 * outputMatrixSize - 2, 0.0f, heap, hidden)
} else {
findKBest(k, heap, hidden, output)
}
Collections.sort(heap, comparePairs)
}
fun findKBest(k: Int, heap: MutableList, hidden: Vector, output: MutableVector) {
computeOutputSoftmax(hidden, output)
for (i in 0 until outputMatrixSize) {
val logoutputi = std_log(output[i]).toFloat()
if (heap.size == k && logoutputi < heap[heap.size - 1].first) {
continue
}
heap.add(FloatIntPair(logoutputi, i))
Collections.sort(heap, comparePairs)
if (heap.size > k) {
Collections.sort(heap, comparePairs)
heap.removeAt(heap.size - 1) // pop last
}
}
}
fun dfs(k: Int, node: Int, score: Float, heap: MutableList, hidden: Vector) {
if (heap.size == k && score < heap[heap.size - 1].first) {
return
}
if (tree[node].left == -1 && tree[node].right == -1) {
heap.add(FloatIntPair(score, node))
Collections.sort(heap, comparePairs)
if (heap.size > k) {
Collections.sort(heap, comparePairs)
heap.removeAt(heap.size - 1) // pop last
}
return
}
// val f = sigmoid(output.dotRow(hidden, node - outputMatrixSize))
var f = if (quant && quantOut) {
qoutput.dotRow(hidden, node - outputMatrixSize)
} else {
output[node - outputMatrixSize] * hidden
}
f = 1.0f / (1 + exp(-f))
dfs(k, tree[node].left, score + std_log(1.0f - f).toFloat(), heap, hidden)
dfs(k, tree[node].right, score + std_log(f).toFloat(), heap, hidden)
}
private fun computeHidden(input: IntArrayList, hidden: MutableVector) {
checkArgument(hidden.length() == hsz)
hidden.zero()
val buffer = input.buffer
var i = 0
val size = input.size()
while (i < size) {
val it = buffer[i]
if (quant) {
qinput.addToVector(hidden, it)
} else {
hidden += this.input[it]
}
i++
}
hidden *= (1.0f / input.size())
}
private fun computeOutputSoftmax(hidden: Vector, output: MutableVector) {
if (quant && quantOut) {
matrixMulVector(qoutput, hidden, output)
} else {
matrixMulVector(this.output, hidden, output)
}
var max = output[0]
var z = 0.0f
for (i in 1 until outputMatrixSize) {
max = Math.max(output.get(i), max)
}
for (i in 0 until outputMatrixSize) {
output[i] = Math.exp((output[i] - max).toDouble()).toFloat()
z += output[i]
}
for (i in 0 until outputMatrixSize) {
output[i] = output[i] / z
}
}
private fun matrixMulVector(matrix: QMatrix, v: Vector, target: MutableVector) {
checkArgument(matrix.m == target.length())
checkArgument(matrix.n == v.length())
val m_ = matrix.m
for (i in 0 until m_) {
target[i] = matrix.dotRow(v,i)
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy