
com.mayabot.mynlp.fasttext.FastTextTrain.kt Maven / Gradle / Ivy
package com.mayabot.mynlp.fasttext
import com.carrotsearch.hppc.IntArrayList
import com.google.common.base.CharMatcher
import com.google.common.base.Charsets
import com.google.common.base.Splitter
import com.google.common.collect.Lists
import com.google.common.io.Files
import com.google.common.primitives.Ints
import com.google.common.util.concurrent.AtomicDouble
import com.mayabot.blas.*
import java.io.*
import java.util.concurrent.atomic.AtomicLong
const val SIGMOID_TABLE_SIZE = 512
const val MAX_SIGMOID = 8
const val LOG_TABLE_SIZE = 512
const val NEGATIVE_TABLE_SIZE = 10000000
class FastTextTrain {
lateinit var args:Args
lateinit var dict: Dictionary
private var tokenCount = AtomicLong(0)
private var loss = AtomicDouble(-1.0)
private var startTime = 0L
var source: TrainExampleSource? = null
var input: MutableFloatMatrix? = null
var output: MutableFloatMatrix? = null
fun train(file: File,modelName: ModelName, trainArgs: TrainArgs): FastText {
return this.train(FileTrainExampleSource(whitespaceSplitter, file), modelName, trainArgs)
}
fun train(file: TrainExampleSource,modelName: ModelName, trainArgs: TrainArgs): FastText {
args = Args().apply {
model = modelName
if (modelName == ModelName.sup) {
minCount = 1
loss = LossName.softmax
minCount = 1
minn = 0
maxn = 0
lr = 0.1
}
thread = trainArgs.thread ?: thread
dim = trainArgs.dim ?: dim
epoch = trainArgs.epoch ?: epoch
loss = trainArgs.loss ?: loss
lr = trainArgs.lr ?: lr
lrUpdateRate = trainArgs.lrUpdateRate ?: lrUpdateRate
neg = trainArgs.neg ?: neg
ws = trainArgs.ws ?: ws
// -wordNgrams max length of word ngram [1]
// -maxn max length of char ngram [0]
if (wordNgrams <= 1 && maxn == 0) {
bucket = 0
}
}
dict = com.mayabot.mynlp.fasttext.Dictionary(args)
dict.buildFromFile(file)
val output: MutableFloatMatrix = if (ModelName.sup == args.model) {//分类模型
FloatMatrix.floatArrayMatrix(dict.nlabels(), args.dim)
} else {
FloatMatrix.floatArrayMatrix(dict.nwords(), args.dim)
}
val input: MutableFloatMatrix
val pretrainedVectors: File? = if (!trainArgs.pretrainedVectors.isNullOrBlank()) {
val filePre = File(trainArgs.pretrainedVectors)
if (filePre.exists() && filePre.canRead()) {
filePre
} else {
throw RuntimeException("Not found File " + trainArgs.pretrainedVectors)
}
} else {
null
}
if (pretrainedVectors != null) {
input = loadVectors(pretrainedVectors)
} else {
input = FloatMatrix.floatArrayMatrix(dict.nwords() + args.bucket, args.dim)
input.uniform(1.0f / args.dim)
}
output.fill(0f)
this.source = file
this.input = input
this.output = output
startThreads()
val fastText = FastText(args,dict, Model(input, output, args, 0).apply {
if (args.model == ModelName.sup) {
this.setTargetCounts(dict.getCounts(EntryType.label))
} else {
this.setTargetCounts(dict.getCounts(EntryType.word))
}
})
println("Train use time ${System.currentTimeMillis() - startTime} ms")
return fastText
}
@Throws(Exception::class)
private fun startThreads() {
startTime = System.currentTimeMillis()
tokenCount = AtomicLong(0)
loss = AtomicDouble(-1.0)
val sourceParts = source!!.split(args.thread)
val threads = Lists.newArrayList()
for (i in 0 until args.thread) {
threads.add(Thread(TrainThread(i,sourceParts[i])))
}
for (i in 0 until args.thread) {
threads[i].start()
}
val ntokens = dict.ntokens()
// Same condition as trainThread
while (tokenCount.toLong() < args.epoch * ntokens) {
Thread.sleep(100)
if (loss.toFloat() >= 0 && args.verbose > 1) {
val progress = tokenCount.toFloat() / (args.epoch * ntokens)
print("\r")
printInfo(progress, loss)
}
}
for (i in 0 until args.thread) {
threads[i].join()
}
if (args.verbose > 0) {
print("\r")
printInfo(1.0f, loss)
println()
}
source?.close()
}
internal inner class TrainThread(
private val threadId: Int,
private val parts: TrainExampleSource
) : Runnable {
override fun run() {
try {
LoopReader(parts).use { loopReader ->
val model = TrainModel(input!!, output!!, args, threadId)
val rng = model.rng
// setTargetCounts 相当耗时
if (args.model == ModelName.sup) {
model.setTargetCounts(dict.getCounts(EntryType.label))
} else {
model.setTargetCounts(dict.getCounts(EntryType.word))
}
val ntokens = dict.ntokens() //文件中词语的总数量(非排重)
var localTokenCount: Long = 0
val up_ = args.epoch * ntokens
val line = IntArrayList()
val labels = IntArrayList()
if (args.model == ModelName.sup) {
while (tokenCount.toLong() < up_) {
val progress = tokenCount.toFloat() / up_ //总的进度
val lr = args.lr.toFloat() * (1.0f - progress) //学习率自动放缓
val tokens = loopReader.readLineTokens()
localTokenCount += dict.getLine(tokens, line, labels).toLong()
supervised(model, lr, line, labels)
if (localTokenCount > args.lrUpdateRate) {
tokenCount.addAndGet(localTokenCount)
localTokenCount = 0
if (threadId == 0) {
loss.set(model.getLoss().toDouble())
}
}
}
}
if (args.model == ModelName.cbow) {
while (tokenCount.toLong() < up_) {
val progress = tokenCount.toFloat() / up_ //总的进度
val lr = args.lr.toFloat() * (1.0f - progress) //学习率自动放缓
val tokens = loopReader.readLineTokens()
localTokenCount += dict.getLine(tokens, line, rng).toLong()
cbow(model, lr, line)
if (localTokenCount > args.lrUpdateRate) {
tokenCount.addAndGet(localTokenCount)
localTokenCount = 0
if (threadId == 0) {
loss.set(model.getLoss().toDouble())
}
}
}
}
if (args.model == ModelName.sg) {
while (tokenCount.toLong() < up_) {
val progress = tokenCount.toFloat() / up_ //总的进度
val lr = args.lr.toFloat() * (1.0f - progress) //学习率自动放缓
val tokens = loopReader.readLineTokens()
localTokenCount += dict.getLine(tokens, line, rng).toLong()
skipgram(model, lr, line)
if (localTokenCount > args.lrUpdateRate) {
tokenCount.addAndGet(localTokenCount)
localTokenCount = 0
if (threadId == 0) {
loss.set(model.getLoss().toDouble())
}
}
}
}
if (threadId == 0) {
loss.set(model.getLoss().toDouble())
}
}
} catch (e: Exception) {
throw RuntimeException(e)
}
}
internal fun supervised(
model: TrainModel,
lr: Float,
line: IntArrayList,
labels: IntArrayList) {
if (labels.size() == 0 || line.size() == 0) {
return
}
val i = if (labels.size() == 1) 0 else model.rng.nextInt(labels.size())
model.update(line, labels.get(i), lr)
}
private fun cbow(model: TrainModel, lr: Float,
line: IntArrayList) {
val bow = IntArrayList()
// std::uniform_int_distribution<> uniform(1, args_->ws);
for (w in 0 until line.size()) {
val boundary = model.rng.nextInt(args.ws) + 1 // 1~5
bow.clear()
for (c in -boundary..boundary) {
if (c != 0 && w + c >= 0 && w + c < line.size()) {
val ngrams = dict.getSubwords(line.get(w + c))
bow.addAll(ngrams)
}
}
model.update(bow, line.get(w), lr)
}
}
private fun skipgram(model: TrainModel, lr: Float,
line: IntArrayList) {
for (w in 0 until line.size()) {
val boundary = model.rng.nextInt(args.ws) + 1 // 1~5
val ngrams = dict.getSubwords(line.get(w))
for (c in -boundary..boundary) {
if (c != 0 && w + c >= 0 && w + c < line.size()) {
model.update(ngrams, line.get(w + c), lr)
}
}
}
}
}
private fun printInfo(progress: Float, loss: AtomicDouble) {
var progress = progress
// clock_t might also only be 32bits wide on some systems
val t = ((System.currentTimeMillis() - startTime) / 1000).toDouble()
val lr = args.lr * (1.0 - progress)
var wst = 0.0
var eta = (720 * 3600).toLong() // Default to one month
if (progress > 0 && t >= 0) {
eta = (t / progress * (1 - progress) / args.thread).toInt().toLong()
wst = tokenCount.toFloat() / t
}
val etah = eta / 3600
val etam = eta % 3600 / 60
val etas = eta % 3600 % 60
progress *= 100
val sb = StringBuilder()
sb.append("Progress: " +
String.format("%2.2f", progress) + "% words/sec/thread: " + String.format("%8.0f", wst))
sb.append(String.format(" lr: %2.5f", lr))
sb.append(String.format(" loss: %2.5f", loss.toFloat()))
sb.append(" ETA: " + etah + "h " + etam + "m " + etas + "s")
print(sb)
}
@Throws(Exception::class)
private fun loadVectors(filename: File): MutableFloatMatrix {
var n: Int = 0
var dim: Int = 0
val firstLine = filename.firstLine()!!
run {
val strings = Splitter.on(CharMatcher.whitespace()).splitToList(firstLine)
n = Ints.tryParse(strings[0])!!
dim = Ints.tryParse(strings[1])!!
}
if (n == 0 || dim == 0) {
throw Exception("Error format for " + filename.name + ",First line must be rows and dim arg")
}
if (dim != args.dim) {
throw Exception("Dimension of pretrained vectors " + dim + " does not match dimension (" + args.dim + ")")
}
val mat = MutableByteBufferMatrix(n, dim)
val sp = Splitter.on(" ").omitEmptyStrings()
val words = Lists.newArrayListWithExpectedSize(n)
val charSource = Files.asCharSource(filename, Charsets.UTF_8)
charSource.openBufferedStream().use { reader ->
reader.readLine()//first line
for (i in 0 until n) {
val line = reader.readLine()
var parts: MutableList = sp.splitToList(line)
if (parts.size != dim + 1) {
if (parts.size == dim) {
parts = Lists.newArrayList(line.substring(0, line.indexOf(parts[0]) - 1))
parts.addAll(sp.splitToList(line))
} else {
throw RuntimeException("line $line parse error")
}
}
val word = parts[0]
dict.add(word)
words.add(word)
val row = mat[i]
var x = 0
for (j in 1..dim) {
row[x++] = parts[j].toFloat()
}
}
}
dict.threshold(1, 0)
val input = FloatMatrix.floatArrayMatrix(dict.nwords() + args.bucket, args.dim)
input.uniform(1.0f / args.dim)
for (i in 0 until n) {
val idx = dict.getId(words[i])
if (idx < 0 || idx > dict.nwords()) {
continue
}
input[idx](mat[i])
// System.arraycopy(matrixData, i * dim, input.getData(), idx * dim, dim)
// for (int j = 0; j < dim; j++) {
// input.set(idx, j, mat.get(i, j));
// }
}
return input
}
}
class TrainModel(
private val inputMatrix: MutableFloatMatrix // input
, private val outputMatrix: MutableFloatMatrix // output
, args_: Args, seed: Int
) : BaseModel(args_, seed, outputMatrix.rows()) {
private val hidden = Vector.floatArrayVector(args_.dim)
private val output = Vector.floatArrayVector(outputMatrix.rows())
private val grad = Vector.floatArrayVector(args_.dim)
private val hsz: Int = args_.dim // dim
// private val isz: Int = input.rows()// input vocabSize
private var loss = 0f
private var nexamples = 1L
fun getLoss(): Float {
return loss / nexamples
}
private fun LongArray.sqrtSum() : Long{
var t = 0L
for (i in this) {
t += i*i
}
return t
}
fun update(input: IntArrayList, target: Int, lr: Float) {
checkArgument(target >= 0)
checkArgument(target < outputMatrixSize)
if (input.size() == 0) {
return
}
computeHidden(input, hidden)
loss += when (args_.loss) {
LossName.ns -> negativeSampling(target, lr)
LossName.hs -> hierarchicalSoftmax(target, lr)
LossName.softmax -> softmax(target, lr)
}
nexamples += 1
if (args_.model == ModelName.sup) {
grad *= (1.0f / input.size())
}
val buffer = input.buffer
var i = 0
val size = input.size()
while (i < size) {
val it = buffer[i]
inputMatrix[it] += grad
i++
}
}
private fun computeHidden(input: IntArrayList, hidden: MutableVector) {
checkArgument(hidden.length() == hsz)
hidden.zero()
val buffer = input.buffer
var i = 0
val size = input.size()
while (i < size) {
val it = buffer[i]
hidden += inputMatrix[it]
i++
}
hidden *= (1.0f / input.size())
}
private fun negativeSampling(target: Int, lr: Float): Float {
var loss = 0.0f
grad.zero()
for (n in 0..args_.neg) {
loss += if (n == 0) {
binaryLogistic(target, true, lr)
} else {
binaryLogistic(getNegative(target), false, lr)
}
}
return loss
}
private fun binaryLogistic(target: Int, label: Boolean, lr: Float): Float {
val score = sigmoid(outputMatrix[target] * hidden)
val alpha = lr * ((if (label) 1.0f else 0.0f) - score)
grad += alpha to outputMatrix[target]
outputMatrix[target] += alpha to hidden
return if (label) {
-log(score)
} else {
-log(1.0f - score)
}
}
private fun getNegative(target: Int): Int {
var negative: Int
do {
negative = negatives[negpos]
negpos = (negpos + 1) % negatives.size
} while (target == negative)
return negative
}
private fun hierarchicalSoftmax(target: Int, lr: Float): Float {
var loss = 0.0f
grad.zero()
val binaryCode = codes[target]
val pathToRoot = paths[target]
for (i in 0 until pathToRoot.size) {
loss += binaryLogistic(pathToRoot[i], binaryCode[i], lr)
}
return loss
}
private fun softmax(target: Int, lr: Float): Float {
grad.zero()
computeOutputSoftmax()
for (i in 0 until outputMatrixSize) {
val label = if (i == target) 1.0f else 0.0f
val alpha = lr * (label - output[i])
grad += alpha to outputMatrix[i]
outputMatrix[i] += alpha to hidden
}
return -log(output[target])
}
@JvmOverloads
private fun computeOutputSoftmax(hidden: Vector = this.hidden, output: MutableVector = this.output) {
matrixMulVector(outputMatrix, hidden, output)
var max = output[0]
var z = 0.0f
for (i in 1 until outputMatrixSize) {
max = Math.max(output[i], max)
}
for (i in 0 until outputMatrixSize) {
output[i] = Math.exp((output[i] - max).toDouble()).toFloat()
z += output[i]
}
for (i in 0 until outputMatrixSize) {
output[i] = output[i] / z
}
}
}
class Node {
@JvmField
var parent: Int = 0
@JvmField
var left: Int = 0
@JvmField
var right: Int = 0
@JvmField
var count: Long = 0
@JvmField
var binary: Boolean = false
}
class LoopReader @Throws(IOException::class)
constructor(private val parts: TrainExampleSource) : AutoCloseable {
var reader: ExampleIterator
internal var splitter = Splitter.on(CharMatcher.whitespace()).omitEmptyStrings().trimResults()
init {
reader = parts.iteratorAll()
}
@Throws(IOException::class)
fun readLineTokens(): List {
var line = loopLine()
//skip empty line
//此处容易导致死循环
while (line.isEmpty()) {
line = loopLine()
}
return line2Tokens(line)
}
@Throws(IOException::class)
private fun loopLine(): List {
if (reader.hasNext()) {
return reader.next()
}else{
reader.close()
reader = parts.iteratorAll()
return Lists.newArrayList()
}
}
private fun line2Tokens(tokens: List): List {
val list = Lists.newArrayList(tokens)
list.add(EOS)
return list
}
@Throws(Exception::class)
override fun close() {
if (reader != null) {
reader!!.close()
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy