![JAR search and dependency download from the Maven repository](/logo.png)
gov.sandia.cognition.learning.algorithm.hmm.ParallelHiddenMarkovModel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: ParallelHiddenMarkovModel.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Feb 4, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.hmm;
import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
/**
* A Hidden Markov Model with parallelized processing.
* @param Type of Observations handled by the HMM.
* @author Kevin R. Dixon
* @since 3.0
*/
@PublicationReference(
author="William Turin",
title="Unidirectional and Parallel Baum–Welch Algorithms",
type=PublicationType.Journal,
publication="IEEE Transactions on Speech and Audio Processing",
year=1998,
url="http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=00725318"
)
public class ParallelHiddenMarkovModel
extends HiddenMarkovModel
implements ParallelAlgorithm
{
/**
* Thread pool used for parallelization.
*/
private transient ThreadPoolExecutor threadPool;
/**
* Creates a new instance of ParallelHiddenMarkovModel
*/
public ParallelHiddenMarkovModel()
{
super();
}
/**
* Creates a new instance of ParallelHiddenMarkovModel
* @param numStates
* Number of states in the HMM.
*/
public ParallelHiddenMarkovModel(
int numStates )
{
super( numStates );
}
/**
* Creates a new instance of ParallelHiddenMarkovModel
* @param initialProbability
* Initial probability Vector over the states. Each entry must be
* nonnegative and the Vector must sum to 1.
* @param transitionProbability
* Transition probability matrix. The entry (i,j) is the probability
* of transition from state "j" to state "i". As a corollary, all
* entries in the Matrix must be nonnegative and the
* columns of the Matrix must sum to 1.
* @param emissionFunctions
* The PDFs that emit symbols from each state.
*/
public ParallelHiddenMarkovModel(
Vector initialProbability,
Matrix transitionProbability,
Collection extends ComputableDistribution> emissionFunctions )
{
super( initialProbability, transitionProbability, emissionFunctions );
}
/**\
* Creates a new {@code ParallelHiddenMarkovModel} from another
* {@code HiddenMarkovModel}.
*
* @param other
* The other hidden Markov model to copy.
*/
public ParallelHiddenMarkovModel(
final HiddenMarkovModel other)
{
this(ObjectUtil.cloneSafe(other.getInitialProbability()),
ObjectUtil.cloneSafe(other.getTransitionProbability()),
ObjectUtil.cloneSmartElementsAsArrayList(
other.getEmissionFunctions()));
}
public ThreadPoolExecutor getThreadPool()
{
if (this.threadPool == null)
{
this.setThreadPool(ParallelUtil.createThreadPool());
}
return this.threadPool;
}
public void setThreadPool(
ThreadPoolExecutor threadPool)
{
this.threadPool = threadPool;
}
public int getNumThreads()
{
return ParallelUtil.getNumThreads(this);
}
/**
* Observation likelihood tasks
*/
transient protected ArrayList> observationLikelihoodTasks;
@Override
public double computeMultipleObservationLogLikelihood(
Collection extends Collection extends ObservationType>> sequences)
{
ArrayList tasks =
new ArrayList( sequences.size() );
for( Collection extends ObservationType> sequence : sequences )
{
tasks.add( new LogLikelihoodTask( sequence ) );
}
ArrayList results = null;
try
{
results = ParallelUtil.executeInParallel(tasks, this.getThreadPool());
}
catch (Exception e)
{
throw new RuntimeException(e);
}
double logSum = 0.0;
for( int i = 0; i < results.size(); i++ )
{
logSum += results.get(i);
}
return logSum;
}
/**
* ComputeTransitionsTasks.
*/
transient protected ArrayList computeTransitionTasks;
@Override
protected Matrix computeTransitions(
ArrayList> alphas,
ArrayList> betas,
ArrayList b)
{
final int N = alphas.size();
if( this.computeTransitionTasks == null )
{
this.computeTransitionTasks =
new ArrayList( N-1 );
}
// Make sure it's N-1
this.computeTransitionTasks.ensureCapacity(N-1);
while( this.computeTransitionTasks.size() > N-1 )
{
this.computeTransitionTasks.remove(
this.computeTransitionTasks.size()-1 );
}
while( this.computeTransitionTasks.size() < N-1 )
{
this.computeTransitionTasks.add( new ComputeTransitionsTask() );
}
for( int n = 0; n < N-1; n++ )
{
final ComputeTransitionsTask tn = this.computeTransitionTasks.get(n);
tn.alphan = alphas.get(n).getValue();
tn.betanp1 = betas.get(n+1).getValue();
tn.bnp1 = b.get(n+1);
}
RingAccumulator counts = new RingAccumulator();
Matrix A = null;
try
{
Collection> futures =
this.getThreadPool().invokeAll(this.computeTransitionTasks);
for( Future f : futures )
{
counts.accumulate( f.get() );
}
A = counts.getSum();
A.dotTimesEquals(this.getTransitionProbability());
normalizeTransitionMatrix(A);
}
catch (Exception ex)
{
throw new RuntimeException( ex );
}
return A;
}
/**
* NormalizeTransitionTasks.
*/
transient protected ArrayList normalizeTransitionTasks;
@Override
protected void normalizeTransitionMatrix(
Matrix A)
{
final int k = A.getNumColumns();
if( this.normalizeTransitionTasks == null )
{
this.normalizeTransitionTasks =
new ArrayList( k );
}
this.normalizeTransitionTasks.ensureCapacity(k);
while( this.normalizeTransitionTasks.size() > k )
{
this.normalizeTransitionTasks.remove(
this.normalizeTransitionTasks.size() - 1 );
}
while( this.normalizeTransitionTasks.size() < k )
{
this.normalizeTransitionTasks.add( new NormalizeTransitionTask() );
}
for( int j = 0; j < k; j++ )
{
NormalizeTransitionTask task = this.normalizeTransitionTasks.get(j);
task.j = j;
task.A = A;
}
try
{
ParallelUtil.executeInParallel(
this.normalizeTransitionTasks, this.getThreadPool());
}
catch (Exception ex)
{
throw new RuntimeException( ex );
}
}
/**
* StateObservationLikelihoodTasks
*/
transient protected ArrayList stateObservationLikelihoodTasks;
@Override
protected ArrayList computeStateObservationLikelihood(
ArrayList> alphas,
ArrayList> betas,
double scaleFactor )
{
final int N = alphas.size();
if( this.stateObservationLikelihoodTasks == null )
{
this.stateObservationLikelihoodTasks =
new ArrayList( N );
}
this.stateObservationLikelihoodTasks.ensureCapacity(N);
while( this.stateObservationLikelihoodTasks.size() > N )
{
this.stateObservationLikelihoodTasks.remove(
this.stateObservationLikelihoodTasks.size()-1 );
}
while( this.stateObservationLikelihoodTasks.size() < N )
{
this.stateObservationLikelihoodTasks.add(
new StateObservationLikelihoodTask() );
}
for( int n = 0; n < N; n++ )
{
StateObservationLikelihoodTask task =
this.stateObservationLikelihoodTasks.get(n);
task.alpha = alphas.get(n).getValue();
task.beta = betas.get(n).getValue();
}
ArrayList gammas = null;
try
{
gammas = ParallelUtil.executeInParallel(
this.stateObservationLikelihoodTasks, this.getThreadPool());
}
catch (Exception e)
{
throw new RuntimeException( e );
}
return gammas;
}
/**
* Viterbi tasks.
*/
transient protected ArrayList viterbiTasks;
@Override
protected Pair computeViterbiRecursion(
Vector delta,
Vector bn )
{
final int k = this.getNumStates();
if( this.viterbiTasks == null )
{
this.viterbiTasks = new ArrayList( k );
}
this.viterbiTasks.ensureCapacity(k);
while( this.viterbiTasks.size() > k )
{
this.viterbiTasks.remove(
this.viterbiTasks.size()-1 );
}
while( this.viterbiTasks.size() < k )
{
this.viterbiTasks.add( new ViterbiTask() );
}
for( int i = 0; i < k; i++ )
{
final ViterbiTask task = this.viterbiTasks.get(i);
task.destinationState = i;
task.delta = delta;
}
ArrayList> results;
try
{
results = ParallelUtil.executeInParallel(
this.viterbiTasks, this.getThreadPool() );
}
catch (Exception e)
{
throw new RuntimeException( e );
}
int[] psis = new int[ k ];
Vector nextDelta = VectorFactory.getDefault().createVector(k);
for( int i = 0; i < k; i++ )
{
WeightedValue value = results.get(i);
psis[i] = value.getValue();
nextDelta.setElement(i, value.getWeight() );
}
nextDelta.dotTimesEquals(bn);
nextDelta.scaleEquals( 1.0/nextDelta.norm1() );
return DefaultPair.create( nextDelta, psis );
}
/**
* Calls the computeObservationLikelihoods() method.
* @param Observation type
*/
protected static class ObservationLikelihoodTask
extends AbstractCloneableSerializable
implements Callable
{
/**
* Observations
*/
protected Collection extends ObservationType> observations;
/**
* The PDF.
*/
protected ProbabilityFunction distributionFunction;
/**
* Default constructor.
*/
public ObservationLikelihoodTask()
{
}
public double[] call()
{
final int N = this.observations.size();
double[] b = new double[ N ];
int n = 0;
for( ObservationType observation : this.observations )
{
b[n] = this.distributionFunction.evaluate(observation);
n++;
}
return b;
}
}
/**
* Calls the computeStateObservationLikelihood() method.
*/
protected static class StateObservationLikelihoodTask
extends AbstractCloneableSerializable
implements Callable
{
/**
* Alpha at time n.
*/
protected Vector alpha;
/**
* Beta at time n.
*/
protected Vector beta;
/**
* Default constructor.
*/
public StateObservationLikelihoodTask()
{
}
public Vector call()
throws Exception
{
return ParallelHiddenMarkovModel.computeStateObservationLikelihood(
this.alpha, this.beta, 1.0 );
}
}
/**
* Calls the normalizeTransitionMatrix method.
*/
protected static class NormalizeTransitionTask
extends AbstractCloneableSerializable
implements Callable
{
/**
* Matrix to normalize.
*/
private Matrix A;
/**
* Column to normalize.
*/
private int j;
/**
* Default constructor.
*/
public NormalizeTransitionTask()
{
}
public Void call()
{
ParallelHiddenMarkovModel.normalizeTransitionMatrix(
this.A, this.j);
return null;
}
}
/**
* Calls the computeTransitions method.
*/
protected static class ComputeTransitionsTask
extends AbstractCloneableSerializable
implements Callable
{
/**
* Alpha at time n.
*/
Vector alphan;
/**
* Alpha at time n.
*/
Vector betanp1;
/**
* b at time n+1.
*/
Vector bnp1;
/**
* Default constructor.
*/
public ComputeTransitionsTask()
{
}
public Matrix call()
{
return ParallelHiddenMarkovModel.computeTransitions(
this.alphan, this.betanp1, this.bnp1 );
}
}
/**
* Computes the most-likely "from state" for the given "destination state"
* and the given deltas.
*/
protected class ViterbiTask
extends AbstractCloneableSerializable
implements Callable>
{
/**
* Destination state for the Viterbi Recursion.
*/
int destinationState;
/**
* Previous value of the Viterbi Recursion.
*/
Vector delta;
/**
* Default constructor
*/
ViterbiTask()
{
}
public WeightedValue call()
throws Exception
{
return ParallelHiddenMarkovModel.this.findMostLikelyState(
this.destinationState, this.delta);
}
}
/**
* Computes the log-likelihood of a particular data sequence
*/
protected class LogLikelihoodTask
extends AbstractCloneableSerializable
implements Callable
{
/**
* Data to compute the log-likelihood of
*/
protected Collection extends ObservationType> data;
/**
* Creates a new instance of LogLikelihoodTask
* @param data
* Data to compute the log-likelihood of
*/
public LogLikelihoodTask(
Collection extends ObservationType> data)
{
this.data = data;
}
public Double call()
throws Exception
{
return computeObservationLogLikelihood( this.data );
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy