All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.text.topic.ParallelLatentDirichletAllocationVectorGibbsSampler Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                ParallelLatentDirichletAllocationVectorGibbsSampler.java
 * Authors:             Justin Basilico, Jason Shepherd
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright October 22, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 * See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.text.topic;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorEntry;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import java.util.ArrayList;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * A parallel implementation of {@link LatentDirichletAllocationVectorGibbsSampler}.
 * It runs the sampling for the different documents using a thread pool.
 *
 * @author  Jason Shepherd
 * @since   3.3.2
 */
public class ParallelLatentDirichletAllocationVectorGibbsSampler
    extends LatentDirichletAllocationVectorGibbsSampler
    implements ParallelAlgorithm
{    
    /**
     * Thread pool used for parallelization.
     */
    private transient ThreadPoolExecutor threadPool;
    
    /**
     * Creates a new {@code ParallelLatentDirichletAllocationVectorGibbsSampler} with
     * default parameters.
     */
    public ParallelLatentDirichletAllocationVectorGibbsSampler()
    {
        super();
    }

    /**
     * Creates a new {@code ParallelLatentDirichletAllocationVectorGibbsSampler} with
     * the given parameters.
     *
     * @param   topicCount
     *      The number of topics for the algorithm to create. Must be positive.
     * @param   alpha
     *      The alpha parameter controlling the Dirichlet distribution for the
     *      document-topic probabilities. It acts as a prior weight assigned to
     *      the document-topic counts. Must be positive.
     * @param   beta
     *      The beta parameter controlling the Dirichlet distribution for the
     *      topic-term probabilities. It acts as a prior weight assigned to
     *      the topic-term counts.
     * @param   maxIterations
     *      The maximum number of iterations to run for. Must be positive.
     * @param   burnInIterations
     *      The number of burn-in iterations for the Markov Chain Monte Carlo
     *      algorithm to run before sampling begins.
     * @param   iterationsPerSample
     *      The number of iterations to the Markov Chain Monte Carlo algorithm
     *      between samples (after the burn-in iterations).
     * @param   random
     *      The random number generator to use.
     */
    public ParallelLatentDirichletAllocationVectorGibbsSampler(
        final int topicCount,
        final double alpha,
        final double beta,
        final int maxIterations,
        final int burnInIterations,
        final int iterationsPerSample,
        final Random random)
    {
        super(topicCount, alpha, beta, maxIterations, burnInIterations, iterationsPerSample, random);
    }

    @Override
    protected boolean step()
    {
        // We create this array to be used as a workspace to avoid having to
        // recreate it inside the sampling function.
        int document = 0;
        int occurrence = 0;
        
        //Create the task list:
        ArrayList samplingTaskList = new ArrayList(this.documentCount);
        for (Vectorizable m : this.data ) 
        {
            Vector av = m.convertToVector();
            samplingTaskList.add( new DocumentSampleTask( av, document, occurrence));
            document++;
            occurrence += av.norm1();
        }
        
        try
        {
            ParallelUtil.executeInParallel(samplingTaskList, this.getThreadPool());
        }
        catch( Exception ex )
        {
            throw new RuntimeException( ex );
        }

        if (this.iteration >= this.burnInIterations
            && (this.iteration - this.burnInIterations)
                % this.iterationsPerSample == 0)
        {
            this.readParameters();
        }

        return true;
    }

    @Override
    protected void cleanupAlgorithm()
    {
        this.getThreadPool().shutdown();
        super.cleanupAlgorithm();
    }

    /**
     * A document sampling task
     */
    protected class DocumentSampleTask
        extends AbstractCloneableSerializable
        implements Callable
    {
        /**
         * The term vector for a document.
         */
        Vector vector;
        
        /**
         * The document address
         */
        int document;
        
        /**
         * The occurrence address
         */
        int occurrence;

        /**
         * Creates a new instance of DocumentSampleTask
         * @param v - term frequency vector for a single document
         * @param doc - document address in sample arrays
         * @param occ - occurrence address in sample occurrence array
         */
        public DocumentSampleTask(Vector v, int doc, int occ )
        {
            super();
            this.vector = v;
            this.document = doc;
            this.occurrence = occ;
        }

        @Override
        public Boolean call() throws Exception
        {
            final double[] topicCumulativeProportions = new double[ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicCount];
            for (VectorEntry v : this.vector)
            {
                final int term = v.getIndex();
                final int count = (int) v.getValue();

                for (int i = 1; i <= count; i++)
                {
                    // Get the old topic assignment.
                    final int oldTopic = ParallelLatentDirichletAllocationVectorGibbsSampler.this.occurrenceTopicAssignments[this.occurrence];

                    // Remove the topic assignment .
                    ParallelLatentDirichletAllocationVectorGibbsSampler.this.documentTopicCount[document][oldTopic] -= 1;
                    ParallelLatentDirichletAllocationVectorGibbsSampler.this.documentTopicSum[document] -= 1;
                    
                    synchronized(ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermCount)
                    {
                        ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermCount[oldTopic][term] -= 1;
                    }
                    synchronized(ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermSum)
                    {
                        ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermSum[oldTopic] -= 1;
                    }

                    // Sample a new topic.
                    final int newTopic = ParallelLatentDirichletAllocationVectorGibbsSampler.this.sampleTopic(document, term, topicCumulativeProportions);

                    // Add the new topic assignment.
                    ParallelLatentDirichletAllocationVectorGibbsSampler.this.occurrenceTopicAssignments[this.occurrence] = newTopic;
                    ParallelLatentDirichletAllocationVectorGibbsSampler.this.documentTopicCount[document][newTopic] += 1;
                    ParallelLatentDirichletAllocationVectorGibbsSampler.this.documentTopicSum[document] += 1;
                    
                    synchronized(ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermCount)
                    {
                        ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermCount[newTopic][term] += 1;
                    }
                    synchronized(ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermSum)
                    {
                        ParallelLatentDirichletAllocationVectorGibbsSampler.this.topicTermSum[newTopic] += 1;
                    }
                    
                    this.occurrence++;
                }
            }

            return true;
        }
    }

    @Override
    public ThreadPoolExecutor getThreadPool()
    {
        if (this.threadPool == null)
        {
            this.setThreadPool(ParallelUtil.createThreadPool());
        }

        return this.threadPool;
    }

    @Override
    public void setThreadPool(final ThreadPoolExecutor threadPool)
    {
        this.threadPool = threadPool;
    }

    @Override
    public int getNumThreads()
    {
        return ParallelUtil.getNumThreads(this);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy