All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.hmm.ParallelBaumWelchAlgorithm Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                ParallelBaumWelchAlgorithm.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright Feb 3, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 *
 */

package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * A Parallelized implementation of some of the methods of the
 * Baum-Welch Algorithm.
 * @param  Type of Observations handled by the HMM.
 * @author Kevin R. Dixon
 * @since 3.0
 */
@PublicationReference(
    author="William Turin",
    title="Unidirectional and Parallel Baum–Welch Algorithms",
    type=PublicationType.Journal,
    publication="IEEE Transactions on Speech and Audio Processing",
    year=1998,
    url="http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=00725318"
)
public class ParallelBaumWelchAlgorithm
    extends BaumWelchAlgorithm
    implements ParallelAlgorithm
{

    /**
     * Thread pool used for parallelization.
     */
    transient private ThreadPoolExecutor threadPool;

    /**
     * Tasks for re-estimating the PDFs.
     */
    transient protected ArrayList> distributionEstimatorTasks;

    /**
     * Default constructor
     */
    public ParallelBaumWelchAlgorithm()
    {
        super();
    }

    /**
     * Creates a new instance of ParallelBaumWelchAlgorithm
     * @param initialGuess
     * Initial guess for the iterations.
     * @param distributionLearner
     * Learner for the Distribution Functions of the HMM.
     * @param reestimateInitialProbabilities
     * Flag to re-estimate the initial probability Vector.
     */
    public ParallelBaumWelchAlgorithm(
        HiddenMarkovModel initialGuess,
        BatchLearner>,? extends ComputableDistribution> distributionLearner,
        boolean reestimateInitialProbabilities )
    {
        super( initialGuess, distributionLearner, reestimateInitialProbabilities );
    }

    public ThreadPoolExecutor getThreadPool()
    {
        if( this.threadPool == null )
        {
            this.setThreadPool( ParallelUtil.createThreadPool() );
        }
        return this.threadPool;
    }

    public void setThreadPool(
        ThreadPoolExecutor threadPool)
    {
        this.threadPool = threadPool;
    }

    public int getNumThreads()
    {
        return ParallelUtil.getNumThreads(this);
    }

    @Override
    protected boolean initializeAlgorithm()
    {
        this.distributionEstimatorTasks = this.createDistributionEstimatorTasks();
        return super.initializeAlgorithm();
    }

    @Override
    protected ArrayList> updateProbabilityFunctions(
        ArrayList sequenceGammas)
    {
        final int N = this.getResult().getNumStates();
        for( int i = 0; i < N; i++ )
        {
            this.distributionEstimatorTasks.get(i).setGammas( sequenceGammas );
        }

        ArrayList> fs = null;
        try
        {
            fs = ParallelUtil.executeInParallel(
                this.distributionEstimatorTasks, this.getThreadPool() );
        }
        catch (Exception e)
        {
            throw new RuntimeException( e );
        }

        return fs;
    }

    /**
     * Creates the DistributionEstimatorTask
     * @return
     * DistributionEstimatorTask.
     */
    protected ArrayList> createDistributionEstimatorTasks()
    {
        final int N = this.initialGuess.getNumStates();
        ArrayList> tasks =
            new ArrayList>( N );
        for( int i = 0; i < N; i++ )
        {
            tasks.add( new DistributionEstimatorTask(
                this.data, this.distributionLearner, i ) );
        }
        return tasks;
    }

    /**
     * Re-estimates the PDF from the gammas.
     * @param  Type of Observations.
     */
    protected static class DistributionEstimatorTask
        extends AbstractCloneableSerializable
        implements Callable>
    {

        /**
         * Weighted values for the PDF estimator.
         */
        protected ArrayList> weightedValues;

        /**
         * My copy of the PDF estimator.
         */
        protected BatchLearner>,? extends ComputableDistribution> distributionLearner;

        /**
         * Gammas used to weight the learner samples.
         */
        private ArrayList gammas;

        /**
         * Index into the gammas to pull the weights.
         */
        protected int index;

        /**
         * Creates an instance of DistributionEstimatorTask
         * @param data
         * Data to stuff into the weightedValues
         * @param distributionLearner
         * Distribution Learner
         * @param index
         * Index into the gammas
         */
        public DistributionEstimatorTask(
            Collection data,
            BatchLearner>,? extends ComputableDistribution> distributionLearner,
            int index )
        {
            this.index = index;
            this.distributionLearner = distributionLearner;
            this.weightedValues =
                new ArrayList>( data.size() );
            for( ObservationType v : data )
            {
                this.weightedValues.add( new DefaultWeightedValue( v ) );
            }
        }

        /**
         * Sets the gamma samples pointer.
         * @param gammas
         * Gammas
         */
        public void setGammas(
            ArrayList gammas )
        {
            this.gammas = gammas;
        }

        public ProbabilityFunction call()
        {
            final int N = this.gammas.size();
            for( int n = 0; n < N; n++ )
            {
                this.weightedValues.get(n).setWeight(
                    this.gammas.get(n).getElement( this.index ) );
            }
            return this.distributionLearner.learn(this.weightedValues).getProbabilityFunction();
        }

    }



}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy