gov.sandia.cognition.learning.algorithm.hmm.ParallelBaumWelchAlgorithm Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: ParallelBaumWelchAlgorithm.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Feb 3, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.hmm;
import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
/**
* A Parallelized implementation of some of the methods of the
* Baum-Welch Algorithm.
* @param Type of Observations handled by the HMM.
* @author Kevin R. Dixon
* @since 3.0
*/
@PublicationReference(
author="William Turin",
title="Unidirectional and Parallel Baum–Welch Algorithms",
type=PublicationType.Journal,
publication="IEEE Transactions on Speech and Audio Processing",
year=1998,
url="http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=00725318"
)
public class ParallelBaumWelchAlgorithm
extends BaumWelchAlgorithm
implements ParallelAlgorithm
{
/**
* Thread pool used for parallelization.
*/
transient private ThreadPoolExecutor threadPool;
/**
* Tasks for re-estimating the PDFs.
*/
transient protected ArrayList> distributionEstimatorTasks;
/**
* Default constructor
*/
public ParallelBaumWelchAlgorithm()
{
super();
}
/**
* Creates a new instance of ParallelBaumWelchAlgorithm
* @param initialGuess
* Initial guess for the iterations.
* @param distributionLearner
* Learner for the Distribution Functions of the HMM.
* @param reestimateInitialProbabilities
* Flag to re-estimate the initial probability Vector.
*/
public ParallelBaumWelchAlgorithm(
HiddenMarkovModel initialGuess,
BatchLearner>,? extends ComputableDistribution> distributionLearner,
boolean reestimateInitialProbabilities )
{
super( initialGuess, distributionLearner, reestimateInitialProbabilities );
}
public ThreadPoolExecutor getThreadPool()
{
if( this.threadPool == null )
{
this.setThreadPool( ParallelUtil.createThreadPool() );
}
return this.threadPool;
}
public void setThreadPool(
ThreadPoolExecutor threadPool)
{
this.threadPool = threadPool;
}
public int getNumThreads()
{
return ParallelUtil.getNumThreads(this);
}
@Override
protected boolean initializeAlgorithm()
{
this.distributionEstimatorTasks = this.createDistributionEstimatorTasks();
return super.initializeAlgorithm();
}
@Override
protected ArrayList> updateProbabilityFunctions(
ArrayList sequenceGammas)
{
final int N = this.getResult().getNumStates();
for( int i = 0; i < N; i++ )
{
this.distributionEstimatorTasks.get(i).setGammas( sequenceGammas );
}
ArrayList> fs = null;
try
{
fs = ParallelUtil.executeInParallel(
this.distributionEstimatorTasks, this.getThreadPool() );
}
catch (Exception e)
{
throw new RuntimeException( e );
}
return fs;
}
/**
* Creates the DistributionEstimatorTask
* @return
* DistributionEstimatorTask.
*/
protected ArrayList> createDistributionEstimatorTasks()
{
final int N = this.initialGuess.getNumStates();
ArrayList> tasks =
new ArrayList>( N );
for( int i = 0; i < N; i++ )
{
tasks.add( new DistributionEstimatorTask(
this.data, this.distributionLearner, i ) );
}
return tasks;
}
/**
* Re-estimates the PDF from the gammas.
* @param Type of Observations.
*/
protected static class DistributionEstimatorTask
extends AbstractCloneableSerializable
implements Callable>
{
/**
* Weighted values for the PDF estimator.
*/
protected ArrayList> weightedValues;
/**
* My copy of the PDF estimator.
*/
protected BatchLearner>,? extends ComputableDistribution> distributionLearner;
/**
* Gammas used to weight the learner samples.
*/
private ArrayList gammas;
/**
* Index into the gammas to pull the weights.
*/
protected int index;
/**
* Creates an instance of DistributionEstimatorTask
* @param data
* Data to stuff into the weightedValues
* @param distributionLearner
* Distribution Learner
* @param index
* Index into the gammas
*/
public DistributionEstimatorTask(
Collection extends ObservationType> data,
BatchLearner>,? extends ComputableDistribution> distributionLearner,
int index )
{
this.index = index;
this.distributionLearner = distributionLearner;
this.weightedValues =
new ArrayList>( data.size() );
for( ObservationType v : data )
{
this.weightedValues.add( new DefaultWeightedValue( v ) );
}
}
/**
* Sets the gamma samples pointer.
* @param gammas
* Gammas
*/
public void setGammas(
ArrayList gammas )
{
this.gammas = gammas;
}
public ProbabilityFunction call()
{
final int N = this.gammas.size();
for( int n = 0; n < N; n++ )
{
this.weightedValues.get(n).setWeight(
this.gammas.get(n).getElement( this.index ) );
}
return this.distributionLearner.learn(this.weightedValues).getProbabilityFunction();
}
}
}