
com.expleague.ml.models.hmm.HiddenMarkovModel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jmll Show documentation
Show all versions of jmll Show documentation
Various ML methods implemented by myself and my students
package com.expleague.ml.models.hmm;
import com.expleague.commons.math.vectors.*;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.seq.Seq;
import com.expleague.commons.seq.regexp.Alphabet;
import org.jetbrains.annotations.NotNull;
import java.util.function.Function;
import static java.lang.Math.exp;
public class HiddenMarkovModel implements Function,Vec> {
private final Alphabet alpha;
private final int statesCount;
private final Vec betta;
private final Vec Pi;
private final Mx A;
private final Mx inverseA;
private final Mx B;
public HiddenMarkovModel(Alphabet alpha, int states, Vec betta) {
this.betta = betta;
this.alpha = alpha;
statesCount = states;
this.Pi = betta.sub(0, states);
this.A = new VecBasedMx(states, betta.sub(states, states * states));
this.B = new VecBasedMx(states, betta.sub((states + 1) * states, states * alpha.size()));
this.inverseA = MxTools.transpose(A);
}
@Override
public Vec apply(Seq argument) {
return new SingleValueVec(value(argument));
}
public double value(Seq x) {
final Mx distribs = forward(x);
VecTools.scale(distribs, backward(x));
double ll = 0;
for (int t = 0; t < x.length(); t++) {
final Vec states = distribs.row(t);
VecTools.normalizeL1(states);
ll += Math.log(VecTools.multiply(B.row(alpha.index(x, t)), states));
}
return exp(ll/x.length());
}
@NotNull
public Mx forward(Seq x) {
final Mx forward = new VecBasedMx(statesCount, new ArrayVec(statesCount * x.length()));
{ // forward
Vec prev = forward.row(0);
VecTools.assign(prev, Pi);
VecTools.scale(prev, B.row(alpha.index(x, 0)));
VecTools.normalizeL1(prev);
for (int i = 1; i < x.length(); i++) {
final Vec next = forward.row(i);
MxTools.multiplyTo(A, prev, next);
VecTools.scale(next, B.row(alpha.index(x, i)));
VecTools.normalizeL1(next);
prev = next;
}
}
return forward;
}
@NotNull
public Mx backward(Seq x) {
final Mx backward = new VecBasedMx(statesCount, new ArrayVec(statesCount * x.length()));
{ // backward
Vec prev = new ArrayVec(statesCount);
VecTools.fill(prev, 1. / statesCount);
for (int i = x.length() - 1; i >= 0; i--) {
final Vec next = backward.row(i);
final int index = alpha.index(x, i);
MxTools.multiplyTo(inverseA, prev, next);
VecTools.scale(next, B.row(index));
VecTools.normalizeL1(next);
prev = next;
}
}
return backward;
}
public Vec betta() {
return betta;
}
public int states() {
return statesCount;
}
public Alphabet alpha() {
return alpha;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy