All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.hmm.BaumWelchAlgorithm Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                BaumWelchAlgorithm.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright Jan 19, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 *
 */

package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.MultiCollection;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;

/**
 * Implements the Baum-Welch algorithm, also known as the "forward-backward
 * algorithm", the expectation-maximization algorithm, etc for
 * Hidden Markov Models (HMMs).  This is the
 * standard learning algorithm for HMMs.  This implementation allows for
 * multiple sequences using the MultiCollection interface.
 * @param  Type of Observations handled by the HMM.
 */
@PublicationReference(
    author="Lawrence R. Rabiner",
    title="A tutorial on hidden Markov models and selected applications in speech recognition",
    type=PublicationType.Journal,
    year=1989,
    publication="Proceedings of the IEEE",
    pages={257,286},
    url="http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf",
    notes="Rabiner's transition matrix is transposed from mine."
)
public class BaumWelchAlgorithm
    extends AbstractBaumWelchAlgorithm>
{

    /**
     * Creates a new instance of BaumWelchAlgorithm
     */
    public BaumWelchAlgorithm()
    {
        this( null, null, DEFAULT_REESTIMATE_INITIAL_PROBABILITY );
    }

    /**
     * Creates a new instance of BaumWelchAlgorithm
     * @param initialGuess
     * Initial guess for the iterations.
     * @param distributionLearner
     * Learner for the Probability Functions of the HMM.
     * @param reestimateInitialProbabilities
     * Flag to re-estimate the initial probability Vector.
     */
    public BaumWelchAlgorithm(
        HiddenMarkovModel initialGuess,
        BatchLearner>,? extends ComputableDistribution> distributionLearner,
        boolean reestimateInitialProbabilities )
    {
        super( initialGuess, distributionLearner, reestimateInitialProbabilities );
    }

    @Override
    public BaumWelchAlgorithm clone()
    {
        return (BaumWelchAlgorithm) super.clone();
    }

    /**
     * Weighted data for the re-estimation procedure.
     */
    private transient ArrayList> weightedData;

    /**
     * Log likelihoods of the various sequences, with the corresponding weights.
     */
    private transient ArrayList> sequenceLogLikelihoods;

    /**
     * Total number of observations in the sequences of data.
     */
    private transient int totalNum;

    /**
     * The multi-collection of sequences
     */
    protected transient MultiCollection multicollection;

    /**
     * The list of all gammas from each sequence
     */
    protected transient ArrayList sequenceGammas;


    /**
     * Allows the algorithm to learn against multiple sequences of data.
     * @param data
     * Multiple sequences of data against which to train.
     * @return
     * HMM resulting from the locally maximum likelihood estimate of the
     * Baum-Welch algorithm.
     */
    public HiddenMarkovModel learn(
        MultiCollection data)
    {
        return super.learn(data);
    }

    @SuppressWarnings("unchecked")
    @Override
    protected boolean initializeAlgorithm()
    {
        this.multicollection = DatasetUtil.asMultiCollection(this.data);

        // Let's make sure nobody uses the original data!
        this.data = null;


        final int numSequences = this.multicollection.getSubCollectionsCount();
        this.sequenceLogLikelihoods =
            new ArrayList>( numSequences );
        this.totalNum = 0;
        for( Collection sequence : this.multicollection.subCollections() )
        {
            this.sequenceLogLikelihoods.add( new DefaultWeightedValue() );
            this.totalNum += sequence.size();
        }
        this.weightedData = new ArrayList>(
            this.totalNum );
        this.sequenceGammas = new ArrayList( this.totalNum );
        for( Collection sequence : this.multicollection.subCollections() )
        {
            for( ObservationType observation : sequence )
            {
                this.weightedData.add( new DefaultWeightedValue( observation ) );
                this.sequenceGammas.add( null );
            }
        }

        this.result = this.getInitialGuess().clone();
        this.lastLogLikelihood = this.updateSequenceLogLikelihoods( this.result );

        return (this.result != null);

    }

    @Override
    protected boolean step()
    {
        final int numSequences = this.multicollection.getSubCollectionsCount();
        final boolean updatePi = this.getReestimateInitialProbabilities();
        Pair>, ArrayList> pair =
            this.computeSequenceParameters();
        ArrayList> allGammas = pair.getFirst();
        ArrayList sequenceTransitionMatrices = pair.getSecond();
        ArrayList firstGammas  = (updatePi)
            ? new ArrayList( numSequences ) : null;

        int index = 0;
        for( int i = 0; i < numSequences; i++ )
        {
            ArrayList gammas = allGammas.get(i);
            if( updatePi )
            {
                firstGammas.add( gammas.get(0) );
            }
            final int Ni = gammas.size();
            for( int n = 0; n < Ni; n++ )
            {
                this.sequenceGammas.set(index, gammas.get(n));
                index++;
            }
        }

        Vector pi = this.result.getInitialProbability();
        if( this.getReestimateInitialProbabilities() )
        {
            pi = this.updateInitialProbabilities(firstGammas);
        }

        Matrix A = this.updateTransitionMatrix(sequenceTransitionMatrices);
        ArrayList> fs =
            this.updateProbabilityFunctions(this.sequenceGammas);

        // See how well we're doing...
        boolean gettingBetter;
        
        // If somebody asks for a single iteration, then just assume they want
        // the re-estimated parameter
        if( this.getMaxIterations() <= 1 )
        {
            this.result.emissionFunctions = fs;
            this.result.initialProbability = pi;
            this.result.transitionProbability = A;
            gettingBetter = true;
        }

        // If somebody wants multiple steps, then check and see if our
        // current candidate is actually better.  This can happen due to
        // numerical round-off or asymptotic behavior.
        else
        {
            HiddenMarkovModel candidate = this.result.clone();
            candidate.emissionFunctions = fs;
            candidate.initialProbability = pi;
            candidate.transitionProbability = A;
            double logLikelihood = this.updateSequenceLogLikelihoods( candidate );

            gettingBetter = (logLikelihood > this.lastLogLikelihood) ||
                (this.getIteration() <= 1);
            if( gettingBetter )
            {
                this.result = candidate;
                this.lastLogLikelihood = logLikelihood;
            }
        }

        return gettingBetter;
    }

    @Override
    protected void cleanupAlgorithm()
    {
        this.multicollection = null;
        this.weightedData = null;
        this.sequenceLogLikelihoods = null;
        this.totalNum = 0;
    }

    /**
     * Computes the gammas and A matrices for each sequence.
     * @return
     * Gammas and A matrices for each sequence
     */
    protected Pair>,ArrayList> computeSequenceParameters()
    {

        final int numSequences = this.multicollection.getSubCollectionsCount();
        ArrayList> allGammas =
            new ArrayList>( numSequences );
        ArrayList sequenceTransitionMatrices =
            new ArrayList( numSequences );
        
        final boolean normalize = true;
        int k = 0;
        for( Collection sequence : this.multicollection.subCollections() )
        {
            double sequenceWeight = this.sequenceLogLikelihoods.get(k).getWeight();
            ArrayList b =
                this.result.computeObservationLikelihoods( sequence );
            ArrayList> alphas =
                this.result.computeForwardProbabilities( b, normalize );
            ArrayList> betas =
                this.result.computeBackwardProbabilities( b, alphas );
            ArrayList gammas =
                this.result.computeStateObservationLikelihood(alphas, betas, sequenceWeight);
            allGammas.add( gammas );
            Matrix A = this.result.computeTransitions( alphas, betas, b );
            if( sequenceWeight != 1.0 )
            {
                A.scaleEquals(sequenceWeight);
            }
            sequenceTransitionMatrices.add( A );
            k++;
        }

        return DefaultPair.create( allGammas, sequenceTransitionMatrices );
    }

    /**
     * Updates the probability function from the concatenated gammas from
     * all sequences
     * @param sequenceGammas
     * Concatenated gammas from all sequences
     * @return
     * Maximum Likelihood probability functions
     */
    protected ArrayList> updateProbabilityFunctions(
        ArrayList sequenceGammas )
    {
        final int numStates = this.result.getNumStates();
        ArrayList> fs =
            new ArrayList>( numStates );
        for( int i = 0; i < numStates; i++ )
        {
            int index = 0;
            for( int n = 0; n < sequenceGammas.size(); n++ )
            {
                final double g = sequenceGammas.get(n).getElement(i);
                this.weightedData.get(index).setWeight(g);
                index++;
            }
            ProbabilityFunction f =
                this.distributionLearner.learn( this.weightedData ).getProbabilityFunction();
            fs.add( f );
        }
        return fs;
    }

    /**
     * Computes an updated transition matrix from the scaled estimates
     * @param sequenceTransitionMatrices
     * Scaled estimates from each sequence
     * @return
     * Overall Maximum Likelihood estimate of the transition matrix
     */
    protected Matrix updateTransitionMatrix(
        ArrayList sequenceTransitionMatrices )
    {
        RingAccumulator As =
            new RingAccumulator( sequenceTransitionMatrices );
        Matrix A = As.getSum();
        this.result.normalizeTransitionMatrix(A);
        return A;
    }

    /**
     * Updates the initial probabilities from sequenceGammas
     * @param firstGammas
     * The first gamma of the each sequence
     * @return
     * Updated initial probability Vector for the HMM.
     */
    protected Vector updateInitialProbabilities(
        ArrayList firstGammas )
    {
        RingAccumulator pi = new RingAccumulator();
        for( int k = 0; k < firstGammas.size(); k++ )
        {
            pi.accumulate( firstGammas.get(k) );
        }
        Vector pisum = pi.getSum();
        pisum.scaleEquals( 1.0 / pisum.norm1() );
        return pisum;
    }

    /**
     * Updates the internal sequence likelihoods for the given HMM
     * @param hmm
     * Hidden Markov model to consider
     * @return
     * log likelihood of the observations sequences given the HMM.
     */
    protected double updateSequenceLogLikelihoods(
        HiddenMarkovModel hmm )
    {
        int k = 0;
        double maxLogLikelihood = Double.NEGATIVE_INFINITY;
        double totalLogLikelihood = 0.0;
        for( Collection sequence : this.multicollection.subCollections() )
        {
            double logLikelihood = hmm.computeObservationLogLikelihood(sequence);
            if( maxLogLikelihood < logLikelihood )
            {
                maxLogLikelihood = logLikelihood;
            }
            this.sequenceLogLikelihoods.get(k).setValue(logLikelihood);
            totalLogLikelihood += logLikelihood;
            k++;
        }

        // Subtract off the maximum log-likleihood to at least make sure the
        // weights go from 1.0 to almost zero.
        final int numSequences = this.multicollection.getSubCollectionsCount();
        for( k = 0; k < numSequences; k++ )
        {
            // The weight is the INVERSE of the sequence Probability!!
            DefaultWeightedValue wv = this.sequenceLogLikelihoods.get(k);
            final double logLikelihood = wv.getValue();
            final double weight = 1.0/Math.exp(logLikelihood - maxLogLikelihood);
            wv.setWeight(weight);
        }
        return totalLogLikelihood;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy