All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.topics.NPTopicModel Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.	For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.topics;

import java.util.*;
import java.util.logging.*;
import java.util.zip.*;

import java.io.*;
import java.text.NumberFormat;

import cc.mallet.types.*;
import cc.mallet.util.*;

import gnu.trove.*;

/**
 * A non-parametric topic model that uses the "minimal path" assumption
 *  to reduce bookkeeping.
 * 
 * @author David Mimno
 */

public class NPTopicModel implements Serializable {

	private static Logger logger = MalletLogger.getLogger(NPTopicModel.class.getName());
	
	// the training instances and their topic assignments
	protected ArrayList data;  

	// the alphabet for the input data
	protected Alphabet alphabet; 

	// the alphabet for the topics
	protected LabelAlphabet topicAlphabet; 
	
	// The largest topic ID seen so far
	protected int maxTopic;
	// The current number of topics
	protected int numTopics;

	// The size of the vocabulary
	protected int numTypes;

	// Prior parameters
	protected double alpha;
	protected double gamma;
	protected double beta;   // Prior on per-topic multinomial distribution over words
	protected double betaSum;
	public static final double DEFAULT_BETA = 0.01;
	
	// Statistics needed for sampling.
	protected TIntIntHashMap[] typeTopicCounts; // indexed by 
	protected TIntIntHashMap tokensPerTopic; // indexed by 

	// The number of documents that contain at least one
	//  token with a given topic.
	protected TIntIntHashMap docsPerTopic;
	protected int totalDocTopics = 0;
	
	public int showTopicsInterval = 50;
	public int wordsPerTopic = 10;
	
	protected Randoms random;
	protected NumberFormat formatter;
	protected boolean printLogLikelihood = false;
	
	/** @param alpha this parameter balances the local document topic counts with 
	 *                the global distribution over topics.
	 *  @param gamma this parameter is the weight on a completely new, never-before-seen topic
	 *                in the global distribution.
	 *  @param beta  this parameter controls the variability of the topic-word distributions
	 */
	public NPTopicModel (double alpha, double gamma, double beta) {

		this.data = new ArrayList();
		this.topicAlphabet = AlphabetFactory.labelAlphabetOfSize(1);

		this.alpha = alpha;
		this.gamma = gamma;
		this.beta = beta;
		this.random = new Randoms();
		
		tokensPerTopic = new TIntIntHashMap();
		docsPerTopic = new TIntIntHashMap();
		
		formatter = NumberFormat.getInstance();
		formatter.setMaximumFractionDigits(5);

		logger.info("Non-Parametric LDA");
	}
	
	public void setTopicDisplay(int interval, int n) {
		this.showTopicsInterval = interval;
		this.wordsPerTopic = n;
	}

	public void setRandomSeed(int seed) {
		random = new Randoms(seed);
	}

	public void addInstances (InstanceList training, int initialTopics) {

		alphabet = training.getDataAlphabet();
		numTypes = alphabet.size();
		
		betaSum = beta * numTypes;
		
		typeTopicCounts = new TIntIntHashMap[numTypes];
		for (int type=0; type < numTypes; type++) {
			typeTopicCounts[type] = new TIntIntHashMap();
		}

		numTopics = initialTopics;
		
		int doc = 0;

		for (Instance instance : training) {
			doc++;

			TIntIntHashMap topicCounts = new TIntIntHashMap();

			FeatureSequence tokens = (FeatureSequence) instance.getData();
			LabelSequence topicSequence =
				new LabelSequence(topicAlphabet, new int[ tokens.size() ]);
			
			int[] topics = topicSequence.getFeatures();
			for (int position = 0; position < tokens.size(); position++) {

				int topic = random.nextInt(numTopics);
				tokensPerTopic.adjustOrPutValue(topic, 1, 1);
				topics[position] = topic;

				// Keep track of the number of docs with at least one token
				//  in a given topic.
				if (! topicCounts.containsKey(topic)) {
					docsPerTopic.adjustOrPutValue(topic, 1, 1);
					totalDocTopics++;
					topicCounts.put(topic, 1);
				}
				else {
					topicCounts.adjustValue(topic, 1);
				}
				
				int type = tokens.getIndexAtPosition(position);
				typeTopicCounts[type].adjustOrPutValue(topic, 1, 1);
			}

			TopicAssignment t = new TopicAssignment (instance, topicSequence);
			data.add (t);
		}

		maxTopic = numTopics - 1;

	}

	public void sample (int iterations) throws IOException {

		for (int iteration = 1; iteration <= iterations; iteration++) {

			long iterationStart = System.currentTimeMillis();

			// Loop over every document in the corpus
			for (int doc = 0; doc < data.size(); doc++) {
				FeatureSequence tokenSequence =
					(FeatureSequence) data.get(doc).instance.getData();
				LabelSequence topicSequence =
					(LabelSequence) data.get(doc).topicSequence;

				sampleTopicsForOneDoc (tokenSequence, topicSequence);
			}
		
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
			logger.info(iteration + "\t" + elapsedMillis + "ms\t" + numTopics);

			// Occasionally print more information
			if (showTopicsInterval != 0 && iteration % showTopicsInterval == 0) {
				logger.info("<" + iteration + "> #Topics: " + numTopics + "\n" +
							topWords (wordsPerTopic));
			}

		}
	}
	
	protected void sampleTopicsForOneDoc (FeatureSequence tokenSequence,
										  FeatureSequence topicSequence) {

		int[] topics = topicSequence.getFeatures();

		TIntIntHashMap currentTypeTopicCounts;
		int type, oldTopic, newTopic;
		double topicWeightsSum;
		int docLength = tokenSequence.getLength();

		TIntIntHashMap localTopicCounts = new TIntIntHashMap();

		//		populate topic counts
		for (int position = 0; position < docLength; position++) {
			localTopicCounts.adjustOrPutValue(topics[position], 1, 1);
		}

		double score, sum;
		double[] topicTermScores = new double[numTopics + 1];

		// Store a list of all the topics that currently exist.
		int[] allTopics = docsPerTopic.keys();
			
		//	Iterate over the positions (words) in the document 
		for (int position = 0; position < docLength; position++) {
			type = tokenSequence.getIndexAtPosition(position);
			oldTopic = topics[position];

			// Grab the relevant row from our two-dimensional array
			currentTypeTopicCounts = typeTopicCounts[type];

			//	Remove this token from all counts. 
			
			int currentCount = localTopicCounts.get(oldTopic);

			// Was this the only token of this topic in the doc?
			if (currentCount == 1) {
				localTopicCounts.remove(oldTopic);
				
				// Was this the only doc with this topic?
				int docCount = docsPerTopic.get(oldTopic);
				if (docCount == 1) {
					// This should be the very last token
					assert(tokensPerTopic.get(oldTopic) == 1);
					
					// Get rid of the topic
					docsPerTopic.remove(oldTopic);
					totalDocTopics--;
					tokensPerTopic.remove(oldTopic);
					numTopics--;

					allTopics = docsPerTopic.keys();
					topicTermScores = new double[numTopics + 1];
				}
				else {
					// This is the last in the doc, but the topic still exists
					docsPerTopic.adjustValue(oldTopic, -1);
					totalDocTopics--;
					tokensPerTopic.adjustValue(oldTopic, -1);
				}
			}
			else {
				// There is at least one other token in this doc
				//  with this topic.
				localTopicCounts.adjustValue(oldTopic, -1);
				tokensPerTopic.adjustValue(oldTopic, -1);
			}

			if (currentTypeTopicCounts.get(oldTopic) == 1) {
				currentTypeTopicCounts.remove(oldTopic);
			}
			else {
				currentTypeTopicCounts.adjustValue(oldTopic, -1);
			}

			// Now calculate and add up the scores for each topic for this word
			sum = 0.0;

			// First do the topics that currently exist
			for (int i = 0; i < numTopics; i++) {
				int topic = allTopics[i];

				topicTermScores[i] =
					(localTopicCounts.get(topic) + 
					 alpha * (docsPerTopic.get(topic) / 
							  (totalDocTopics + gamma))) *
					(currentTypeTopicCounts.get(topic) + beta) /
					(tokensPerTopic.get(topic) + betaSum);

				sum += topicTermScores[i];
			}

			// Add the weight for a new topic
			topicTermScores[numTopics] =
				alpha * gamma / ( numTypes * (totalDocTopics + gamma) );
			
			sum += topicTermScores[numTopics];

			// Choose a random point between 0 and the sum of all topic scores
			double sample = random.nextUniform() * sum;

			// Figure out which topic contains that point
			newTopic = -1;
			
			int i = -1;
			while (sample > 0.0) {
				i++;
				sample -= topicTermScores[i];
			}

			if (i < numTopics) {
				newTopic = allTopics[i];

				topics[position] = newTopic;
				currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
				tokensPerTopic.adjustValue(newTopic, 1);
				
				if (localTopicCounts.containsKey(newTopic)) {
					localTopicCounts.adjustValue(newTopic, 1);
				}
				else {
					// This is not a new topic, but it is new for this doc.
					localTopicCounts.put(newTopic, 1);
					docsPerTopic.adjustValue(newTopic, 1);
					totalDocTopics++;
				}
			}
			else {
				// completely new topic: first generate an ID

				newTopic = maxTopic + 1;
				maxTopic = newTopic;

				numTopics++;
				
				topics[position] = newTopic;
				localTopicCounts.put(newTopic, 1);
				
				docsPerTopic.put(newTopic, 1);
				totalDocTopics++;
				
				currentTypeTopicCounts.put(newTopic, 1);
                tokensPerTopic.put(newTopic, 1);
				
				allTopics = docsPerTopic.keys();
			    topicTermScores = new double[numTopics + 1];
			}
		}
	}
	
	// 
	// Methods for displaying and saving results
	//

	public String topWords (int numWords) {

		StringBuilder output = new StringBuilder();

		IDSorter[] sortedWords = new IDSorter[numTypes];

		for (int topic: docsPerTopic.keys()) {
			for (int type = 0; type < numTypes; type++) {
				sortedWords[type] = new IDSorter(type, typeTopicCounts[type].get(topic));
			}

			Arrays.sort(sortedWords);
			
			output.append(topic + "\t" + tokensPerTopic.get(topic) + "\t");
			for (int i=0; i < numWords; i++) {
				if (sortedWords[i].getWeight() < 1.0) {
					break;
				}
				output.append(alphabet.lookupObject(sortedWords[i].getID()) + " ");
			}
			output.append("\n");
		}

		return output.toString();
	}

	public void printState (File f) throws IOException {
		PrintStream out =
			new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
		printState(out);
		out.close();
	}
	
	public void printState (PrintStream out) {

		out.println ("#doc source pos typeindex type topic");

		for (int doc = 0; doc < data.size(); doc++) {
			FeatureSequence tokenSequence =	(FeatureSequence) data.get(doc).instance.getData();
			LabelSequence topicSequence =	(LabelSequence) data.get(doc).topicSequence;

			String source = "NA";
			if (data.get(doc).instance.getSource() != null) {
				source = data.get(doc).instance.getSource().toString();
			}

			for (int position = 0; position < topicSequence.getLength(); position++) {
				int type = tokenSequence.getIndexAtPosition(position);
				int topic = topicSequence.getIndexAtPosition(position);
				out.print(doc); out.print(' ');
				out.print(source); out.print(' '); 
				out.print(position); out.print(' ');
				out.print(type); out.print(' ');
				out.print(alphabet.lookupObject(type)); out.print(' ');
				out.print(topic); out.println();
			}
		}
	}
	
	public static void main (String[] args) throws IOException {

		InstanceList training = InstanceList.load (new File(args[0]));

		int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;

		NPTopicModel lda = new NPTopicModel (5.0, 10.0, 0.1);
		lda.addInstances(training, numTopics);
		lda.sample(1000);
	}
	
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy