com.enterprisemath.math.statistics.HMM Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of em-math Show documentation
Show all versions of em-math Show documentation
Advanced mathematical algorithms.
The newest version!
package com.enterprisemath.math.statistics;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.apache.commons.lang3.builder.ToStringBuilder;
import com.enterprisemath.math.probability.ProbabilityDistribution;
import com.enterprisemath.math.statistics.observation.ObservationIterator;
import com.enterprisemath.math.statistics.observation.ObservationProvider;
import com.enterprisemath.utils.DomainUtils;
import com.enterprisemath.utils.ValidationUtils;
/**
* Basic definition of the hidden markov model.
*
* @author radek.hecl
*
* @param type of the hidden state
* @param type of the observation
*/
public class HMM {
/**
* Builder object.
*
* @param type of the hidden state
* @param type of the observation
*/
public static class Builder {
/**
* Set of hidden states.
*/
private Set hiddenStates = new HashSet();
/**
* Initial probability. Must be 1 over all states.
*/
private ProbabilityDistribution initialProbability;
/**
* Transition probabilities between states.
*/
private Map> transitionProbabilities = new HashMap>();
/**
* Emission probabilities to the observations.
*/
private Map> emissionProbabilities = new HashMap>();
/**
* Sets the hidden states.
*
* @param hiddenStates hidden states
* @return this instance
*/
public Builder setHiddenStates(Set hiddenStates) {
this.hiddenStates = DomainUtils.softCopySet(hiddenStates);
return this;
}
/**
* Adds hidden state.
*
* @param hiddenState hidden state
* @return this instance
*/
public Builder addHiddenState(State hiddenState) {
this.hiddenStates.add(hiddenState);
return this;
}
/**
* Sets initial probability distribution.
*
* @param initialProbability initial probability distribution
* @return this instance
*/
public Builder setInitialProbability(ProbabilityDistribution initialProbability) {
this.initialProbability = initialProbability;
return this;
}
/**
* Sets transition probability distributions.
*
* @param transitionProbabilities transition probability distributions
* @return this instance
*/
public Builder setTransitionProbabilities(Map> transitionProbabilities) {
this.transitionProbabilities = DomainUtils.softCopyMap(transitionProbabilities);
return this;
}
/**
* Add transition probability distribution.
*
* @param state state from which transition starts
* @param probability probability distribution to other transitions
* @return this instance
*/
public Builder addTransitionProbability(State state, ProbabilityDistribution probability) {
this.transitionProbabilities.put(state, probability);
return this;
}
/**
* Sets emission probability distribution.
*
* @param emissionProbabilities emission probability distributions
* @return this instance
*/
public Builder setEmissionProbabilities(Map> emissionProbabilities) {
this.emissionProbabilities = DomainUtils.softCopyMap(emissionProbabilities);
return this;
}
/**
* Adds emission probability distribution.
*
* @param state state of the emission
* @param probability probability distribution of emission
* @return this instance
*/
public Builder addEmissionProbability(State state, ProbabilityDistribution probability) {
this.emissionProbabilities.put(state, probability);
return this;
}
/**
* Builds the result object.
*
* @return created object
*/
public HMM build() {
return new HMM(this);
}
}
/**
* Set of hidden states.
*/
private Set hiddenStates;
/**
* Initial probability. Must be 1 over all states.
*/
private ProbabilityDistribution initialProbability;
/**
* Transition probabilities between states.
*/
private Map> transitionProbabilities;
/**
* Emission probabilities to the observations.
*/
private Map> emissionProbabilities;
/**
* Creates new instance.
*
* @param builder builder object
*/
public HMM(Builder builder) {
hiddenStates = Collections.unmodifiableSet(DomainUtils.softCopySet(builder.hiddenStates));
initialProbability = builder.initialProbability;
transitionProbabilities = Collections.unmodifiableMap(DomainUtils.softCopyMap(builder.transitionProbabilities));
emissionProbabilities = Collections.unmodifiableMap(DomainUtils.softCopyMap(builder.emissionProbabilities));
guardInvariants();
}
/**
* Guards this object to be consistent. Throws exception if this is not the case.
*/
private void guardInvariants() {
//
// first level validation
ValidationUtils.guardNotNullCollection(hiddenStates, "hiddenStates cannot have null element");
ValidationUtils.guardPositiveInt(hiddenStates.size(), "hiddenStates cannot be empty");
ValidationUtils.guardNotNull(initialProbability, "initialProbability canno be null");
ValidationUtils.guardNotNullMap(transitionProbabilities, "transitionProbabilities cannot have null element");
ValidationUtils.guardNotNullMap(emissionProbabilities, "emissionProbabilities cannot have null element");
//
// second level validation
ValidationUtils.guardEquals(hiddenStates, transitionProbabilities.keySet(),
"transitionProbabilities must be defined for all hiddenStates");
ValidationUtils.guardEquals(hiddenStates, emissionProbabilities.keySet(),
"emissionProbabilities must be defined for all hiddenStates");
//
// third level validation
double h = 0;
for (State state : hiddenStates) {
h = h + initialProbability.getValue(state);
}
ValidationUtils.guardGreaterOrEqualDouble(0.0001, Math.abs(1 - h),
"initialProbability must be 1 in sum through all hiddenStates");
for (State state : hiddenStates) {
h = 0;
for (State until : hiddenStates) {
h = h + transitionProbabilities.get(state).getValue(until);
}
ValidationUtils.guardGreaterOrEqualDouble(0.0001, Math.abs(1 - h),
"transitionProbability must be 1 in sum for " + state);
}
}
/**
* Returns the most probable sequence of the hidden states for the given observations.
* If there only one most probable sequence exists (which is expected in the majority of real world cases),
* then there is not problem.
* If there are multiple most probable sequences, then one of them is returned.
* In that case there is not guaranteed which one and even on 2 calls with same observations
* the returned sequence may be different.
*
* @param observations observations
* @return this instance
*/
public List viterbi(ObservationProvider observations) {
ObservationIterator iterator = observations.getIterator();
//
// handling edge case
if (!iterator.isNextAvailable()) {
return Collections.emptyList();
}
//
// initialization
List states = new ArrayList(hiddenStates);
List delta = new ArrayList();
List psi = new ArrayList();
//
// first observation
{
Observation obs = iterator.getNext();
double[] deltaelm = new double[states.size()];
int[] psielm = new int[states.size()];
for (int i = 0; i < states.size(); ++i) {
deltaelm[i] = initialProbability.getLnValue(states.get(i)) + emissionProbabilities.get(states.get(i)).getLnValue(obs);
psielm[i] = i;
}
delta.add(deltaelm);
psi.add(psielm);
}
//
// iterations
while (iterator.isNextAvailable()) {
Observation obs = iterator.getNext();
double[] deltaelm = new double[states.size()];
int[] psielm = new int[states.size()];
for (int i = 0; i < states.size(); ++i) {
double max = Double.NEGATIVE_INFINITY;
int argmax = 0;
for (int j = 0; j < states.size(); ++j) {
double val = delta.get(delta.size() - 1)[j] + transitionProbabilities.get(states.get(j)).getLnValue(states.get(i));
if (val > max) {
max = val;
argmax = j;
}
}
deltaelm[i] = max + emissionProbabilities.get(states.get(i)).getLnValue(obs);
psielm[i] = argmax;
}
delta.add(deltaelm);
psi.add(psielm);
}
//
// taking the last (current) observation
List res = new ArrayList();
{
double max = Double.NEGATIVE_INFINITY;
int argmax = -1;
for (int j = 0; j < states.size(); ++j) {
if (delta.get(delta.size() - 1)[j] > max) {
max = delta.get(delta.size() - 1)[j];
argmax = j;
}
}
res.add(states.get(argmax));
}
//
// back tracking to get the whole sequence
for (int i = delta.size() - 1; i >= 1; --i) {
double max = Double.NEGATIVE_INFINITY;
int argmax = -1;
for (int j = 0; j < states.size(); ++j) {
if (delta.get(i)[j] > max) {
max = delta.get(i)[j];
argmax = j;
}
}
res.add(states.get(psi.get(i)[argmax]));
}
Collections.reverse(res);
return res;
}
@Override
public int hashCode() {
return HashCodeBuilder.reflectionHashCode(this);
}
@Override
public boolean equals(Object obj) {
return EqualsBuilder.reflectionEquals(this, obj);
}
@Override
public String toString() {
return ToStringBuilder.reflectionToString(this);
}
}