com.expleague.ml.methods.hmm.BaumWelch Maven / Gradle / Ivy
package com.expleague.ml.methods.hmm;
import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.ml.methods.Optimization;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.seq.Seq;
import com.expleague.commons.seq.regexp.Alphabet;
import com.expleague.commons.util.ThreadTools;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.loss.LLLogit;
import com.expleague.ml.models.hmm.HiddenMarkovModel;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;
public class BaumWelch implements Optimization>,Seq> {
private static final Logger log = Logger.getLogger(BaumWelch.class.getName());
private final Alphabet alphabet;
private final int states;
private int iterations;
private FastRandom rng = new FastRandom(0);
public BaumWelch(Alphabet alphabet, int states, int iterations) {
this.alphabet = alphabet;
this.states = states;
this.iterations = iterations;
}
@Override
public HiddenMarkovModel fit(DataSet> learn, LLLogit llLogit) {
final Vec[] betta = {
new ArrayVec((states + 1) * states + states * alphabet.size()),
new ArrayVec((states + 1) * states + states * alphabet.size())
};
VecTools.fill(betta[0], 1);
VecTools.fillUniformPlus(betta[0].sub(states * (states + 1), states * alphabet.size()), rng, 1);
final ThreadPoolExecutor bwCalcer = ThreadTools.createBGExecutor("BWCalcer", learn.length());
final ThreadLocal accBCache = ThreadLocal.withInitial(() -> new VecBasedMx(alphabet.size(), states));
final ThreadLocal ksiCache = ThreadLocal.withInitial(() -> new VecBasedMx(states, states));
for (int t = 0; t < iterations; t++) {
final Vec current = betta[t % 2];
final Vec next = betta[(t + 1) % 2];
normalizeBetta(current);
final Mx A = new VecBasedMx(states, current.sub(states, states * states));
final Mx B = new VecBasedMx(states, current.sub(states * (states + 1), states * alphabet.size()));
VecTools.fill(next, 0);
final HiddenMarkovModel hmm = new HiddenMarkovModel<>(alphabet, states, current);
double[] ll = {0};
double totalLength = 0;
final CountDownLatch latch = new CountDownLatch(learn.length());
for (int i = 0; i < learn.length(); i++) {
final Seq seq = learn.at(i);
if (seq.length() == 0 || llLogit.label(i) < 0) {
latch.countDown();
continue;
}
totalLength += seq.length();
bwCalcer.execute(() -> {
final Mx ksi = ksiCache.get();
VecTools.fill(ksi, 0);
final Mx forward = hmm.forward(seq);
final Mx backward = hmm.backward(seq);
for (int k = 0; k < seq.length() - 1; k++) {
final int nextIdx = alphabet.index(seq, k + 1);
double sum = 0;
for (int u = 0; u < states; u++) {
for (int v = 0; v < states; v++) {
sum += forward.get(k, u) * A.get(u, v) * backward.get(k + 1, v) * B.get(nextIdx, v);
}
}
if (sum < MathTools.EPSILON)
System.out.println();
for (int u = 0; u < states; u++) {
for (int v = 0; v < states; v++) {
final double increment = forward.get(k, u) * A.get(u, v) * backward.get(k + 1, v) * B.get(nextIdx, v);
ksi.adjust(u, v, increment / sum);
}
}
}
//noinspection UnnecessaryLocalVariable
final Mx distrib = forward;
VecTools.scale(distrib, backward);
final Vec sum = new ArrayVec(states);
double llLocal = 0;
for (int k = 0; k < seq.length(); k++) {
final Vec states = distrib.row(k);
VecTools.normalizeL1(states);
llLocal += Math.log(VecTools.multiply(states, B.row(alphabet.index(seq, k))));
VecTools.append(sum, states);
}
for (int u = 0; u < states; u++) {
VecTools.normalizeL1(ksi.row(u));
sum.set(u, 1. / (sum.get(u) + 1e-6));
}
final Mx accB = accBCache.get();
VecTools.fill(accB, 0);
for (int k = 0; k < seq.length(); k++) {
final int nextIdx = alphabet.index(seq, k);
final Vec bRow = accB.row(nextIdx);
final Vec gamma = distrib.row(k);
VecTools.scale(gamma, sum);
VecTools.append(bRow, gamma);
}
synchronized (this) {
ll[0] += llLocal;
VecTools.incscale(next.sub(0, states), distrib.row(0), 1. / learn.length());
VecTools.incscale(next.sub(states, states * states), ksi, 1. / learn.length());
VecTools.incscale(next.sub((states + 1) * states, states * alphabet.size()), accB, 1. / learn.length());
}
latch.countDown();
});
}
try {
latch.await();
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
log.fine("It: " + t + " unit perplexity: " + Math.exp(ll[0]/totalLength));
System.out.println("It: " + t + " unit perplexity: " + Math.exp(ll[0]/totalLength));
}
return new HiddenMarkovModel<>(alphabet, states, betta[iterations % 2]);
}
private void normalizeBetta(Vec betta) {
for (int i = 0; i < (states + 1) * states; i += states) {
final Vec vec = betta.sub(i, states);
VecTools.normalizeL1(vec);
}
final Mx B = new VecBasedMx(states, betta.sub(states * (states + 1), states * alphabet.size()));
final ArrayVec unit = new ArrayVec(alphabet.size());
VecTools.fill(unit, 1);
for (int j = 0; j < states; j++) {
final Vec vec = B.col(j);
VecTools.incscale(vec, unit, 1e-4);
VecTools.normalizeL1(vec);
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy