gov.sandia.cognition.statistics.bayesian.ParallelDirichletProcessMixtureModel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: ParallelDirichletProcessMixtureModel.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright May 3, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.bayesian;
import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel.DPMMLogConditional;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
/**
* A Parallelized version of vanilla Dirichlet Process Mixture Model learning.
* In particular, this class parallelizes the assignment of observations to
* clusters and the Gibbs sampling updating of clusters from their constituent
* observations.
* @param
* Type of observations handled by the algorithm
* @author Kevin R. Dixon
* @since 3.0
*/
public class ParallelDirichletProcessMixtureModel
extends DirichletProcessMixtureModel
implements ParallelAlgorithm
{
/**
* Thread pool used for parallelization.
*/
private transient ThreadPoolExecutor threadPool;
/**
* Creates a new instance of ParallelDirichletProcessMixtureModel
*/
public ParallelDirichletProcessMixtureModel()
{
super();
}
public int getNumThreads()
{
return ParallelUtil.getNumThreads(this);
}
public ThreadPoolExecutor getThreadPool()
{
if (this.threadPool == null)
{
this.setThreadPool(ParallelUtil.createThreadPool());
}
return this.threadPool;
}
public void setThreadPool(
final ThreadPoolExecutor threadPool)
{
this.threadPool = threadPool;
}
/**
* Tasks that assign observations to clusters
*/
transient protected ArrayList assignmentTasks;
@Override
protected ArrayList> assignObservationsToClusters(
int K,
DPMMLogConditional logConditional )
{
if( this.assignmentTasks == null )
{
ArrayList extends ObservationType> dataArray =
CollectionUtil.asArrayList(this.data );
final int N = dataArray.size();
final int numThreads = this.getNumThreads();
this.assignmentTasks = new ArrayList( numThreads );
int numPerTask = N/numThreads;
int endIndex = 0;
for( int n = 0; n < numThreads-1; n++ )
{
int startIndex = endIndex;
endIndex += numPerTask;
this.assignmentTasks.add( new ObservationAssignmentTask(
dataArray.subList(startIndex, endIndex) ) );
}
this.assignmentTasks.add( new ObservationAssignmentTask(
dataArray.subList(endIndex,N) ) );
}
ArrayList results;
try
{
results = ParallelUtil.executeInParallel(
this.assignmentTasks, this.getThreadPool() );
}
catch( Exception ex )
{
throw new RuntimeException( ex );
}
// This assigns observations to each of the K clusters, plus the
// as-yet-uncreated new cluster
ArrayList> clusterAssignments =
new ArrayList>( K+1 );
for( int k = 0; k < K+1; k++ )
{
clusterAssignments.add( new LinkedList() );
}
for( int n = 0; n < results.size(); n++ )
{
logConditional.logConditional +=
results.get(n).logConditional.logConditional;
ArrayList assignments = results.get(n).assignments;
int index = 0;
for( ObservationType observation : this.assignmentTasks.get(n).observations )
{
int assignment = assignments.get(index);
clusterAssignments.get(assignment).add( observation );
index++;
}
}
return clusterAssignments;
}
/**
* Assignments from the DPMM
*/
public static class DPMMAssignments
{
/**
* List of assignment indices
*/
protected ArrayList assignments;
/**
* Log conditional likelihood of the previous sample
*/
protected DPMMLogConditional logConditional;
/**
* Constructor
* @param assignments
* List of assignment indices
* @param logConditional
* Log conditional likelihood of the previous sample
*/
public DPMMAssignments(
ArrayList assignments,
DPMMLogConditional logConditional)
{
this.assignments = assignments;
this.logConditional = logConditional;
}
}
/**
* Task that assign observations to cluster indices
*/
protected class ObservationAssignmentTask
extends AbstractCloneableSerializable
implements Callable
{
/**
* Observations to assign
*/
private Collection extends ObservationType> observations;
/**
* Weights that are re-used
*/
private double[] weights;
/**
* Resulting assignments
*/
private ArrayList assignments;
/**
* Log conditional of the previous sample
*/
private DPMMLogConditional logConditional;
/**
* Creates a new instance of ObservationAssignmentTask
* @param observations
* Observations to assign
*/
public ObservationAssignmentTask(
Collection extends ObservationType> observations )
{
this.weights = null;
this.observations = observations;
}
public DPMMAssignments call()
throws Exception
{
final int K = currentParameter.getNumClusters();
if( (this.weights == null) ||
(this.weights.length != K+1) )
{
this.weights = new double[ K+1 ];
}
if( this.assignments == null )
{
this.assignments = new ArrayList(
this.observations.size() );
for( int n = 0; n < this.observations.size(); n++ )
{
this.assignments.add( null );
}
}
this.logConditional = new DPMMLogConditional();
int index = 0;
for( ObservationType observation : this.observations )
{
int clusterAssignment = assignObservationToCluster(
observation, this.weights, this.logConditional );
this.assignments.set( index, clusterAssignment );
index++;
}
return new DPMMAssignments(this.assignments, this.logConditional);
}
}
/**
* Tasks that update the values of the clusters for Gibbs sampling
*/
transient protected ArrayList clusterUpdaterTasks;
@Override
protected ArrayList> updateClusters(
ArrayList> clusterAssignments)
{
final int Kp1 = clusterAssignments.size();
if( (this.clusterUpdaterTasks == null) ||
(this.clusterUpdaterTasks.size() != Kp1) )
{
this.clusterUpdaterTasks = new ArrayList( Kp1 );
for( int k = 0; k < Kp1; k++ )
{
this.clusterUpdaterTasks.add( new ClusterUpdaterTask() );
}
}
for( int k = 0; k < Kp1; k++ )
{
Collection observations = clusterAssignments.get(k);
if( observations.size() <= 1 )
{
observations = null;
}
this.clusterUpdaterTasks.get(k).observations = observations;
}
ArrayList> clusters = null;
try
{
clusters = ParallelUtil.executeInParallel(
this.clusterUpdaterTasks, this.getThreadPool() );
}
catch (Exception e)
{
throw new RuntimeException(e);
}
ArrayList> results =
new ArrayList>( Kp1 );
for( int k = 0; k < Kp1; k++ )
{
DPMMCluster cluster = clusters.get(k);
if( cluster != null )
{
results.add( cluster );
}
}
return results;
}
/**
* Tasks that update the values of the clusters for Gibbs sampling
*/
protected class ClusterUpdaterTask
extends AbstractCloneableSerializable
implements Callable>
{
/**
* Observations that comprise the cluster
*/
Collection observations;
/**
* Local clone of the updater, needed to ensure thread safety
*/
Updater localUpdater;
/**
* Creates a new instance of ClusterUpdaterTask
*/
public ClusterUpdaterTask()
{
this.localUpdater = ObjectUtil.cloneSafe( updater );
}
public DPMMCluster call()
{
return createCluster(this.observations, this.localUpdater );
}
}
}