gov.sandia.cognition.text.topic.ParallelLatentDirichletAllocationVectorGibbsSampler Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: ParallelLatentDirichletAllocationVectorGibbsSampler.java
* Authors: Justin Basilico, Jason Shepherd
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright October 22, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.text.topic;
import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorEntry;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import java.util.ArrayList;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
/**
* A parallel implementation of {@link LatentDirichletAllocationVectorGibbsSampler}.
* It runs the sampling for the different documents using a thread pool.
*
* @author Jason Shepherd
* @since 3.3.2
*/
public class ParallelLatentDirichletAllocationVectorGibbsSampler
extends LatentDirichletAllocationVectorGibbsSampler
implements ParallelAlgorithm
{
/**
* Thread pool used for parallelization.
*/
private transient ThreadPoolExecutor threadPool;
/**
* Creates a new {@code ParallelLatentDirichletAllocationVectorGibbsSampler} with
* default parameters.
*/
public ParallelLatentDirichletAllocationVectorGibbsSampler()
{
super();
}
/**
* Creates a new {@code ParallelLatentDirichletAllocationVectorGibbsSampler} with
* the given parameters.
*
* @param topicCount
* The number of topics for the algorithm to create. Must be positive.
* @param alpha
* The alpha parameter controlling the Dirichlet distribution for the
* document-topic probabilities. It acts as a prior weight assigned to
* the document-topic counts. Must be positive.
* @param beta
* The beta parameter controlling the Dirichlet distribution for the
* topic-term probabilities. It acts as a prior weight assigned to
* the topic-term counts.
* @param maxIterations
* The maximum number of iterations to run for. Must be positive.
* @param burnInIterations
* The number of burn-in iterations for the Markov Chain Monte Carlo
* algorithm to run before sampling begins.
* @param iterationsPerSample
* The number of iterations to the Markov Chain Monte Carlo algorithm
* between samples (after the burn-in iterations).
* @param random
* The random number generator to use.
*/
public ParallelLatentDirichletAllocationVectorGibbsSampler(
final int topicCount,
final double alpha,
final double beta,
final int maxIterations,
final int burnInIterations,
final int iterationsPerSample,
final Random random)
{
super(topicCount, alpha, beta, maxIterations, burnInIterations, iterationsPerSample, random);
}
@Override
protected boolean step()
{
// We create this array to be used as a workspace to avoid having to
// recreate it inside the sampling function.
int document = 0;
int occurrence = 0;
//Create the task list:
ArrayList samplingTaskList = new ArrayList(this.documentCount);
for (Vectorizable m : this.data )
{
Vector av = m.convertToVector();
samplingTaskList.add( new DocumentSampleTask( av, document, occurrence));
document++;
occurrence += av.norm1();
}
try
{
ParallelUtil.executeInParallel(samplingTaskList, this.getThreadPool());
}
catch( Exception ex )
{
throw new RuntimeException( ex );
}
if (this.iteration >= this.burnInIterations
&& (this.iteration - this.burnInIterations)
% this.iterationsPerSample == 0)
{
this.readParameters();
}
return true;
}
@Override
protected void cleanupAlgorithm()
{
this.getThreadPool().shutdown();
super.cleanupAlgorithm();
}
/**
* A document sampling task
*/
protected class DocumentSampleTask
extends AbstractCloneableSerializable
implements Callable
{
/**
* The term vector for a document.
*/
Vector vector;
/**
* The document address
*/
int document;
/**
* The occurrence address
*/
int occurrence;
/**
* Creates a new instance of DocumentSampleTask
* @param v - term frequency vector for a single document
* @param doc - document address in sample arrays
* @param occ - occurrence address in sample occurrence array
*/
public DocumentSampleTask(Vector v, int doc, int occ )
{
super();
this.vector = v;
this.document = doc;
this.occurrence = occ;
}
@Override
public Boolean call() throws Exception
{
final double[] topicCumulativeProportions = new double[ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicCount];
for (VectorEntry v : this.vector)
{
final int term = v.getIndex();
final int count = (int) v.getValue();
for (int i = 1; i <= count; i++)
{
// Get the old topic assignment.
final int oldTopic = ParallelLatentDirichletAllocationVectorGibbsSampler.this.occurrenceTopicAssignments[this.occurrence];
// Remove the topic assignment .
ParallelLatentDirichletAllocationVectorGibbsSampler.this.documentTopicCount[document][oldTopic] -= 1;
ParallelLatentDirichletAllocationVectorGibbsSampler.this.documentTopicSum[document] -= 1;
synchronized(ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermCount)
{
ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermCount[oldTopic][term] -= 1;
}
synchronized(ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermSum)
{
ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermSum[oldTopic] -= 1;
}
// Sample a new topic.
final int newTopic = ParallelLatentDirichletAllocationVectorGibbsSampler.this.sampleTopic(document, term, topicCumulativeProportions);
// Add the new topic assignment.
ParallelLatentDirichletAllocationVectorGibbsSampler.this.occurrenceTopicAssignments[this.occurrence] = newTopic;
ParallelLatentDirichletAllocationVectorGibbsSampler.this.documentTopicCount[document][newTopic] += 1;
ParallelLatentDirichletAllocationVectorGibbsSampler.this.documentTopicSum[document] += 1;
synchronized(ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermCount)
{
ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermCount[newTopic][term] += 1;
}
synchronized(ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermSum)
{
ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermSum[newTopic] += 1;
}
this.occurrence++;
}
}
return true;
}
}
@Override
public ThreadPoolExecutor getThreadPool()
{
if (this.threadPool == null)
{
this.setThreadPool(ParallelUtil.createThreadPool());
}
return this.threadPool;
}
@Override
public void setThreadPool(final ThreadPoolExecutor threadPool)
{
this.threadPool = threadPool;
}
@Override
public int getNumThreads()
{
return ParallelUtil.getNumThreads(this);
}
}