All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.hmm.HiddenMarkovModel Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                HiddenMarkovModel.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Feb 2, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.Distribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.distribution.DirichletDistribution;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Random;

/**
 * A discrete-state Hidden Markov Model (HMM) with either continuous
 * or discrete observations.
 * @author Kevin R. Dixon
 * @since 3.0
 * @param  Type of Observations handled by the HMM.
 */
@PublicationReference(
    author="Lawrence R. Rabiner",
    title="A tutorial on hidden Markov models and selected applications in speech recognition",
    type=PublicationType.Journal,
    year=1989,
    publication="Proceedings of the IEEE",
    pages={257,286},
    url="http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf",
    notes="Rabiner's transition matrix is transposed from mine."
)
public class HiddenMarkovModel
    extends MarkovChain
    implements Distribution
{

    /**
     * The PDFs that emit symbols from each state.
     */
    protected Collection> emissionFunctions;

    /**
     * Default constructor.
     */
    public HiddenMarkovModel()
    {
        super();
    }

    /**
     * Creates a new instance of ContinuousDensityHiddenMarkovModel
     * @param numStates
     * Number of states in the HMM.
     */
    public HiddenMarkovModel(
        int numStates )
    {
        super( numStates );
    }

    /**
     * Creates a new instance of ContinuousDensityHiddenMarkovModel
     * @param initialProbability
     * Initial probability Vector over the states.  Each entry must be
     * nonnegative and the Vector must sum to 1.
     * @param transitionProbability
     * Transition probability matrix.  The entry (i,j) is the probability
     * of transition from state "j" to state "i".  As a corollary, all
     * entries in the Matrix must be nonnegative and the
     * columns of the Matrix must sum to 1.
     * @param emissionFunctions
     * The PDFs that emit symbols from each state.
     */
    public HiddenMarkovModel(
        Vector initialProbability,
        Matrix transitionProbability,
        Collection> emissionFunctions )
    {
        super( initialProbability, transitionProbability );
        final int k = this.getNumStates();
        if( emissionFunctions.size() != k )
        {
            throw new IllegalArgumentException(
                "Number of PDFs must be equal to number of states!" );
        }

        this.setEmissionFunctions(emissionFunctions);

    }

    /**
     * Creates a Hidden Markov Model with the same PMF/PDF for each state,
     * but sampling the columns of the transition matrix and the initial
     * probability distributions from a diffuse Dirichlet.
     * @param 
     * Type of observations to generate.
     * @param numStates
     * Number of states to create
     * @param learner
     * Learner to create the distributions for each state
     * @param data
     * Data from which to make the distribution
     * @param random
     * Random number generator
     * @return
     * HMM with the specified states.
     */
    public static  HiddenMarkovModel createRandom(
        int numStates,
        BatchLearner>,? extends ComputableDistribution> learner,
        Collection data,
        Random random )
    {
        ArrayList> weightedData =
            new ArrayList>( data.size() );
        for( ObservationType observation : data )
        {
            weightedData.add( new DefaultWeightedValue( observation, 1.0 ) );
        }
        ComputableDistribution distribution =
            learner.learn(weightedData);
        return createRandom( numStates, distribution, random );
        
    }

    /**
     * Creates a Hidden Markov Model with the same PMF/PDF for each state,
     * but sampling the columns of the transition matrix and the initial
     * probability distributions from a diffuse Dirichlet.
     * @param 
     * Type of observations to generate.
     * @param numStates
     * Number of states to create
     * @param distribution
     * Distribution from which we will assign to each state.
     * @param random
     * Random number generator
     * @return
     * HMM with the specified states.
     */
    public static  HiddenMarkovModel createRandom(
        int numStates,
        ComputableDistribution distribution,
        Random random )
    {
        Collection> distributions =
            Collections.nCopies(numStates, distribution.getProbabilityFunction());
        return createRandom(distributions, random);
    }

    /**
     * Creates a Hidden Markov Model with the given probability function for
     * each state, but sampling the columns of the transition matrix and the
     * initial probability distributions from a diffuse Dirichlet.
     * @param 
     *      Type of observations to generate.
     * @param distributions
     *      The distribution for each state. The size of the collection is the
     *      number of states to create.
     * @param random
     * Random number generator
     * @return
     * HMM with the specified states.
     */
    public static  HiddenMarkovModel createRandom(
        Collection> distributions,
        Random random)
    {
        int numStates = distributions.size();

        // We'll sample the multinomial probabilities from a uniform Dirichlet
        DirichletDistribution dirichlet =
            new DirichletDistribution( numStates );

        Vector initialProbability = dirichlet.sample(random);
        Matrix transitionMatrix =
            MatrixFactory.getDefault().createMatrix(numStates, numStates);
        ArrayList outbounds = dirichlet.sample(random, numStates);
        for( int i = 0; i < numStates; i++ )
        {
            transitionMatrix.setColumn( i, outbounds.get(i) );
        }

        return new HiddenMarkovModel(
            initialProbability,transitionMatrix, distributions);
    }

    @Override
    public HiddenMarkovModel clone()
    {
        @SuppressWarnings("unchecked")
        HiddenMarkovModel  clone =
            (HiddenMarkovModel) super.clone();
        clone.setEmissionFunctions( ObjectUtil.cloneSmartElementsAsArrayList(
            this.getEmissionFunctions() ) );
        return clone;
    }

    /**
     * Computes the log-likelihood of the observation sequence, given the
     * current HMM's parameterization.  This is the answer to Rabiner's
     * "Three Basic Problems for HMMs, Problem 1: Probability Evaluation".
     * @param observations
     * Observations to consider.
     * @return
     * Log-likelihood of the given observation sequence.
     */
    public double computeObservationLogLikelihood(
        Collection observations )
    {

        final int k = this.getNumStates();
        Vector b = VectorFactory.getDefault().createVector(k);
        Vector alpha = this.getInitialProbability().clone();
        Matrix A = this.getTransitionProbability();
        int index = 0;
        double logLikelihood = 0.0;
        for( ObservationType observation : observations )
        {
            if( index > 0 )
            {
                alpha = A.times( alpha );
            }
            this.computeObservationLikelihoods(observation, b);
            alpha.dotTimesEquals(b);
            final double weight = alpha.norm1();
            alpha.scaleEquals(1.0/weight);
            logLikelihood += Math.log(weight);
            index++;
        }

        return logLikelihood;

    }

    /**
     * Computes the log-likelihood of the observation sequences, given the
     * current HMM's parameterization.  This is the answer to Rabiner's
     * "Three Basic Problems for HMMs, Problem 1: Probability Evaluation".
     * @param sequences
     * Observations sequences to consider
     * @return
     * Log-likelihood of the given observation sequence.
     */
    protected double computeMultipleObservationLogLikelihood(
        Collection> sequences )
    {
        double logLikelihood = 0.0;
        for( Collection observations : sequences )
        {
            logLikelihood += this.computeObservationLogLikelihood(observations);
        }
        return logLikelihood;
    }

    /**
     * Computes the log-likelihood that the given observation sequence
     * was generated by the given sequence of state indices.
     * @param observations
     * Observations to consider.
     * @param states
     * Indices of states hypothesized to have generated the observation
     * sequence.
     * @return
     * Log-likelihood of the given observation sequence.
     */
    public double computeObservationLogLikelihood(
        Collection observations,
        Collection states )
    {

        final int N = observations.size();
        if( N != states.size() )
        {
            throw new IllegalArgumentException(
                "Observations and states must be the same size" );
        }

        Iterator stateIterator = states.iterator();
        double logLikelihood = 0.0;
        ArrayList> fs =
            new ArrayList>( this.getNumStates() );
        for( ComputableDistribution f : this.getEmissionFunctions() )
        {
            fs.add( f.getProbabilityFunction() );
        }
        
        int lastState = -1;
        for( ObservationType observation : observations )
        {
            final int state = stateIterator.next();
            double blog = fs.get(state).logEvaluate(observation);
            double ll;
            if( lastState < 0 )
            {
                ll = Math.log(this.initialProbability.getElement(state));
            }
            else
            {
                ll = Math.log(this.transitionProbability.getElement(state, lastState) );
            }

            lastState = state;

            logLikelihood += blog + ll;

        }

        return logLikelihood;

    }

    @Override
    public ObservationType sample(
        Random random)
    {
        return CollectionUtil.getFirst( this.sample(random, 1) );
    }

    @Override
    public ArrayList sample(
        Random random,
        int numSamples )
    {
        final ArrayList samples = new ArrayList(numSamples);
        this.sampleInto(random, numSamples, samples);
        return samples;
    }

    @Override
    public void sampleInto(
        final Random random,
        final int sampleCount,
        final Collection output)
    {
        Vector p = this.getInitialProbability();
        int state = -1;
        for( int n = 0; n < sampleCount; n++ )
        {
            double value = random.nextDouble();
            state = -1;
            while( value > 0.0 )
            {
                state++;
                value -= p.getElement(state);
            }

            ObservationType sample = CollectionUtil.getElement(
                this.getEmissionFunctions(), state ).sample(random);
            output.add( sample );
            p = this.getTransitionProbability().getColumn(state);
        }
    }

    /**
     * Getter for emissionFunctions
     * @return
     * The PDFs that emit symbols from each state.
     */
    public Collection> getEmissionFunctions()
    {
        return this.emissionFunctions;
    }

    /**
     * Setter for emissionFunctions.
     * @param emissionFunctions
     * The PDFs that emit symbols from each state.
     */
    public void setEmissionFunctions(
        Collection> emissionFunctions)
    {
        this.emissionFunctions = emissionFunctions;
    }

    /*
    @SuppressWarnings("unchecked")
    public ObservationType getMean()
    {

        Vector p = this.getSteadyStateDistribution();

        ObservationType observation =
            CollectionUtil.getFirst(this.emissionFunctions).getMean();
        if( observation instanceof Ring )
        {
            RingAccumulator weightedAverage = new RingAccumulator();
            int i = 0;
            for( ComputableDistribution f : this.emissionFunctions )
            {
                Ring mean = (Ring) f.getMean();
                weightedAverage.accumulate( mean.scale( p.getElement(i) ) );
                i++;
            }
            return (ObservationType) weightedAverage.getSum();
        }
        else if( observation instanceof Number )
        {
            double weightedAverage = 0.0;
            int i = 0;
            for( ComputableDistribution f : this.emissionFunctions )
            {
                Number mean = (Number) f.getMean();
                weightedAverage += mean.doubleValue() * p.getElement(i);
                i++;
            }
            return (ObservationType) (new Double( weightedAverage ));
        }
        else
        {
            throw new UnsupportedOperationException(
                "Mean not supported for type: " + observation.getClass() );
        }
    }
     *
     */

    /**
     * Computes the recursive solution to the forward probabilities of the
     * HMM.
     * @param alpha
     * Previous alpha value.
     * @param b
     * Current observation-emission likelihood.
     * @param normalize
     * True to normalize the alphas, false to leave them unnormalized.
     * @return
     * Alpha with the associated weighting (will be 1 if unnormalized).
     */
    protected WeightedValue computeForwardProbabilities(
        Vector alpha,
        Vector b,
        boolean normalize )
    {
        Vector alphaNext = this.getTransitionProbability().times( alpha );
        alphaNext.dotTimesEquals(b);

        double weight;
        if( normalize )
        {
            weight = 1.0/alphaNext.norm1();
            alphaNext.scaleEquals(weight);
        }
        else
        {
            weight = 1.0;
        }

        return new DefaultWeightedValue( alphaNext, weight );

    }


    /**
     * Computes the forward probabilities for the given observation likelihood
     * sequence.
     * @param b
     * Observation likelihood sequence.
     * @param normalize
     * True to normalize the alphas, false to leave them unnormalized.
     * @return
     * Forward probability alphas, with their associated weights.
     */
    protected ArrayList> computeForwardProbabilities(
        ArrayList b,
        boolean normalize )
    {
        final int N = b.size();
        ArrayList> weightedAlphas =
            new ArrayList>( N );
        Vector alpha = b.get(0).dotTimes( this.getInitialProbability() );
        double weight;
        if( normalize )
        {
            weight = 1.0/alpha.norm1();
            alpha.scaleEquals(weight);
        }
        else
        {
            weight = 1.0;
        }
        WeightedValue weightedAlpha =
            new DefaultWeightedValue(alpha,weight);
        weightedAlphas.add( weightedAlpha );
        for( int n = 1; n < N; n++ )
        {
            weightedAlpha = this.computeForwardProbabilities(
                weightedAlpha.getValue(), b.get(n), normalize );
            weightedAlphas.add( weightedAlpha );
        }

        return weightedAlphas;
    }

    /**
     * Computes the conditionally independent likelihoods
     * for each state given the observation.
     * @param observation
     * Observation to consider
     * @return
     * Likelihood of each state generating the given observation.
     */
    public Vector computeObservationLikelihoods(
        ObservationType observation )
    {
        final int k = this.getEmissionFunctions().size();
        Vector b = VectorFactory.getDefault().createVector(k);
        this.computeObservationLikelihoods(observation, b);
        return b;
    }

    /**
     * Computes the conditionally independent likelihoods
     * for each state given the observation.
     * @param observation
     * Observation to consider
     * @param b
     * Likelihood of each state generating the given observation. This is where
     * the result of the computation is stored.
     */
    protected void computeObservationLikelihoods(
        ObservationType observation,
        Vector b )
    {
        int i = 0;
        for( ComputableDistribution f : this.getEmissionFunctions() )
        {
            b.setElement(i, f.getProbabilityFunction().evaluate(observation) );
            i++;
        }
    }

    /**
     * Computes the conditionally independent likelihoods
     * for each state given the observation sequence.
     * @param observations
     * Observation sequence to consider
     * @return
     * Likelihood of each state generating the given observation sequence.
     */
    protected ArrayList computeObservationLikelihoods(
        Collection observations )
    {
        final int N = observations.size();
        ArrayList bs = new ArrayList( N );
        for( ObservationType observation : observations )
        {
            bs.add( this.computeObservationLikelihoods(observation) );
        }

        return bs;
    }

    /**
     * Computes the backward probability recursion.
     * @param beta
     * Beta from the "next" time step.
     * @param b
     * Observation likelihood from the "next" time step.
     * @param weight
     * Weight to use for the current time step.
     * @return
     * Beta for the previous time step, weighted by "weight".
     */
    protected WeightedValue computeBackwardProbabilities(
        Vector beta,
        Vector b,
        double weight )
    {
        Vector betaPrevious = b.dotTimes(beta);
        betaPrevious = betaPrevious.times( this.getTransitionProbability() );
        if( weight != 1.0 )
        {
            betaPrevious.scaleEquals(weight);
        }
        return new DefaultWeightedValue( betaPrevious, weight );
    }

    /**
     * Computes the backward-probabilities for the given observation likelihoods
     * and the weights from the alphas.
     * @param b
     * Observation likelihoods.
     * @param alphas
     * Forward probabilities from which we will use the weights.
     * @return
     * Backward probabilities.
     */
    protected ArrayList> computeBackwardProbabilities(
        ArrayList b,
        ArrayList> alphas )
    {
        final int N = b.size();
        final int k = this.getInitialProbability().getDimensionality();
        ArrayList> weightedBetas =
            new ArrayList>( N );
        for( int n = 0; n < N; n++ )
        {
            weightedBetas.add( null );
        }
        Vector beta = VectorFactory.getDefault().createVector(k, 1.0);
        double weight = alphas.get(N-1).getWeight();
        if( weight != 1.0 )
        {
            beta.scaleEquals(weight);
        }
        WeightedValue weightedBeta =
            new DefaultWeightedValue( beta, weight );
        weightedBetas.set( N-1, weightedBeta );
        for( int n = N-2; n >= 0; n-- )
        {
            weight = alphas.get(n).getWeight();
            weightedBeta = this.computeBackwardProbabilities(
                weightedBeta.getValue(), b.get(n+1), weight );
            weightedBetas.set( n, weightedBeta );
        }

        return weightedBetas;

    }

    /**
     * Computes the probability of the various states at a time instance given
     * the observation sequence.  Rabiner calls this the "gamma".
     * @param alpha
     * Forward probability at time n.
     * @param beta
     * Backward probability at time n.
     * @param scaleFactor
     * Amount to scale the gamma by
     * @return
     * Gamma at time n.
     */
    protected static Vector computeStateObservationLikelihood(
        Vector alpha,
        Vector beta,
        double scaleFactor )
    {
        Vector gamma = alpha.dotTimes(beta);
        gamma.scaleEquals(scaleFactor/gamma.norm1());
        return gamma;
    }

    /**
     * Computes the probabilities of the various states over time given the
     * observation sequence.  Rabiner calls these the "gammas".
     * @param alphas
     * Forward probabilities.
     * @param betas
     * Backward probabilities.
     * @param scaleFactor
     * Amount to scale the gamma by
     * @return
     * Gammas.
     */
    protected ArrayList computeStateObservationLikelihood(
        ArrayList> alphas,
        ArrayList> betas,
        double scaleFactor )
    {
        final int N = alphas.size();
        ArrayList gammas = new ArrayList( N );
        for( int n = 0; n < N; n++ )
        {
            Vector alphan = alphas.get(n).getValue();
            Vector betan = betas.get(n).getValue();
            gammas.add( computeStateObservationLikelihood(
                alphan, betan, scaleFactor ) );

        }

        return gammas;
    }

    /**
     * Computes the stochastic transition-probability matrix from the
     * given probabilities.
     * @param alphan
     * Result of the forward pass through the HMM at time n
     * @param betanp1
     * Result of the backward pass through the HMM at time n+1
     * @param bnp1
     * Conditionally independent likelihoods of each observation at time n+1
     * @return
     * Transition probabilities at time n
     */
    protected static Matrix computeTransitions(
        Vector alphan,
        Vector betanp1,
        Vector bnp1 )
    {
        Vector bnext = bnp1.dotTimes(betanp1);
        return bnext.outerProduct(alphan);
    }

    /**
     * Computes the stochastic transition-probability matrix from the
     * given probabilities.
     * @param alphas
     * Result of the forward pass through the HMM.
     * @param betas
     * Result of the backward pass through the HMM.
     * @param b
     * Conditionally independent likelihoods of each observation.
     * @return
     * ML estimate of the transition probability Matrix over all time steps.
     */
    protected Matrix computeTransitions(
        ArrayList> alphas,
        ArrayList> betas,
        ArrayList b )
    {
        final int N = b.size();
        RingAccumulator counts = new RingAccumulator();
        for( int n = 0; n < N-1; n++ )
        {
            Vector alpha = alphas.get(n).getValue();
            Vector beta = betas.get(n+1).getValue();
            counts.accumulate(
                computeTransitions( alpha, beta, b.get(n+1) ) );
        }

        Matrix A = counts.getSum();
        A.dotTimesEquals(this.getTransitionProbability());
        this.normalizeTransitionMatrix(A);

        return A;

    }

    @Override
    public String toString()
    {

        StringBuilder retval = new StringBuilder( super.toString() );
        for( ComputableDistribution f : this.getEmissionFunctions() )
        {
            retval.append( "F: " );
            retval.append( f.toString() );
        }

        return retval.toString();
    }

    /**
     * Finds the most-likely next state given the previous "delta" in the
     * Viterbi algorithm.
     * @param destinationState
     * Destination state index to consider.
     * @param delta
     * Previous value of the "delta".
     * @return
     * Most-likely previous state, weighted by its likelihood.
     */
    protected WeightedValue findMostLikelyState(
        int destinationState, Vector delta )
    {
        double best = Double.NEGATIVE_INFINITY;
        int index = -1;
        double dj;
        final int k = delta.getDimensionality();
        for( int j = 0; j < k; j++ )
        {
            dj = this.transitionProbability.getElement(destinationState,j) * delta.getElement(j);
            if( best < dj )
            {
                best = dj;
                index = j;
            }
        }

        return new DefaultWeightedValue( index, best );

    }

    /**
     * Computes the Viterbi recursion for a given "delta" and "b"
     * @param delta
     * Previous value of the Viterbi recursion.
     * @param bn
     * Current observation likelihood.
     * @return
     * Updated "delta" and state backpointers.
     */
    protected Pair computeViterbiRecursion(
        Vector delta,
        Vector bn )
    {

        final int k = delta.getDimensionality();
        final Vector dn = VectorFactory.getDefault().createVector(k);
        final int[] psi = new int[ k ];
        for( int i = 0; i < k; i++ )
        {
            WeightedValue transition =
                this.findMostLikelyState(i, delta);
            psi[i] = transition.getValue();
            dn.setElement(i, transition.getWeight());
        }
        dn.dotTimesEquals( bn );
        delta = dn;
        delta.scaleEquals( 1.0/delta.norm1() );

        return DefaultPair.create( delta, psi );
    }

    /**
     * Viterbi algorithm for decoding the most-likely sequence of states
     * from the HMMs underlying Markov chain for a given observation sequence.
     * @param observations
     * Observation sequence to consider
     * @return
     * Indices of the most-likely state sequence that generated the given
     * observations.
     */
    @PublicationReference(
        author="Wikipedia",
        title="Viterbi algorithm",
        year=2010,
        type=PublicationType.WebPage,
        url="http://en.wikipedia.org/wiki/Viterbi_algorithm"
    )
    public ArrayList viterbi(
        Collection observations )
    {
        
        final int N = observations.size();
        final int k = this.getNumStates();
        ArrayList bs = this.computeObservationLikelihoods(observations);
        Vector delta = this.getInitialProbability().dotTimes( bs.get(0) );
        ArrayList psis = new ArrayList( N );
        int[] psi = new int[ k ];
        for( int i = 0; i < k; i++ )
        {
            psi[i] = 0;
        }
        psis.add( psi );
        ArrayList states = new ArrayList( N );
        states.add( null );
        for( int n = 1; n < N; n++ )
        {
            states.add( null );
            Pair pair =
                this.computeViterbiRecursion( delta, bs.get(n) );
            delta = pair.getFirst();
            psis.add( pair.getSecond() );
        }

        // Backchaining
        int finalState = -1;
        double best = Double.NEGATIVE_INFINITY;
        for( int i = 0; i < k; i++ )
        {
            final double v = delta.getElement(i);
            if( best < v )
            {
                best = v;
                finalState = i;
            }
        }
        int state = finalState;
        states.set(N-1, state);
        for( int n = N-2; n >= 0; n-- )
        {
            state = psis.get(n+1)[state];
            states.set(n, state);
        }

        return states;

    }

    /**
     * Computes the probability distribution over all states for each
     * observation.
     * @param observations
     * @return
     *      The list of state belief probabilities for each observation.
     */
    public ArrayList stateBeliefs(
        Collection observations )
    {

        ArrayList bs = this.computeObservationLikelihoods(observations);
        ArrayList> alphas =
            this.computeForwardProbabilities(bs, true);
        ArrayList beliefs = new ArrayList( alphas.size() );
        for( WeightedValue alpha : alphas )
        {
            beliefs.add( alpha.getValue() );
        }
        return beliefs;

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy