All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.hmm.ParallelHiddenMarkovModel Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                ParallelHiddenMarkovModel.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Feb 4, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * A Hidden Markov Model with parallelized processing.
 * @param  Type of Observations handled by the HMM.
 * @author Kevin R. Dixon
 * @since 3.0
 */
@PublicationReference(
    author="William Turin",
    title="Unidirectional and Parallel Baum–Welch Algorithms",
    type=PublicationType.Journal,
    publication="IEEE Transactions on Speech and Audio Processing",
    year=1998,
    url="http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=00725318"
)
public class ParallelHiddenMarkovModel
    extends HiddenMarkovModel
    implements ParallelAlgorithm
{

    /**
     * Thread pool used for parallelization.
     */
    private transient ThreadPoolExecutor threadPool;

    /** 
     * Creates a new instance of ParallelHiddenMarkovModel 
     */
    public ParallelHiddenMarkovModel()
    {
        super();
    }

    /**
     * Creates a new instance of ParallelHiddenMarkovModel
     * @param numStates
     * Number of states in the HMM.
     */
    public ParallelHiddenMarkovModel(
        int numStates )
    {
        super( numStates );
    }

    /**
     * Creates a new instance of ParallelHiddenMarkovModel
     * @param initialProbability
     * Initial probability Vector over the states.  Each entry must be
     * nonnegative and the Vector must sum to 1.
     * @param transitionProbability
     * Transition probability matrix.  The entry (i,j) is the probability
     * of transition from state "j" to state "i".  As a corollary, all
     * entries in the Matrix must be nonnegative and the
     * columns of the Matrix must sum to 1.
     * @param emissionFunctions
     * The PDFs that emit symbols from each state.
     */
    public ParallelHiddenMarkovModel(
        Vector initialProbability,
        Matrix transitionProbability,
        Collection> emissionFunctions )
    {
        super( initialProbability, transitionProbability, emissionFunctions );
    }

    /**\
     * Creates a new {@code ParallelHiddenMarkovModel} from another
     * {@code HiddenMarkovModel}.
     *
     * @param   other
     *      The other hidden Markov model to copy.
     */
    public ParallelHiddenMarkovModel(
        final HiddenMarkovModel other)
    {
        this(ObjectUtil.cloneSafe(other.getInitialProbability()),
            ObjectUtil.cloneSafe(other.getTransitionProbability()),
            ObjectUtil.cloneSmartElementsAsArrayList(
                other.getEmissionFunctions()));
    }

    public ThreadPoolExecutor getThreadPool()
    {
        if (this.threadPool == null)
        {
            this.setThreadPool(ParallelUtil.createThreadPool());
        }
        return this.threadPool;
    }

    public void setThreadPool(
        ThreadPoolExecutor threadPool)
    {
        this.threadPool = threadPool;
    }

    public int getNumThreads()
    {
        return ParallelUtil.getNumThreads(this);
    }

    /**
     * Observation likelihood tasks
     */
    transient protected ArrayList> observationLikelihoodTasks;

    @Override
    public double computeMultipleObservationLogLikelihood(
        Collection> sequences)
    {

        ArrayList tasks =
            new ArrayList( sequences.size() );
        for( Collection sequence : sequences )
        {
            tasks.add( new LogLikelihoodTask( sequence ) );
        }

        ArrayList results = null;
        try
        {
            results = ParallelUtil.executeInParallel(tasks, this.getThreadPool());
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }

        double logSum = 0.0;
        for( int i = 0; i < results.size(); i++ )
        {
            logSum += results.get(i);
        }
        return logSum;
    }


    /**
     * ComputeTransitionsTasks.
     */
    transient protected ArrayList computeTransitionTasks;

    @Override
    protected Matrix computeTransitions(
        ArrayList> alphas,
        ArrayList> betas,
        ArrayList b)
    {

        final int N = alphas.size();
        if( this.computeTransitionTasks == null )
        {
            this.computeTransitionTasks =
                new ArrayList( N-1 );
        }

        // Make sure it's N-1
        this.computeTransitionTasks.ensureCapacity(N-1);
        while( this.computeTransitionTasks.size() > N-1 )
        {
            this.computeTransitionTasks.remove(
                this.computeTransitionTasks.size()-1 );
        }
        while( this.computeTransitionTasks.size() < N-1 )
        {
            this.computeTransitionTasks.add( new ComputeTransitionsTask() );
        }

        for( int n = 0; n < N-1; n++ )
        {
            final ComputeTransitionsTask tn = this.computeTransitionTasks.get(n);
            tn.alphan = alphas.get(n).getValue();
            tn.betanp1 = betas.get(n+1).getValue();
            tn.bnp1 = b.get(n+1);
        }

        RingAccumulator counts = new RingAccumulator();
        Matrix A = null;
        try
        {
            Collection> futures =
                this.getThreadPool().invokeAll(this.computeTransitionTasks);
            for( Future f : futures )
            {
                counts.accumulate( f.get() );
            }

            A = counts.getSum();
            A.dotTimesEquals(this.getTransitionProbability());
            normalizeTransitionMatrix(A);
        }
        catch (Exception ex)
        {
            throw new RuntimeException( ex );
        }

        return A;

    }

    /**
     * NormalizeTransitionTasks.
     */
    transient protected ArrayList normalizeTransitionTasks;

    @Override
    protected void normalizeTransitionMatrix(
        Matrix A)
    {
        final int k = A.getNumColumns();
        if( this.normalizeTransitionTasks == null )
        {
            this.normalizeTransitionTasks =
                new ArrayList( k );
        }

        this.normalizeTransitionTasks.ensureCapacity(k);
        while( this.normalizeTransitionTasks.size() > k )
        {
            this.normalizeTransitionTasks.remove(
                this.normalizeTransitionTasks.size() - 1 );
        }
        while( this.normalizeTransitionTasks.size() < k )
        {
            this.normalizeTransitionTasks.add( new NormalizeTransitionTask() );
        }

        for( int j = 0; j < k; j++ )
        {
            NormalizeTransitionTask task = this.normalizeTransitionTasks.get(j);
            task.j = j;
            task.A = A;
        }

        try
        {
            ParallelUtil.executeInParallel(
                this.normalizeTransitionTasks, this.getThreadPool());
        }
        catch (Exception ex)
        {
            throw new RuntimeException( ex );
        }

    }

    /**
     * StateObservationLikelihoodTasks
     */
    transient protected ArrayList stateObservationLikelihoodTasks;

    @Override
    protected ArrayList computeStateObservationLikelihood(
        ArrayList> alphas,
        ArrayList> betas,
        double scaleFactor )
    {

        final int N = alphas.size();
        if( this.stateObservationLikelihoodTasks == null )
        {
            this.stateObservationLikelihoodTasks =
                new ArrayList( N );
        }

        this.stateObservationLikelihoodTasks.ensureCapacity(N);
        while( this.stateObservationLikelihoodTasks.size() > N )
        {
            this.stateObservationLikelihoodTasks.remove(
                this.stateObservationLikelihoodTasks.size()-1 );
        }
        while( this.stateObservationLikelihoodTasks.size() < N )
        {
            this.stateObservationLikelihoodTasks.add(
                new StateObservationLikelihoodTask() );
        }

        for( int n = 0; n < N; n++ )
        {
            StateObservationLikelihoodTask task =
                this.stateObservationLikelihoodTasks.get(n);
            task.alpha = alphas.get(n).getValue();
            task.beta = betas.get(n).getValue();
        }

        ArrayList gammas = null;
        try
        {
            gammas = ParallelUtil.executeInParallel(
                this.stateObservationLikelihoodTasks, this.getThreadPool());
        }
        catch (Exception e)
        {
            throw new RuntimeException( e );
        }

        return gammas;
    }

    /**
     * Viterbi tasks.
     */
    transient protected ArrayList viterbiTasks;

    @Override
    protected Pair computeViterbiRecursion(
        Vector delta,
        Vector bn )
    {

        final int k = this.getNumStates();
        if( this.viterbiTasks == null )
        {
            this.viterbiTasks = new ArrayList( k );
        }

        this.viterbiTasks.ensureCapacity(k);
        while( this.viterbiTasks.size() > k )
        {
            this.viterbiTasks.remove(
                this.viterbiTasks.size()-1 );
        }
        while( this.viterbiTasks.size() < k )
        {
            this.viterbiTasks.add( new ViterbiTask() );
        }

        for( int i = 0; i < k; i++ )
        {
            final ViterbiTask task = this.viterbiTasks.get(i);
            task.destinationState = i;
            task.delta = delta;
        }
        
        ArrayList> results;
        try
        {
            results = ParallelUtil.executeInParallel(
                this.viterbiTasks, this.getThreadPool() );
        }
        catch (Exception e)
        {
            throw new RuntimeException( e );
        }

        int[] psis = new int[ k ];
        Vector nextDelta = VectorFactory.getDefault().createVector(k);
        for( int i = 0; i < k; i++ )
        {
            WeightedValue value = results.get(i);
            psis[i] = value.getValue();
            nextDelta.setElement(i, value.getWeight() );
        }

        nextDelta.dotTimesEquals(bn);
        nextDelta.scaleEquals( 1.0/nextDelta.norm1() );

        return DefaultPair.create( nextDelta, psis );

    }

    /**
     * Calls the computeObservationLikelihoods() method.
     * @param  Observation type
     */
    protected static class ObservationLikelihoodTask
        extends AbstractCloneableSerializable
        implements Callable
    {

        /**
         * Observations
         */
        protected Collection observations;

        /**
         * The PDF.
         */
        protected ProbabilityFunction distributionFunction;

        /**
         * Default constructor.
         */
        public ObservationLikelihoodTask()
        {
        }

        public double[] call()
        {
            final int N = this.observations.size();
            double[] b = new double[ N ];
            int n = 0;
            for( ObservationType observation : this.observations )
            {
                b[n] = this.distributionFunction.evaluate(observation);
                n++;
            }
            return b;
        }

    }

    /**
     * Calls the computeStateObservationLikelihood() method.
     */
    protected static class StateObservationLikelihoodTask
        extends AbstractCloneableSerializable
        implements Callable
    {

        /**
         * Alpha at time n.
         */
        protected Vector alpha;

        /**
         * Beta at time n.
         */
        protected Vector beta;

        /**
         * Default constructor.
         */
        public StateObservationLikelihoodTask()
        {
        }

        public Vector call()
            throws Exception
        {
            return ParallelHiddenMarkovModel.computeStateObservationLikelihood(
                this.alpha, this.beta, 1.0 );
        }

    }

    /**
     * Calls the normalizeTransitionMatrix method.
     */
    protected static class NormalizeTransitionTask
        extends AbstractCloneableSerializable
        implements Callable
    {

        /**
         * Matrix to normalize.
         */
        private Matrix A;

        /**
         * Column to normalize.
         */
        private int j;

        /**
         * Default constructor.
         */
        public NormalizeTransitionTask()
        {
        }

        public Void call()
        {
            ParallelHiddenMarkovModel.normalizeTransitionMatrix(
                this.A, this.j);
            return null;
        }

    }

    /**
     * Calls the computeTransitions method.
     */
    protected static class ComputeTransitionsTask
        extends AbstractCloneableSerializable
        implements Callable
    {

        /**
         * Alpha at time n.
         */
        Vector alphan;

        /**
         * Alpha at time n.
         */
        Vector betanp1;

        /**
         * b at time n+1.
         */
        Vector bnp1;

        /**
         * Default constructor.
         */
        public ComputeTransitionsTask()
        {
        }

        public Matrix call()
        {
            return ParallelHiddenMarkovModel.computeTransitions(
                this.alphan, this.betanp1, this.bnp1 );
        }

    }


    /**
     * Computes the most-likely "from state" for the given "destination state"
     * and the given deltas.
     */
    protected class ViterbiTask
        extends AbstractCloneableSerializable
        implements Callable>
    {

        /**
         * Destination state for the Viterbi Recursion.
         */
        int destinationState;

        /**
         * Previous value of the Viterbi Recursion.
         */
        Vector delta;

        /**
         * Default constructor
         */
        ViterbiTask()
        {
        }

        public WeightedValue call()
            throws Exception
        {
            return ParallelHiddenMarkovModel.this.findMostLikelyState(
                this.destinationState, this.delta);
         }
        
    }

    /**
     * Computes the log-likelihood of a particular data sequence
     */
    protected class LogLikelihoodTask
        extends AbstractCloneableSerializable
        implements Callable
    {

        /**
         * Data to compute the log-likelihood of
         */
        protected Collection data;

        /**
         * Creates a new instance of LogLikelihoodTask
         * @param data
         * Data to compute the log-likelihood of
         */
        public LogLikelihoodTask(
            Collection data)
        {
            this.data = data;
        }

        public Double call()
            throws Exception
        {
            return computeObservationLogLikelihood( this.data );
        }

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy