gov.sandia.cognition.learning.algorithm.hmm.HiddenMarkovModel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: HiddenMarkovModel.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Feb 2, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.hmm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.Distribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.distribution.DirichletDistribution;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Random;
/**
* A discrete-state Hidden Markov Model (HMM) with either continuous
* or discrete observations.
* @author Kevin R. Dixon
* @since 3.0
* @param Type of Observations handled by the HMM.
*/
@PublicationReference(
author="Lawrence R. Rabiner",
title="A tutorial on hidden Markov models and selected applications in speech recognition",
type=PublicationType.Journal,
year=1989,
publication="Proceedings of the IEEE",
pages={257,286},
url="http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf",
notes="Rabiner's transition matrix is transposed from mine."
)
public class HiddenMarkovModel
extends MarkovChain
implements Distribution
{
/**
* The PDFs that emit symbols from each state.
*/
protected Collection extends ComputableDistribution> emissionFunctions;
/**
* Default constructor.
*/
public HiddenMarkovModel()
{
super();
}
/**
* Creates a new instance of ContinuousDensityHiddenMarkovModel
* @param numStates
* Number of states in the HMM.
*/
public HiddenMarkovModel(
int numStates )
{
super( numStates );
}
/**
* Creates a new instance of ContinuousDensityHiddenMarkovModel
* @param initialProbability
* Initial probability Vector over the states. Each entry must be
* nonnegative and the Vector must sum to 1.
* @param transitionProbability
* Transition probability matrix. The entry (i,j) is the probability
* of transition from state "j" to state "i". As a corollary, all
* entries in the Matrix must be nonnegative and the
* columns of the Matrix must sum to 1.
* @param emissionFunctions
* The PDFs that emit symbols from each state.
*/
public HiddenMarkovModel(
Vector initialProbability,
Matrix transitionProbability,
Collection extends ComputableDistribution> emissionFunctions )
{
super( initialProbability, transitionProbability );
final int k = this.getNumStates();
if( emissionFunctions.size() != k )
{
throw new IllegalArgumentException(
"Number of PDFs must be equal to number of states!" );
}
this.setEmissionFunctions(emissionFunctions);
}
/**
* Creates a Hidden Markov Model with the same PMF/PDF for each state,
* but sampling the columns of the transition matrix and the initial
* probability distributions from a diffuse Dirichlet.
* @param
* Type of observations to generate.
* @param numStates
* Number of states to create
* @param learner
* Learner to create the distributions for each state
* @param data
* Data from which to make the distribution
* @param random
* Random number generator
* @return
* HMM with the specified states.
*/
public static HiddenMarkovModel createRandom(
int numStates,
BatchLearner>,? extends ComputableDistribution> learner,
Collection extends ObservationType> data,
Random random )
{
ArrayList> weightedData =
new ArrayList>( data.size() );
for( ObservationType observation : data )
{
weightedData.add( new DefaultWeightedValue( observation, 1.0 ) );
}
ComputableDistribution distribution =
learner.learn(weightedData);
return createRandom( numStates, distribution, random );
}
/**
* Creates a Hidden Markov Model with the same PMF/PDF for each state,
* but sampling the columns of the transition matrix and the initial
* probability distributions from a diffuse Dirichlet.
* @param
* Type of observations to generate.
* @param numStates
* Number of states to create
* @param distribution
* Distribution from which we will assign to each state.
* @param random
* Random number generator
* @return
* HMM with the specified states.
*/
public static HiddenMarkovModel createRandom(
int numStates,
ComputableDistribution distribution,
Random random )
{
Collection> distributions =
Collections.nCopies(numStates, distribution.getProbabilityFunction());
return createRandom(distributions, random);
}
/**
* Creates a Hidden Markov Model with the given probability function for
* each state, but sampling the columns of the transition matrix and the
* initial probability distributions from a diffuse Dirichlet.
* @param
* Type of observations to generate.
* @param distributions
* The distribution for each state. The size of the collection is the
* number of states to create.
* @param random
* Random number generator
* @return
* HMM with the specified states.
*/
public static HiddenMarkovModel createRandom(
Collection extends ProbabilityFunction> distributions,
Random random)
{
int numStates = distributions.size();
// We'll sample the multinomial probabilities from a uniform Dirichlet
DirichletDistribution dirichlet =
new DirichletDistribution( numStates );
Vector initialProbability = dirichlet.sample(random);
Matrix transitionMatrix =
MatrixFactory.getDefault().createMatrix(numStates, numStates);
ArrayList outbounds = dirichlet.sample(random, numStates);
for( int i = 0; i < numStates; i++ )
{
transitionMatrix.setColumn( i, outbounds.get(i) );
}
return new HiddenMarkovModel(
initialProbability,transitionMatrix, distributions);
}
@Override
public HiddenMarkovModel clone()
{
@SuppressWarnings("unchecked")
HiddenMarkovModel clone =
(HiddenMarkovModel) super.clone();
clone.setEmissionFunctions( ObjectUtil.cloneSmartElementsAsArrayList(
this.getEmissionFunctions() ) );
return clone;
}
/**
* Computes the log-likelihood of the observation sequence, given the
* current HMM's parameterization. This is the answer to Rabiner's
* "Three Basic Problems for HMMs, Problem 1: Probability Evaluation".
* @param observations
* Observations to consider.
* @return
* Log-likelihood of the given observation sequence.
*/
public double computeObservationLogLikelihood(
Collection extends ObservationType> observations )
{
final int k = this.getNumStates();
Vector b = VectorFactory.getDefault().createVector(k);
Vector alpha = this.getInitialProbability().clone();
Matrix A = this.getTransitionProbability();
int index = 0;
double logLikelihood = 0.0;
for( ObservationType observation : observations )
{
if( index > 0 )
{
alpha = A.times( alpha );
}
this.computeObservationLikelihoods(observation, b);
alpha.dotTimesEquals(b);
final double weight = alpha.norm1();
alpha.scaleEquals(1.0/weight);
logLikelihood += Math.log(weight);
index++;
}
return logLikelihood;
}
/**
* Computes the log-likelihood of the observation sequences, given the
* current HMM's parameterization. This is the answer to Rabiner's
* "Three Basic Problems for HMMs, Problem 1: Probability Evaluation".
* @param sequences
* Observations sequences to consider
* @return
* Log-likelihood of the given observation sequence.
*/
protected double computeMultipleObservationLogLikelihood(
Collection extends Collection extends ObservationType>> sequences )
{
double logLikelihood = 0.0;
for( Collection extends ObservationType> observations : sequences )
{
logLikelihood += this.computeObservationLogLikelihood(observations);
}
return logLikelihood;
}
/**
* Computes the log-likelihood that the given observation sequence
* was generated by the given sequence of state indices.
* @param observations
* Observations to consider.
* @param states
* Indices of states hypothesized to have generated the observation
* sequence.
* @return
* Log-likelihood of the given observation sequence.
*/
public double computeObservationLogLikelihood(
Collection extends ObservationType> observations,
Collection states )
{
final int N = observations.size();
if( N != states.size() )
{
throw new IllegalArgumentException(
"Observations and states must be the same size" );
}
Iterator stateIterator = states.iterator();
double logLikelihood = 0.0;
ArrayList> fs =
new ArrayList>( this.getNumStates() );
for( ComputableDistribution f : this.getEmissionFunctions() )
{
fs.add( f.getProbabilityFunction() );
}
int lastState = -1;
for( ObservationType observation : observations )
{
final int state = stateIterator.next();
double blog = fs.get(state).logEvaluate(observation);
double ll;
if( lastState < 0 )
{
ll = Math.log(this.initialProbability.getElement(state));
}
else
{
ll = Math.log(this.transitionProbability.getElement(state, lastState) );
}
lastState = state;
logLikelihood += blog + ll;
}
return logLikelihood;
}
@Override
public ObservationType sample(
Random random)
{
return CollectionUtil.getFirst( this.sample(random, 1) );
}
@Override
public ArrayList sample(
Random random,
int numSamples )
{
final ArrayList samples = new ArrayList(numSamples);
this.sampleInto(random, numSamples, samples);
return samples;
}
@Override
public void sampleInto(
final Random random,
final int sampleCount,
final Collection super ObservationType> output)
{
Vector p = this.getInitialProbability();
int state = -1;
for( int n = 0; n < sampleCount; n++ )
{
double value = random.nextDouble();
state = -1;
while( value > 0.0 )
{
state++;
value -= p.getElement(state);
}
ObservationType sample = CollectionUtil.getElement(
this.getEmissionFunctions(), state ).sample(random);
output.add( sample );
p = this.getTransitionProbability().getColumn(state);
}
}
/**
* Getter for emissionFunctions
* @return
* The PDFs that emit symbols from each state.
*/
public Collection extends ComputableDistribution> getEmissionFunctions()
{
return this.emissionFunctions;
}
/**
* Setter for emissionFunctions.
* @param emissionFunctions
* The PDFs that emit symbols from each state.
*/
public void setEmissionFunctions(
Collection extends ComputableDistribution> emissionFunctions)
{
this.emissionFunctions = emissionFunctions;
}
/*
@SuppressWarnings("unchecked")
public ObservationType getMean()
{
Vector p = this.getSteadyStateDistribution();
ObservationType observation =
CollectionUtil.getFirst(this.emissionFunctions).getMean();
if( observation instanceof Ring )
{
RingAccumulator weightedAverage = new RingAccumulator();
int i = 0;
for( ComputableDistribution f : this.emissionFunctions )
{
Ring mean = (Ring) f.getMean();
weightedAverage.accumulate( mean.scale( p.getElement(i) ) );
i++;
}
return (ObservationType) weightedAverage.getSum();
}
else if( observation instanceof Number )
{
double weightedAverage = 0.0;
int i = 0;
for( ComputableDistribution f : this.emissionFunctions )
{
Number mean = (Number) f.getMean();
weightedAverage += mean.doubleValue() * p.getElement(i);
i++;
}
return (ObservationType) (new Double( weightedAverage ));
}
else
{
throw new UnsupportedOperationException(
"Mean not supported for type: " + observation.getClass() );
}
}
*
*/
/**
* Computes the recursive solution to the forward probabilities of the
* HMM.
* @param alpha
* Previous alpha value.
* @param b
* Current observation-emission likelihood.
* @param normalize
* True to normalize the alphas, false to leave them unnormalized.
* @return
* Alpha with the associated weighting (will be 1 if unnormalized).
*/
protected WeightedValue computeForwardProbabilities(
Vector alpha,
Vector b,
boolean normalize )
{
Vector alphaNext = this.getTransitionProbability().times( alpha );
alphaNext.dotTimesEquals(b);
double weight;
if( normalize )
{
weight = 1.0/alphaNext.norm1();
alphaNext.scaleEquals(weight);
}
else
{
weight = 1.0;
}
return new DefaultWeightedValue( alphaNext, weight );
}
/**
* Computes the forward probabilities for the given observation likelihood
* sequence.
* @param b
* Observation likelihood sequence.
* @param normalize
* True to normalize the alphas, false to leave them unnormalized.
* @return
* Forward probability alphas, with their associated weights.
*/
protected ArrayList> computeForwardProbabilities(
ArrayList b,
boolean normalize )
{
final int N = b.size();
ArrayList> weightedAlphas =
new ArrayList>( N );
Vector alpha = b.get(0).dotTimes( this.getInitialProbability() );
double weight;
if( normalize )
{
weight = 1.0/alpha.norm1();
alpha.scaleEquals(weight);
}
else
{
weight = 1.0;
}
WeightedValue weightedAlpha =
new DefaultWeightedValue(alpha,weight);
weightedAlphas.add( weightedAlpha );
for( int n = 1; n < N; n++ )
{
weightedAlpha = this.computeForwardProbabilities(
weightedAlpha.getValue(), b.get(n), normalize );
weightedAlphas.add( weightedAlpha );
}
return weightedAlphas;
}
/**
* Computes the conditionally independent likelihoods
* for each state given the observation.
* @param observation
* Observation to consider
* @return
* Likelihood of each state generating the given observation.
*/
public Vector computeObservationLikelihoods(
ObservationType observation )
{
final int k = this.getEmissionFunctions().size();
Vector b = VectorFactory.getDefault().createVector(k);
this.computeObservationLikelihoods(observation, b);
return b;
}
/**
* Computes the conditionally independent likelihoods
* for each state given the observation.
* @param observation
* Observation to consider
* @param b
* Likelihood of each state generating the given observation. This is where
* the result of the computation is stored.
*/
protected void computeObservationLikelihoods(
ObservationType observation,
Vector b )
{
int i = 0;
for( ComputableDistribution f : this.getEmissionFunctions() )
{
b.setElement(i, f.getProbabilityFunction().evaluate(observation) );
i++;
}
}
/**
* Computes the conditionally independent likelihoods
* for each state given the observation sequence.
* @param observations
* Observation sequence to consider
* @return
* Likelihood of each state generating the given observation sequence.
*/
protected ArrayList computeObservationLikelihoods(
Collection extends ObservationType> observations )
{
final int N = observations.size();
ArrayList bs = new ArrayList( N );
for( ObservationType observation : observations )
{
bs.add( this.computeObservationLikelihoods(observation) );
}
return bs;
}
/**
* Computes the backward probability recursion.
* @param beta
* Beta from the "next" time step.
* @param b
* Observation likelihood from the "next" time step.
* @param weight
* Weight to use for the current time step.
* @return
* Beta for the previous time step, weighted by "weight".
*/
protected WeightedValue computeBackwardProbabilities(
Vector beta,
Vector b,
double weight )
{
Vector betaPrevious = b.dotTimes(beta);
betaPrevious = betaPrevious.times( this.getTransitionProbability() );
if( weight != 1.0 )
{
betaPrevious.scaleEquals(weight);
}
return new DefaultWeightedValue( betaPrevious, weight );
}
/**
* Computes the backward-probabilities for the given observation likelihoods
* and the weights from the alphas.
* @param b
* Observation likelihoods.
* @param alphas
* Forward probabilities from which we will use the weights.
* @return
* Backward probabilities.
*/
protected ArrayList> computeBackwardProbabilities(
ArrayList b,
ArrayList> alphas )
{
final int N = b.size();
final int k = this.getInitialProbability().getDimensionality();
ArrayList> weightedBetas =
new ArrayList>( N );
for( int n = 0; n < N; n++ )
{
weightedBetas.add( null );
}
Vector beta = VectorFactory.getDefault().createVector(k, 1.0);
double weight = alphas.get(N-1).getWeight();
if( weight != 1.0 )
{
beta.scaleEquals(weight);
}
WeightedValue weightedBeta =
new DefaultWeightedValue( beta, weight );
weightedBetas.set( N-1, weightedBeta );
for( int n = N-2; n >= 0; n-- )
{
weight = alphas.get(n).getWeight();
weightedBeta = this.computeBackwardProbabilities(
weightedBeta.getValue(), b.get(n+1), weight );
weightedBetas.set( n, weightedBeta );
}
return weightedBetas;
}
/**
* Computes the probability of the various states at a time instance given
* the observation sequence. Rabiner calls this the "gamma".
* @param alpha
* Forward probability at time n.
* @param beta
* Backward probability at time n.
* @param scaleFactor
* Amount to scale the gamma by
* @return
* Gamma at time n.
*/
protected static Vector computeStateObservationLikelihood(
Vector alpha,
Vector beta,
double scaleFactor )
{
Vector gamma = alpha.dotTimes(beta);
gamma.scaleEquals(scaleFactor/gamma.norm1());
return gamma;
}
/**
* Computes the probabilities of the various states over time given the
* observation sequence. Rabiner calls these the "gammas".
* @param alphas
* Forward probabilities.
* @param betas
* Backward probabilities.
* @param scaleFactor
* Amount to scale the gamma by
* @return
* Gammas.
*/
protected ArrayList computeStateObservationLikelihood(
ArrayList> alphas,
ArrayList> betas,
double scaleFactor )
{
final int N = alphas.size();
ArrayList gammas = new ArrayList( N );
for( int n = 0; n < N; n++ )
{
Vector alphan = alphas.get(n).getValue();
Vector betan = betas.get(n).getValue();
gammas.add( computeStateObservationLikelihood(
alphan, betan, scaleFactor ) );
}
return gammas;
}
/**
* Computes the stochastic transition-probability matrix from the
* given probabilities.
* @param alphan
* Result of the forward pass through the HMM at time n
* @param betanp1
* Result of the backward pass through the HMM at time n+1
* @param bnp1
* Conditionally independent likelihoods of each observation at time n+1
* @return
* Transition probabilities at time n
*/
protected static Matrix computeTransitions(
Vector alphan,
Vector betanp1,
Vector bnp1 )
{
Vector bnext = bnp1.dotTimes(betanp1);
return bnext.outerProduct(alphan);
}
/**
* Computes the stochastic transition-probability matrix from the
* given probabilities.
* @param alphas
* Result of the forward pass through the HMM.
* @param betas
* Result of the backward pass through the HMM.
* @param b
* Conditionally independent likelihoods of each observation.
* @return
* ML estimate of the transition probability Matrix over all time steps.
*/
protected Matrix computeTransitions(
ArrayList> alphas,
ArrayList> betas,
ArrayList b )
{
final int N = b.size();
RingAccumulator counts = new RingAccumulator();
for( int n = 0; n < N-1; n++ )
{
Vector alpha = alphas.get(n).getValue();
Vector beta = betas.get(n+1).getValue();
counts.accumulate(
computeTransitions( alpha, beta, b.get(n+1) ) );
}
Matrix A = counts.getSum();
A.dotTimesEquals(this.getTransitionProbability());
this.normalizeTransitionMatrix(A);
return A;
}
@Override
public String toString()
{
StringBuilder retval = new StringBuilder( super.toString() );
for( ComputableDistribution f : this.getEmissionFunctions() )
{
retval.append( "F: " );
retval.append( f.toString() );
}
return retval.toString();
}
/**
* Finds the most-likely next state given the previous "delta" in the
* Viterbi algorithm.
* @param destinationState
* Destination state index to consider.
* @param delta
* Previous value of the "delta".
* @return
* Most-likely previous state, weighted by its likelihood.
*/
protected WeightedValue findMostLikelyState(
int destinationState, Vector delta )
{
double best = Double.NEGATIVE_INFINITY;
int index = -1;
double dj;
final int k = delta.getDimensionality();
for( int j = 0; j < k; j++ )
{
dj = this.transitionProbability.getElement(destinationState,j) * delta.getElement(j);
if( best < dj )
{
best = dj;
index = j;
}
}
return new DefaultWeightedValue( index, best );
}
/**
* Computes the Viterbi recursion for a given "delta" and "b"
* @param delta
* Previous value of the Viterbi recursion.
* @param bn
* Current observation likelihood.
* @return
* Updated "delta" and state backpointers.
*/
protected Pair computeViterbiRecursion(
Vector delta,
Vector bn )
{
final int k = delta.getDimensionality();
final Vector dn = VectorFactory.getDefault().createVector(k);
final int[] psi = new int[ k ];
for( int i = 0; i < k; i++ )
{
WeightedValue transition =
this.findMostLikelyState(i, delta);
psi[i] = transition.getValue();
dn.setElement(i, transition.getWeight());
}
dn.dotTimesEquals( bn );
delta = dn;
delta.scaleEquals( 1.0/delta.norm1() );
return DefaultPair.create( delta, psi );
}
/**
* Viterbi algorithm for decoding the most-likely sequence of states
* from the HMMs underlying Markov chain for a given observation sequence.
* @param observations
* Observation sequence to consider
* @return
* Indices of the most-likely state sequence that generated the given
* observations.
*/
@PublicationReference(
author="Wikipedia",
title="Viterbi algorithm",
year=2010,
type=PublicationType.WebPage,
url="http://en.wikipedia.org/wiki/Viterbi_algorithm"
)
public ArrayList viterbi(
Collection extends ObservationType> observations )
{
final int N = observations.size();
final int k = this.getNumStates();
ArrayList bs = this.computeObservationLikelihoods(observations);
Vector delta = this.getInitialProbability().dotTimes( bs.get(0) );
ArrayList psis = new ArrayList( N );
int[] psi = new int[ k ];
for( int i = 0; i < k; i++ )
{
psi[i] = 0;
}
psis.add( psi );
ArrayList states = new ArrayList( N );
states.add( null );
for( int n = 1; n < N; n++ )
{
states.add( null );
Pair pair =
this.computeViterbiRecursion( delta, bs.get(n) );
delta = pair.getFirst();
psis.add( pair.getSecond() );
}
// Backchaining
int finalState = -1;
double best = Double.NEGATIVE_INFINITY;
for( int i = 0; i < k; i++ )
{
final double v = delta.getElement(i);
if( best < v )
{
best = v;
finalState = i;
}
}
int state = finalState;
states.set(N-1, state);
for( int n = N-2; n >= 0; n-- )
{
state = psis.get(n+1)[state];
states.set(n, state);
}
return states;
}
/**
* Computes the probability distribution over all states for each
* observation.
* @param observations
* @return
* The list of state belief probabilities for each observation.
*/
public ArrayList stateBeliefs(
Collection extends ObservationType> observations )
{
ArrayList bs = this.computeObservationLikelihoods(observations);
ArrayList> alphas =
this.computeForwardProbabilities(bs, true);
ArrayList beliefs = new ArrayList( alphas.size() );
for( WeightedValue alpha : alphas )
{
beliefs.add( alpha.getValue() );
}
return beliefs;
}
}