All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.topics.HierarchicalPAM Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.topics;

import cc.mallet.types.*;
import cc.mallet.util.*;

import java.util.Arrays;
import java.io.*;
import java.text.NumberFormat;

import java.util.logging.*;

/**
 * Hierarchical PAM, where each node in the DAG has a distribution over all topics on the 
 *  next level and one additional "node-specific" topic.
 * @author David Mimno
 */

public class HierarchicalPAM {
	
	protected static Logger logger = MalletLogger.getLogger(HierarchicalPAM.class.getName());

	static CommandOption.String inputFile = new CommandOption.String
		(HierarchicalPAM.class, "input", "FILENAME", true, null,
		 "The filename from which to read the list of training instances.  Use - for stdin.  " +
		 "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);
	
	static CommandOption.String stateFile = new CommandOption.String
		(HierarchicalPAM.class, "output-state", "FILENAME", true, null,
		 "The filename in which to write the Gibbs sampling state after at the end of the iterations.  " +
		 "By default this is null, indicating that no file will be written.", null);

	static CommandOption.Double superTopicBalanceOption = new CommandOption.Double
		(HierarchicalPAM.class, "super-topic-balance", "DECIMAL", true, 1.0,
		 "Weight (in \"words\") of the shared distribution over super-topics, relative to the document-specific distribution", null);
	
	static CommandOption.Double subTopicBalanceOption = new CommandOption.Double
		(HierarchicalPAM.class, "sub-topic-balance", "DECIMAL", true, 1.0,
		 "Weight (in \"words\") of the shared distribution over sub-topics for each super-topic, relative to the document-specific distribution", null);
	
	static CommandOption.Integer numSuperTopicsOption = new CommandOption.Integer
		(HierarchicalPAM.class, "num-super-topics", "INTEGER", true, 10,
		 "The number of super-topics", null);
	
	static CommandOption.Integer numSubTopicsOption = new CommandOption.Integer
		(HierarchicalPAM.class, "num-sub-topics", "INTEGER", true, 20,
		 "The number of sub-topics", null);
	
    public static final int NUM_LEVELS = 3;

    // Constants to determine the level of the output multinomial
    public static final int ROOT_TOPIC = 0;
    public static final int SUPER_TOPIC = 1;
    public static final int SUB_TOPIC = 2;    

    // Parameters
    int numSuperTopics; // Number of topics to be fit
    int numSubTopics;

	// This parameter controls the balance between
	//  the local document counts and the global distribution
	//   over super-topics.
	double superTopicBalance = 1.0;

	// This parameter is the smoothing on that global distribution
	double superTopicSmoothing = 1.0;

	// ... and the same for sub-topics.
	double subTopicBalance = 1.0;
	double subTopicSmoothing = 1.0;

	// Prior on per-topic multinomial distribution over words
    double beta;
    double betaSum;

    // Data
    InstanceList instances;  // the data field of the instances is expected to hold a FeatureSequence
    int numTypes;
    int numTokens;

    // Gibbs sampling state
    //  (these could be shorts, or we could encode both in one int)
    int[][] superTopics; // indexed by 
    int[][] subTopics; // indexed by 
    
    // Per-document state variables
    int[][] superSubCounts; // # of words per 
    int[] superCounts; // # of words per 
    double[] superWeights; // the component of the Gibbs update that depends on super-topics
    double[] subWeights;   // the component of the Gibbs update that depends on sub-topics
    double[][] superSubWeights; // unnormalized sampling distribution
    double[] cumulativeSuperWeights; // a cache of the cumulative weight for each super-topic

	// Document frequencies used for "minimal path" hierarchical Dirichlets
	int[] superTopicDocumentFrequencies;
	int[][] superSubTopicDocumentFrequencies;

	// ... and their sums
	int sumDocumentFrequencies;
	int[] sumSuperTopicDocumentFrequencies;
	// [Note that this last is not the same as superTopicDocumentFrequencies]

	// Cached priors
	double[] superTopicPriorWeights;
	double[][] superSubTopicPriorWeights;
    
    // Per-word type state variables
    int[][] typeTopicCounts; // indexed by 
    int[] tokensPerTopic; // indexed by 

    int[] tokensPerSuperTopic; // indexed by 
    int[][] tokensPerSuperSubTopic;

    Runtime runtime;
    NumberFormat formatter;
    
    public HierarchicalPAM (int superTopics, int subTopics, double superTopicBalance, double subTopicBalance) {
		formatter = NumberFormat.getInstance();
		formatter.setMaximumFractionDigits(5);
	
		this.superTopicBalance = superTopicBalance;
		this.subTopicBalance = subTopicBalance;

		this.numSuperTopics = superTopics;
		this.numSubTopics = subTopics;

		superTopicDocumentFrequencies = new int[numSuperTopics + 1];
		superSubTopicDocumentFrequencies = new int[numSuperTopics + 1][numSubTopics + 1];
		sumSuperTopicDocumentFrequencies = new int[numSuperTopics];

		this.beta = 0.01; // We can't calculate betaSum until we know how many word types...
	
		runtime = Runtime.getRuntime();
    }
    
    public void estimate (InstanceList documents, InstanceList testing,
						  int numIterations, int showTopicsInterval,
						  int outputModelInterval, int optimizeInterval, String outputModelFilename,
						  Randoms r) {

		instances = documents;
		numTypes = instances.getDataAlphabet().size ();
		int numDocs = instances.size();
		superTopics = new int[numDocs][];
		subTopics = new int[numDocs][];

		// Allocate several arrays for use within each document
		//  to cut down memory allocation and garbage collection time

		superSubCounts = new int[numSuperTopics + 1][numSubTopics + 1];
		superCounts = new int[numSuperTopics + 1];
		superWeights = new double[numSuperTopics + 1];
		subWeights = new double[numSubTopics];
		superSubWeights = new double[numSuperTopics + 1][numSubTopics + 1];
		cumulativeSuperWeights = new double[numSuperTopics];

		typeTopicCounts = new int[numTypes][1 + numSuperTopics + numSubTopics];
		tokensPerTopic = new int[1 + numSuperTopics + numSubTopics];

		tokensPerSuperTopic = new int[numSuperTopics + 1];
		tokensPerSuperSubTopic = new int[numSuperTopics + 1][numSubTopics + 1];

		betaSum = beta * numTypes;

		long startTime = System.currentTimeMillis();
	
		int maxTokens = 0;

		// Initialize with random assignments of tokens to topics
		// and finish allocating this.topics and this.tokens

		int superTopic, subTopic, seqLen;

		for (int doc = 0; doc < numDocs; doc++) {

			int[] localTokensPerSuperTopic = new int[numSuperTopics + 1];
			int[][] localTokensPerSuperSubTopic = new int[numSuperTopics + 1][numSubTopics + 1];

			FeatureSequence fs = (FeatureSequence) instances.get(doc).getData();

			seqLen = fs.getLength();
			if (seqLen > maxTokens) { 
				maxTokens = seqLen;
			}

			numTokens += seqLen;
			superTopics[doc] = new int[seqLen];
			subTopics[doc] = new int[seqLen];

			// Randomly assign tokens to topics
			for (int position = 0; position < seqLen; position++) {

				// Random super-topic
				superTopic = r.nextInt(numSuperTopics);

				// Random sub-topic
				subTopic = r.nextInt(numSubTopics);

				int level = r.nextInt(NUM_LEVELS);
		
				if (level == ROOT_TOPIC) {
					superTopics[doc][position] = numSuperTopics;
					subTopics[doc][position] = numSubTopics;
					typeTopicCounts[ fs.getIndexAtPosition(position) ][0]++;
					tokensPerTopic[0]++;
					tokensPerSuperTopic[numSuperTopics]++;
					tokensPerSuperSubTopic[numSuperTopics][numSubTopics]++;

					if (localTokensPerSuperTopic[numSuperTopics] == 0) {
						superTopicDocumentFrequencies[numSuperTopics]++;
						sumDocumentFrequencies++;
					}
					localTokensPerSuperTopic[numSuperTopics]++;

				}
				else if (level == SUPER_TOPIC) {
					superTopics[doc][position] = superTopic;
					subTopics[doc][position] = numSubTopics;
					typeTopicCounts[ fs.getIndexAtPosition(position) ][1 + superTopic]++;
					tokensPerTopic[1 + superTopic]++;
					tokensPerSuperTopic[superTopic]++;
					tokensPerSuperSubTopic[superTopic][numSubTopics]++;

					if (localTokensPerSuperTopic[superTopic] == 0) {
						superTopicDocumentFrequencies[superTopic]++;
						sumDocumentFrequencies++;
					}
					localTokensPerSuperTopic[superTopic]++;

					if (localTokensPerSuperSubTopic[superTopic][numSubTopics] == 0) {
                        superSubTopicDocumentFrequencies[superTopic][numSubTopics]++;
						sumSuperTopicDocumentFrequencies[superTopic]++;
                    }
                    localTokensPerSuperSubTopic[superTopic][numSubTopics]++;
				}
				else {
					superTopics[doc][position] = superTopic;
					subTopics[doc][position] = subTopic;
					typeTopicCounts[ fs.getIndexAtPosition(position) ][ 1 + numSuperTopics + subTopic]++;
					tokensPerTopic[1 + numSuperTopics + subTopic]++;
					tokensPerSuperTopic[superTopic]++;
					tokensPerSuperSubTopic[superTopic][subTopic]++;

					if (localTokensPerSuperTopic[superTopic] == 0) {
						superTopicDocumentFrequencies[superTopic]++;
						sumDocumentFrequencies++;
					}
					localTokensPerSuperTopic[superTopic]++;

					if (localTokensPerSuperSubTopic[superTopic][subTopic] == 0) {
						superSubTopicDocumentFrequencies[superTopic][subTopic]++;
						sumSuperTopicDocumentFrequencies[superTopic]++;
					}
					localTokensPerSuperSubTopic[superTopic][subTopic]++;
				}
			}
		}

		// Initialize cached priors

		superTopicPriorWeights = new double[ numSuperTopics + 1 ];
		superSubTopicPriorWeights = new double[ numSuperTopics ][ numSubTopics + 1 ];
		
		cacheSuperTopicPrior();
		for (superTopic = 0; superTopic < numSuperTopics; superTopic++) {
			cacheSuperSubTopicPrior(superTopic);
		}

		// Finally, start the sampler!

		for (int iterations = 1; iterations < numIterations; iterations++) {
			long iterationStart = System.currentTimeMillis();

			// Loop over every word in the corpus
			for (int doc = 0; doc < superTopics.length; doc++) {
				sampleTopicsForOneDoc ((FeatureSequence)instances.get(doc).getData(),
									   superTopics[doc], subTopics[doc], r);
			}
			
			if (showTopicsInterval != 0 && iterations % showTopicsInterval == 0) {
				logger.info( printTopWords(8, false) );
			}

			logger.fine((System.currentTimeMillis() - iterationStart) + " ");
			if (iterations % 10 == 0) {
				logger.info ("<" + iterations + "> LL: " + formatter.format(modelLogLikelihood() / numTokens));
			}
		}
	
    }
    
	private void cacheSuperTopicPrior() {
		for (int superTopic = 0; superTopic < numSuperTopics; superTopic++) {
			superTopicPriorWeights[superTopic] = 
				(superTopicDocumentFrequencies[superTopic] + superTopicSmoothing) /
				(sumDocumentFrequencies + (numSuperTopics + 1) * superTopicSmoothing);
		}
		
		superTopicPriorWeights[numSuperTopics] = 
			(superTopicDocumentFrequencies[numSuperTopics] + superTopicSmoothing) /
			(sumDocumentFrequencies + (numSuperTopics + 1) * superTopicSmoothing);
	}

	private void cacheSuperSubTopicPrior(int superTopic) {
		int[] documentFrequencies = superSubTopicDocumentFrequencies[superTopic];

		for (int subTopic = 0; subTopic < numSubTopics; subTopic++) {
			superSubTopicPriorWeights[superTopic][subTopic] = 
				(documentFrequencies[subTopic] + subTopicSmoothing ) /
				(sumSuperTopicDocumentFrequencies[superTopic] + (numSubTopics + 1) * subTopicSmoothing);
		}
		
		superSubTopicPriorWeights[superTopic][numSubTopics] = 
			(documentFrequencies[numSubTopics] + subTopicSmoothing ) /
			(sumSuperTopicDocumentFrequencies[superTopic] + (numSubTopics + 1) * subTopicSmoothing);
	}

    private void sampleTopicsForOneDoc (FeatureSequence oneDocTokens,
										int[] superTopics, // indexed by seq position
										int[] subTopics,
										Randoms r) {

		//long startTime = System.currentTimeMillis();
	
		int[] currentTypeTopicCounts;
		int[] currentSuperSubCounts;
		double[] currentSuperSubWeights;

		double[] wordWeights = new double[ 1 + numSuperTopics + numSubTopics ];

		int type, subTopic, superTopic;
		double rootWeight, currentSuperWeight, cumulativeWeight, sample;
	    
		int docLen = oneDocTokens.getLength();

		Arrays.fill(superCounts, 0);
		for (int t = 0; t < numSuperTopics; t++) {
			Arrays.fill(superSubCounts[t], 0);
		}
	
		// populate topic counts
		for (int position = 0; position < docLen; position++) {
			superSubCounts[ superTopics[position] ][ subTopics[position] ]++;
			superCounts[ superTopics[position] ]++;
		}

		for (superTopic = 0; superTopic < numSuperTopics; superTopic++) {
			superWeights[superTopic] =
				((double) superCounts[superTopic] + 
				 (superTopicBalance * superTopicPriorWeights[superTopic])) /
				((double) superCounts[superTopic] + subTopicBalance);
			assert(superWeights[superTopic] != 0.0);
		}

		// Iterate over the positions (words) in the document

		for (int position = 0; position < docLen; position++) {

			type = oneDocTokens.getIndexAtPosition(position);
			currentTypeTopicCounts = typeTopicCounts[type];

			superTopic = superTopics[position];
			subTopic = subTopics[position];

			if (superTopic == numSuperTopics) {
				currentTypeTopicCounts[ 0 ]--;
				tokensPerTopic[ 0 ]--;
			}
			else if (subTopic == numSubTopics) {
				currentTypeTopicCounts[ 1 + superTopic ]--;
				tokensPerTopic[ 1 + superTopic ]--;
			}
			else {
				currentTypeTopicCounts[ 1 + numSuperTopics + subTopic ]--;
				tokensPerTopic[ 1 + numSuperTopics + subTopic ]--;
			}

			// Remove this token from all counts
			superCounts[superTopic]--;
			superSubCounts[superTopic][subTopic]--;

			if (superCounts[superTopic] == 0) {
				// The document frequencies have changed.
				//  Decrement and recalculate the prior weights
				superTopicDocumentFrequencies[superTopic]--;
				sumDocumentFrequencies--;
				cacheSuperTopicPrior();
			}
			if (superTopic != numSuperTopics && 
				superSubCounts[superTopic][subTopic] == 0) {
				superSubTopicDocumentFrequencies[superTopic][subTopic]--;
				sumSuperTopicDocumentFrequencies[superTopic]--;
				cacheSuperSubTopicPrior(superTopic);
			}

			tokensPerSuperTopic[superTopic]--;
			tokensPerSuperSubTopic[superTopic][subTopic]--;

			// Update the super-topic weight for the old topic.
			superWeights[superTopic] =
                ((double) superCounts[superTopic] +
                 (superTopicBalance * superTopicPriorWeights[superTopic])) /
                ((double) superCounts[superTopic] + subTopicBalance);

			// Build a distribution over super-sub topic pairs 
			//   for this token

			for (int i=0; i cumulativeWeight) {
				// We picked the root topic
		
				currentTypeTopicCounts[ 0 ]++;
				tokensPerTopic[ 0 ] ++;
		
				superTopic = numSuperTopics;
				subTopic = numSubTopics;
			}
			else {

				// Go over the row sums to find the super-topic...
				superTopic = 0;
				while (sample > cumulativeSuperWeights[superTopic]) {
					superTopic++;
				}

				// Now read across to find the sub-topic
				currentSuperSubWeights = superSubWeights[superTopic];
				cumulativeWeight = cumulativeSuperWeights[superTopic];

				// Go over each sub-topic until the weight is LESS than
				//  the sample. Note that we're subtracting weights
				//  in the same order we added them...
				subTopic = 0;
				cumulativeWeight -=	currentSuperSubWeights[0];

				while (sample < cumulativeWeight) {
					subTopic++;
					cumulativeWeight -= currentSuperSubWeights[subTopic];
				}
		
				if (subTopic == numSubTopics) {
					currentTypeTopicCounts[ 1 + superTopic ]++;
					tokensPerTopic[ 1 + superTopic ]++;
				}
				else {
					currentTypeTopicCounts[ 1 + numSuperTopics + subTopic ]++;
					tokensPerTopic[ 1 + numSuperTopics + subTopic ]++;
				}
			}

			// Save the choice into the Gibbs state
	    
			superTopics[position] = superTopic;
			subTopics[position] = subTopic;

			// Put the new super/sub topics into the counts
	    
			superSubCounts[superTopic][subTopic]++;
			superCounts[superTopic]++;

			if (superCounts[superTopic] == 1) {
				superTopicDocumentFrequencies[superTopic]++;
				sumDocumentFrequencies++;
				cacheSuperTopicPrior();
			}
			if (superTopic != numSuperTopics && 
				superSubCounts[superTopic][subTopic] == 1) {
				superSubTopicDocumentFrequencies[superTopic][subTopic]++;
				sumSuperTopicDocumentFrequencies[superTopic]++;
				cacheSuperSubTopicPrior(superTopic);
			}

			tokensPerSuperTopic[superTopic]++;
			tokensPerSuperSubTopic[superTopic][subTopic]++;

			// Update the weight for the new super topic
			superWeights[superTopic] =
                ((double) superCounts[superTopic] +
                 (superTopicBalance * superTopicPriorWeights[superTopic])) /
                ((double) superCounts[superTopic] + subTopicBalance);

		}

    }

    public String printTopWords (int numWords, boolean useNewLines) {

		StringBuilder output = new StringBuilder();

		IDSorter[] sortedTypes = new IDSorter[numTypes];
		IDSorter[] sortedSubTopics = new IDSorter[numSubTopics];
		String[] topicTerms = new String[1 + numSuperTopics + numSubTopics];

		int subTopic, superTopic;

		for (int topic = 0; topic < topicTerms.length; topic++) {
			for (int type = 0; type < numTypes; type++)
				sortedTypes[type] = new IDSorter (type, 
									   (((double) typeTopicCounts[type][topic]) /
										tokensPerTopic[topic]));
			Arrays.sort (sortedTypes);

			StringBuilder terms = new StringBuilder();
			for (int i = 0; i < numWords; i++) {
				terms.append(instances.getDataAlphabet().lookupObject(sortedTypes[i].getID()));
				terms.append(" ");
			}
			topicTerms[topic] = terms.toString();
		}

		int maxSubTopics = 10;
		if (numSubTopics < 10) { maxSubTopics = numSubTopics; }

		output.append("Root: " + "[" + tokensPerSuperTopic[numSuperTopics] + "/" + 
					  superTopicDocumentFrequencies[numSuperTopics] + "]" +
					  topicTerms[0] + "\n");

		for (superTopic = 0; superTopic < numSuperTopics; superTopic++) {
			for (subTopic = 0; subTopic < numSubTopics; subTopic++) {
				sortedSubTopics[subTopic] =
					new IDSorter(subTopic, tokensPerSuperSubTopic[superTopic][subTopic]);
			}

			Arrays.sort(sortedSubTopics);

			output.append("\nSuper-topic " + superTopic + 
						  " [" + tokensPerSuperTopic[superTopic] + "/" + 
						  superTopicDocumentFrequencies[superTopic] + " " + 
						  tokensPerSuperSubTopic[superTopic][numSubTopics] + "/" + 
						  superSubTopicDocumentFrequencies[superTopic][numSubTopics] + "]\t" + 
						  topicTerms[1 + superTopic] + "\n");

			for (int i = 0; i < maxSubTopics; i++) {
				subTopic = sortedSubTopics[i].getID();
				output.append(subTopic + ":\t" +
							  tokensPerSuperSubTopic[superTopic][subTopic] + "/" +
							  formatter.format(superSubTopicDocumentFrequencies[superTopic][subTopic]) + "\t" +
							  topicTerms[1 + numSuperTopics + subTopic] + "\n");
			}
		}

		return output.toString();
    }
    
    public void printState (File f) throws IOException {
		PrintWriter out = new PrintWriter (new BufferedWriter (new FileWriter(f)));
		printState (out);
		out.close();
    }
    
    public void printState (PrintWriter out) {

		Alphabet alphabet = instances.getDataAlphabet();
		out.println ("#doc pos typeindex type super-topic sub-topic");

		for (int doc = 0; doc < superTopics.length; doc++) {
			StringBuilder output = new StringBuilder();
			
			FeatureSequence fs = (FeatureSequence) instances.get(doc).getData();
			for (int position = 0; position < superTopics[doc].length; position++) {
				int type = fs.getIndexAtPosition(position);
				output.append(doc); output.append(' ');
				output.append(position); output.append(' ');
				output.append(type); output.append(' ');
				output.append(alphabet.lookupObject(type)); output.append(' ');
				output.append(superTopics[doc][position]); output.append(' ');
				output.append(subTopics[doc][position]); output.append("\n");
			}

			out.print(output);
		}
    }

    public double modelLogLikelihood() {
        double logLikelihood = 0.0;
        int nonZeroTopics;

        // The likelihood of the model is a combination of a 
        // Dirichlet-multinomial for the words in each topic
        // and a Dirichlet-multinomial for the topics in each
        // document.

        // The likelihood function of a dirichlet multinomial is
        //   Gamma( sum_i alpha_i )  prod_i Gamma( alpha_i + N_i )
        //  prod_i Gamma( alpha_i )   Gamma( sum_i (alpha_i + N_i) )

        // So the log likelihood is 
        //  logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) + 
        //   sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ]

        // Do the documents first

        int superTopic, subTopic;

        double[] superTopicLogGammas = new double[numSuperTopics + 1];
        double[][] superSubTopicLogGammas = new double[numSuperTopics][numSubTopics + 1];

        for (superTopic=0; superTopic < numSuperTopics; superTopic++) {
            superTopicLogGammas[ superTopic ] = Dirichlet.logGamma(superTopicPriorWeights[superTopic]);

            for (subTopic=0; subTopic < numSubTopics; subTopic++) {
                superSubTopicLogGammas[ superTopic ][ subTopic ] =
                    Dirichlet.logGamma(subTopicBalance * superSubTopicPriorWeights[superTopic][subTopic]);
            }
			superSubTopicLogGammas[ superTopic ][ numSubTopics ] =
				Dirichlet.logGamma(subTopicBalance * superSubTopicPriorWeights[superTopic][numSubTopics]);
        }
		superTopicLogGammas[ numSuperTopics ] = Dirichlet.logGamma(superTopicPriorWeights[numSuperTopics]);

        int[] superTopicCounts = new int[ numSuperTopics + 1];
        int[][] superSubTopicCounts = new int[ numSuperTopics ][ numSubTopics + 1];

        int[] docSuperTopics;
        int[] docSubTopics;

        for (int doc=0; doc < superTopics.length; doc++) {
            
            docSuperTopics = superTopics[doc];
            docSubTopics = subTopics[doc];

            for (int token=0; token < docSuperTopics.length; token++) {
                superTopic = docSuperTopics[ token ];
                subTopic =   docSubTopics[ token ];

                superTopicCounts[ superTopic ]++;
				if (superTopic != numSuperTopics) {
					superSubTopicCounts[ superTopic ][ subTopic ]++;
				}
            }

            for (superTopic=0; superTopic < numSuperTopics; superTopic++) {
                if (superTopicCounts[superTopic] > 0) {
                    logLikelihood += (Dirichlet.logGamma(superTopicBalance * superTopicPriorWeights[superTopic] +
														 superTopicCounts[superTopic]) -
                                      superTopicLogGammas[ superTopic ]);
                    
                    for (subTopic=0; subTopic < numSubTopics; subTopic++) {
                        if (superSubTopicCounts[superTopic][subTopic] > 0) {
                            logLikelihood += (Dirichlet.logGamma(subTopicBalance * superSubTopicPriorWeights[superTopic][subTopic] +
																 superSubTopicCounts[superTopic][subTopic]) -
                                              superSubTopicLogGammas[ superTopic ][ subTopic ]);
                        }
                    }
                    
					// Account for words assigned to super-topic
		    
					logLikelihood += (Dirichlet.logGamma(subTopicBalance * superSubTopicPriorWeights[superTopic][numSubTopics] +
														 superSubTopicCounts[superTopic][numSubTopics]) -
									  superSubTopicLogGammas[ superTopic ][ numSubTopics ]);

					// The term for the sums
                    logLikelihood += 
                        Dirichlet.logGamma(subTopicBalance) -
                        Dirichlet.logGamma(subTopicBalance + superTopicCounts[superTopic]);

                    Arrays.fill(superSubTopicCounts[superTopic], 0);
                }
            }

			// Account for words assigned to the root topic
			logLikelihood += (Dirichlet.logGamma(superTopicBalance * superTopicPriorWeights[numSuperTopics] +
												 superTopicCounts[numSuperTopics]) -
							  superTopicLogGammas[ numSuperTopics ]);

            // subtract the (count + parameter) sum term
            logLikelihood -= Dirichlet.logGamma(superTopicBalance + docSuperTopics.length);
            Arrays.fill(superTopicCounts, 0);
            
        }

        // add the parameter sum term for every document all at once.
        logLikelihood += superTopics.length * Dirichlet.logGamma(superTopicBalance);

        // And the topics

        // Count the number of type-topic pairs
        int nonZeroTypeTopics = 0;

        for (int type=0; type < numTypes; type++) {
            // reuse this array as a pointer
            int[] topicCounts = typeTopicCounts[type];

            for (int topic=0; topic < numSuperTopics + numSubTopics + 1; topic++) {
                if (topicCounts[topic] > 0) {
                    nonZeroTypeTopics++;
                    logLikelihood += Dirichlet.logGamma(beta + topicCounts[topic]);
                }
            }
        }

		for (int topic=0; topic < numSuperTopics + numSubTopics + 1; topic++) {
            logLikelihood -= 
				Dirichlet.logGamma( (beta * (numSuperTopics + numSubTopics + 1)) +
											tokensPerTopic[ topic ] );
		}
	
        logLikelihood += 
			(Dirichlet.logGamma(beta * (numSuperTopics + numSubTopics + 1))) -
			(Dirichlet.logGamma(beta) * nonZeroTypeTopics);
	
		return logLikelihood;
    }

    public static void main (String[] args) throws IOException {
		CommandOption.setSummary(HierarchicalPAM.class, "Train a three level hierarchy of topics");
		CommandOption.process(HierarchicalPAM.class, args);

        InstanceList instances = InstanceList.load (new File(inputFile.value));
        InstanceList testing = null;

        HierarchicalPAM pam = new HierarchicalPAM (numSuperTopicsOption.value, numSubTopicsOption.value,
												   superTopicBalanceOption.value, subTopicBalanceOption.value);
        pam.estimate (instances, testing, 1000, 100, 0, 250, null, new Randoms());
		if (stateFile.wasInvoked()) {
			pam.printState(new File(stateFile.value));
		}
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy