All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.topics.LDAStream Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/**
 * Implement different Gibbs sampling based inference methods
 */
package cc.mallet.topics;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.zip.GZIPOutputStream;

import cc.mallet.types.FeatureCounter;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.util.Randoms;
import gnu.trove.TIntIntHashMap;

/**
 * @author Limin Yao, David Mimno
 *
 */
public class LDAStream extends LDAHyper {

	protected ArrayList test; // the test instances and their topic assignments
	/**
	 * @param numberOfTopics
	 */
	public LDAStream(int numberOfTopics) {
		super(numberOfTopics);
		// TODO Auto-generated constructor stub
	}

	/**
	 * @param numberOfTopics
	 * @param alphaSum
	 * @param beta
	 */
	public LDAStream(int numberOfTopics, double alphaSum, double beta) {
		super(numberOfTopics, alphaSum, beta);
		// TODO Auto-generated constructor stub
	}

	/**
	 * @param numberOfTopics
	 * @param alphaSum
	 * @param beta
	 * @param random
	 */
	public LDAStream(int numberOfTopics, double alphaSum, double beta,
			Randoms random) {
		super(numberOfTopics, alphaSum, beta, random);
		// TODO Auto-generated constructor stub
	}

	/**
	 * @param topicAlphabet
	 * @param alphaSum
	 * @param beta
	 * @param random
	 */
	public LDAStream(LabelAlphabet topicAlphabet, double alphaSum, double beta,
			Randoms random) {
		super(topicAlphabet, alphaSum, beta, random);
		// TODO Auto-generated constructor stub
	}

	public ArrayList getTest() { return test; }
	
	//first training a topic model on training data,
	//inference on test data, count typeTopicCounts
	// re-sampling on all data
	public void inferenceAll(int maxIteration){
		this.test = new ArrayList();  //initialize test
		//initial sampling on testdata
		ArrayList topicSequences = new ArrayList();
		for (Instance instance : testing) {
			LabelSequence topicSequence = new LabelSequence(topicAlphabet, new int[instanceLength(instance)]);
			if (false) {
				// This method not yet obeying its last "false" argument, and must be for this to work
				//sampleTopicsForOneDoc((FeatureSequence)instance.getData(), topicSequence, false, false);
			} else {
				Randoms r = new Randoms();
				FeatureSequence fs = (FeatureSequence) instance.getData();
				int[] topics = topicSequence.getFeatures();
				for (int i = 0; i < topics.length; i++) {
					int type = fs.getIndexAtPosition(i);
					topics[i] = r.nextInt(numTopics);
					typeTopicCounts[type].adjustOrPutValue(topics[i], 1, 1);
				    tokensPerTopic[topics[i]]++;
				}
			}
			topicSequences.add (topicSequence);
		}

		//construct test
		assert (testing.size() == topicSequences.size());
		for (int i = 0; i < testing.size(); i++) {
			Topication t = new Topication (testing.get(i), this, topicSequences.get(i));
			test.add (t);
		}

		long startTime = System.currentTimeMillis();
		//loop
		int iter = 0;
		for ( ; iter <= maxIteration; iter++) {
			if(iter%100==0)
			{
				System.out.print("Iteration: " + iter);
				System.out.println();
			}
			int numDocs = test.size(); // TODO
			for (int di = 0; di < numDocs; di++) {
				FeatureSequence tokenSequence = (FeatureSequence) test.get(di).instance.getData();
				LabelSequence topicSequence = test.get(di).topicSequence;
				sampleTopicsForOneTestDocAll (tokenSequence, topicSequence);
			}
		}

		long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
		long minutes = seconds / 60;	seconds %= 60;
		long hours = minutes / 60;	minutes %= 60;
		long days = hours / 24;	hours %= 24;
		System.out.print ("\nTotal inferencing time: ");
		if (days != 0) { System.out.print(days); System.out.print(" days "); }
		if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
		if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
		System.out.print(seconds); System.out.println(" seconds");
	}

	//called by inferenceAll, using unseen words in testdata
	private void sampleTopicsForOneTestDocAll(FeatureSequence tokenSequence,
			LabelSequence topicSequence) {
		// TODO Auto-generated method stub
		int[] oneDocTopics = topicSequence.getFeatures();

		TIntIntHashMap currentTypeTopicCounts;
		int type, oldTopic, newTopic;
		double tw;
		double[] topicWeights = new double[numTopics];
		double topicWeightsSum;
		int docLength = tokenSequence.getLength();

		//		populate topic counts
		int[] localTopicCounts = new int[numTopics];
		for (int ti = 0; ti < numTopics; ti++){
			localTopicCounts[ti] = 0;
		}
		for (int position = 0; position < docLength; position++) {
			localTopicCounts[oneDocTopics[position]] ++;
		}

		// Iterate over the positions (words) in the document
		for (int si = 0; si < docLength; si++) {
			type = tokenSequence.getIndexAtPosition(si);
			oldTopic = oneDocTopics[si];

			// Remove this token from all counts
			localTopicCounts[oldTopic] --;

			currentTypeTopicCounts = typeTopicCounts[type];
			assert(currentTypeTopicCounts.get(oldTopic) >= 0);

			if (currentTypeTopicCounts.get(oldTopic) == 1) {
				currentTypeTopicCounts.remove(oldTopic);
			}
			else {
				currentTypeTopicCounts.adjustValue(oldTopic, -1);
			}
			tokensPerTopic[oldTopic]--;

			// Build a distribution over topics for this token
			Arrays.fill (topicWeights, 0.0);
			topicWeightsSum = 0;

			for (int ti = 0; ti < numTopics; ti++) {
				tw = ((currentTypeTopicCounts.get(ti) + beta) / (tokensPerTopic[ti] + betaSum))
				      * ((localTopicCounts[ti] + alpha[ti])); // (/docLen-1+tAlpha); is constant across all topics
				topicWeightsSum += tw;
				topicWeights[ti] = tw;
			}
			// Sample a topic assignment from this distribution
			newTopic = random.nextDiscrete (topicWeights, topicWeightsSum);

			// Put that new topic into the counts
			oneDocTopics[si] = newTopic;
			currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
			localTopicCounts[newTopic] ++;
			tokensPerTopic[newTopic]++;
		}
	}
	
	//what do we have:
	//typeTopicCounts, tokensPerTopic, topic-sequence of training and test data
	public void estimateAll(int iteration) throws IOException {
		//re-Gibbs sampling on all data
		data.addAll(test);
		initializeHistogramsAndCachedValues();
		estimate(iteration);
	}

	//inference on testdata, one problem is how to deal with unseen words
	//unseen words is in the Alphabet, but typeTopicsCount entry is null
	//added by Limin Yao
	/**
	 * @param maxIteration
	 * @param
	 */
	public void inference(int maxIteration){
		this.test = new ArrayList();  //initialize test
		//initial sampling on testdata
		ArrayList topicSequences = new ArrayList();
		for (Instance instance : testing) {
			LabelSequence topicSequence = new LabelSequence(topicAlphabet, new int[instanceLength(instance)]);
			if (false) {
				// This method not yet obeying its last "false" argument, and must be for this to work
				//sampleTopicsForOneDoc((FeatureSequence)instance.getData(), topicSequence, false, false);
			} else {
				Randoms r = new Randoms();
				FeatureSequence fs = (FeatureSequence) instance.getData();
				int[] topics = topicSequence.getFeatures();
				for (int i = 0; i < topics.length; i++) {
					int type = fs.getIndexAtPosition(i);
					topics[i] = r.nextInt(numTopics);
				/*	if(typeTopicCounts[type].size() != 0) {
						topics[i] = r.nextInt(numTopics);
					} else {
						topics[i] = -1;  // for unseen words
					}*/
				}
			}
			topicSequences.add (topicSequence);
		}

		//construct test
		assert (testing.size() == topicSequences.size());
		for (int i = 0; i < testing.size(); i++) {
			Topication t = new Topication (testing.get(i), this, topicSequences.get(i));
			test.add (t);
			// Include sufficient statistics for this one doc
			// add count on new data to n[k][w] and n[k][*]
			// pay attention to unseen words
			FeatureSequence tokenSequence = (FeatureSequence) t.instance.getData();
			LabelSequence topicSequence = t.topicSequence;
			for (int pi = 0; pi < topicSequence.getLength(); pi++) {
				int topic = topicSequence.getIndexAtPosition(pi);
				int type = tokenSequence.getIndexAtPosition(pi);
				if(topic != -1) // type seen in training
				{
					typeTopicCounts[type].adjustOrPutValue(topic, 1, 1);
				    tokensPerTopic[topic]++;
				}
			}
		}

		long startTime = System.currentTimeMillis();
		//loop
		int iter = 0;
		for ( ; iter <= maxIteration; iter++) {
			if(iter%100==0)
			{
				System.out.print("Iteration: " + iter);
				System.out.println();
			}
			int numDocs = test.size(); // TODO
			for (int di = 0; di < numDocs; di++) {
				FeatureSequence tokenSequence = (FeatureSequence) test.get(di).instance.getData();
				LabelSequence topicSequence = test.get(di).topicSequence;
				sampleTopicsForOneTestDoc (tokenSequence, topicSequence);
			}
		}

		long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
		long minutes = seconds / 60;	seconds %= 60;
		long hours = minutes / 60;	minutes %= 60;
		long days = hours / 24;	hours %= 24;
		System.out.print ("\nTotal inferencing time: ");
		if (days != 0) { System.out.print(days); System.out.print(" days "); }
		if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
		if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
		System.out.print(seconds); System.out.println(" seconds");
	}

	private void sampleTopicsForOneTestDoc(FeatureSequence tokenSequence,
			LabelSequence topicSequence) {
		// TODO Auto-generated method stub
		int[] oneDocTopics = topicSequence.getFeatures();

		TIntIntHashMap currentTypeTopicCounts;
		int type, oldTopic, newTopic;
		double tw;
		double[] topicWeights = new double[numTopics];
		double topicWeightsSum;
		int docLength = tokenSequence.getLength();

		//		populate topic counts
		int[] localTopicCounts = new int[numTopics];
		for (int ti = 0; ti < numTopics; ti++){
			localTopicCounts[ti] = 0;
		}
		for (int position = 0; position < docLength; position++) {
			if(oneDocTopics[position] != -1) {
				localTopicCounts[oneDocTopics[position]] ++;
			}
		}

		// Iterate over the positions (words) in the document
		for (int si = 0; si < docLength; si++) {
			type = tokenSequence.getIndexAtPosition(si);
			oldTopic = oneDocTopics[si];
			if(oldTopic == -1) {
				continue;
			}

			// Remove this token from all counts
     		localTopicCounts[oldTopic] --;
     		currentTypeTopicCounts = typeTopicCounts[type];
			assert(currentTypeTopicCounts.get(oldTopic) >= 0);

			if (currentTypeTopicCounts.get(oldTopic) == 1) {
				currentTypeTopicCounts.remove(oldTopic);
			}
			else {
				currentTypeTopicCounts.adjustValue(oldTopic, -1);
			}
			tokensPerTopic[oldTopic]--;

			// Build a distribution over topics for this token
			Arrays.fill (topicWeights, 0.0);
			topicWeightsSum = 0;

			for (int ti = 0; ti < numTopics; ti++) {
				tw = ((currentTypeTopicCounts.get(ti) + beta) / (tokensPerTopic[ti] + betaSum))
				      * ((localTopicCounts[ti] + alpha[ti])); // (/docLen-1+tAlpha); is constant across all topics
				topicWeightsSum += tw;
				topicWeights[ti] = tw;
			}
			// Sample a topic assignment from this distribution
			newTopic = random.nextDiscrete (topicWeights, topicWeightsSum);

			// Put that new topic into the counts
			oneDocTopics[si] = newTopic;
			currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
			localTopicCounts[newTopic] ++;
			tokensPerTopic[newTopic]++;
		}
	}

	//inference method 3, for each doc, for each iteration, for each word
	//compare against inference(that is method2): for each iter, for each doc, for each word
	public void inferenceOneByOne(int maxIteration){
		this.test = new ArrayList();  //initialize test
		//initial sampling on testdata
		ArrayList topicSequences = new ArrayList();
		for (Instance instance : testing) {
			LabelSequence topicSequence = new LabelSequence(topicAlphabet, new int[instanceLength(instance)]);
			if (false) {
				// This method not yet obeying its last "false" argument, and must be for this to work
				//sampleTopicsForOneDoc((FeatureSequence)instance.getData(), topicSequence, false, false);
			} else {
				Randoms r = new Randoms();
				FeatureSequence fs = (FeatureSequence) instance.getData();
				int[] topics = topicSequence.getFeatures();
				for (int i = 0; i < topics.length; i++) {
					int type = fs.getIndexAtPosition(i);
					topics[i] = r.nextInt(numTopics);
					typeTopicCounts[type].adjustOrPutValue(topics[i], 1, 1);
					tokensPerTopic[topics[i]]++;
				/*	if(typeTopicCounts[type].size() != 0) {
						topics[i] = r.nextInt(numTopics);
						typeTopicCounts[type].adjustOrPutValue(topics[i], 1, 1);
						tokensPerTopic[topics[i]]++;
					} else {
						topics[i] = -1;  // for unseen words
					}*/
				}
			}
			topicSequences.add (topicSequence);
		}

		//construct test
		assert (testing.size() == topicSequences.size());
		for (int i = 0; i < testing.size(); i++) {
			Topication t = new Topication (testing.get(i), this, topicSequences.get(i));
			test.add (t);
		}

		long startTime = System.currentTimeMillis();
		//loop
		int iter = 0;
		int numDocs = test.size(); // TODO
		for (int di = 0; di < numDocs; di++) {
			iter = 0;
			FeatureSequence tokenSequence = (FeatureSequence) test.get(di).instance.getData();
			LabelSequence topicSequence = test.get(di).topicSequence;
			for( ; iter <= maxIteration; iter++) {
				sampleTopicsForOneTestDoc (tokenSequence, topicSequence);
			}
			if(di%100==0)
			{
				System.out.print("Docnum: " + di);
				System.out.println();
			}
		}

		long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
		long minutes = seconds / 60;	seconds %= 60;
		long hours = minutes / 60;	minutes %= 60;
		long days = hours / 24;	hours %= 24;
		System.out.print ("\nTotal inferencing time: ");
		if (days != 0) { System.out.print(days); System.out.print(" days "); }
		if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
		if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
		System.out.print(seconds); System.out.println(" seconds");
	}
	
	public void inferenceWithTheta(int maxIteration, InstanceList theta){
		this.test = new ArrayList();  //initialize test
		//initial sampling on testdata
		ArrayList topicSequences = new ArrayList();
		for (Instance instance : testing) {
			LabelSequence topicSequence = new LabelSequence(topicAlphabet, new int[instanceLength(instance)]);
			if (false) {
				// This method not yet obeying its last "false" argument, and must be for this to work
				//sampleTopicsForOneDoc((FeatureSequence)instance.getData(), topicSequence, false, false);
			} else {
				Randoms r = new Randoms();
				FeatureSequence fs = (FeatureSequence) instance.getData();
				int[] topics = topicSequence.getFeatures();
				for (int i = 0; i < topics.length; i++) {
					int type = fs.getIndexAtPosition(i);
					topics[i] = r.nextInt(numTopics);
				}
			}
			topicSequences.add (topicSequence);
		}

		//construct test
		assert (testing.size() == topicSequences.size());
		for (int i = 0; i < testing.size(); i++) {
			Topication t = new Topication (testing.get(i), this, topicSequences.get(i));
			test.add (t);
			// Include sufficient statistics for this one doc
			// add count on new data to n[k][w] and n[k][*]
			// pay attention to unseen words
			FeatureSequence tokenSequence = (FeatureSequence) t.instance.getData();
			LabelSequence topicSequence = t.topicSequence;
			for (int pi = 0; pi < topicSequence.getLength(); pi++) {
				int topic = topicSequence.getIndexAtPosition(pi);
				int type = tokenSequence.getIndexAtPosition(pi);
				if(topic != -1) // type seen in training
				{
					typeTopicCounts[type].adjustOrPutValue(topic, 1, 1);
				    tokensPerTopic[topic]++;
				}
			}
		}

		long startTime = System.currentTimeMillis();
		//loop
		int iter = 0;
		for ( ; iter <= maxIteration; iter++) {
			if(iter%100==0)
			{
				System.out.print("Iteration: " + iter);
				System.out.println();
			}
			int numDocs = test.size(); // TODO
			for (int di = 0; di < numDocs; di++) {
				FeatureVector fvTheta = (FeatureVector) theta.get(di).getData();
				double[] topicDistribution = fvTheta.getValues();
				FeatureSequence tokenSequence = (FeatureSequence) test.get(di).instance.getData();
				LabelSequence topicSequence = test.get(di).topicSequence;
				sampleTopicsForOneDocWithTheta (tokenSequence, topicSequence, topicDistribution);
			}
		}

		long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
		long minutes = seconds / 60;	seconds %= 60;
		long hours = minutes / 60;	minutes %= 60;
		long days = hours / 24;	hours %= 24;
		System.out.print ("\nTotal inferencing time: ");
		if (days != 0) { System.out.print(days); System.out.print(" days "); }
		if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
		if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
		System.out.print(seconds); System.out.println(" seconds");
	}

	//sampling with known theta, from maxent
	private void sampleTopicsForOneDocWithTheta(FeatureSequence tokenSequence,
			LabelSequence topicSequence, double[] topicDistribution) {
		// TODO Auto-generated method stub
		int[] oneDocTopics = topicSequence.getFeatures();

		TIntIntHashMap currentTypeTopicCounts;
		int type, oldTopic, newTopic;
		double tw;
		double[] topicWeights = new double[numTopics];
		double topicWeightsSum;
		int docLength = tokenSequence.getLength();
		
		// Iterate over the positions (words) in the document
		for (int si = 0; si < docLength; si++) {
			type = tokenSequence.getIndexAtPosition(si);
			oldTopic = oneDocTopics[si];
			if(oldTopic == -1) {
				continue;
			}

	 		currentTypeTopicCounts = typeTopicCounts[type];
			assert(currentTypeTopicCounts.get(oldTopic) >= 0);

			if (currentTypeTopicCounts.get(oldTopic) == 1) {
				currentTypeTopicCounts.remove(oldTopic);
			}
			else {
				currentTypeTopicCounts.adjustValue(oldTopic, -1);
			}
			tokensPerTopic[oldTopic]--;

			// Build a distribution over topics for this token
			Arrays.fill (topicWeights, 0.0);
			topicWeightsSum = 0;

			for (int ti = 0; ti < numTopics; ti++) {
				tw = ((currentTypeTopicCounts.get(ti) + beta) / (tokensPerTopic[ti] + betaSum))
				      * topicDistribution[ti]; // (/docLen-1+tAlpha); is constant across all topics
				topicWeightsSum += tw;
				topicWeights[ti] = tw;
			}
			// Sample a topic assignment from this distribution
			newTopic = random.nextDiscrete (topicWeights, topicWeightsSum);

			// Put that new topic into the counts
			oneDocTopics[si] = newTopic;
			currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
			tokensPerTopic[newTopic]++;
		}
	}

	//print human readable doc-topic matrix, for further IR use
	public void printTheta(ArrayList dataset, File f, double threshold, int max) throws IOException{
		PrintWriter pw = new PrintWriter(new FileWriter(f));
		int[] topicCounts = new int[ numTopics ];
		int docLen;
		
		for (int di = 0; di < dataset.size(); di++) {
			LabelSequence topicSequence = dataset.get(di).topicSequence;
			int[] currentDocTopics = topicSequence.getFeatures();
			docLen = currentDocTopics.length;
			for (int token=0; token < docLen; token++) {
				topicCounts[ currentDocTopics[token] ]++;
			}
			pw.println(dataset.get(di).instance.getName());
			// n(t|d)+alpha(t) / docLen + alphaSum
			for (int topic = 0; topic < numTopics; topic++) {
				double prob = (double) (topicCounts[topic]+alpha[topic]) / (docLen + alphaSum);
				pw.println("topic"+ topic + "\t" + prob);
			}

			pw.println();
			Arrays.fill(topicCounts, 0);
		}
		pw.close();
	}
	
	//print topic-word matrix, for further IR use
	public void printPhi(File f, double threshold) throws IOException{
		PrintWriter pw = new PrintWriter(new FileWriter(f));
		FeatureCounter[] wordCountsPerTopic = new FeatureCounter[numTopics];
		for (int ti = 0; ti < numTopics; ti++) {
			wordCountsPerTopic[ti] = new FeatureCounter(alphabet);
		}

		for (int fi = 0; fi < numTypes; fi++) {
			int[] topics = typeTopicCounts[fi].keys();
			for (int i = 0; i < topics.length; i++) {
				wordCountsPerTopic[topics[i]].increment(fi, typeTopicCounts[fi].get(topics[i]));
			}
		}
		
		for(int ti = 0; ti < numTopics; ti++){
			pw.println("Topic\t" + ti);
			FeatureCounter counter = wordCountsPerTopic[ti];
			FeatureVector fv = counter.toFeatureVector();
			for(int pos = 0; pos < fv.numLocations(); pos++){
				int fi = fv.indexAtLocation(pos);
				String word = (String) alphabet.lookupObject(fi);
				int count = (int) fv.valueAtLocation(pos);
				double prob;
				prob = (double) (count+beta)/(tokensPerTopic[ti] + betaSum);
				pw.println(word + "\t" + prob);
			}
			pw.println();
		}
		pw.close();
	}

	public void printDocumentTopics (ArrayList dataset, File f) throws IOException {
		printDocumentTopics (dataset, new PrintWriter (new FileWriter (f) ) );
	}

	public void printDocumentTopics (ArrayList dataset, PrintWriter pw) {
		printDocumentTopics (dataset, pw, 0.0, -1);
	}

	/**
	 *  @param pw          A print writer
	 *  @param threshold   Only print topics with proportion greater than this number
	 *  @param max         Print no more than this many topics
	 */
	public void printDocumentTopics (ArrayList dataset, PrintWriter pw, double threshold, int max)	{
		pw.print ("#doc source topic proportion ...\n");
		int docLen;
		int[] topicCounts = new int[ numTopics ];

		IDSorter[] sortedTopics = new IDSorter[ numTopics ];
		for (int topic = 0; topic < numTopics; topic++) {
			// Initialize the sorters with dummy values
			sortedTopics[topic] = new IDSorter(topic, topic);
		}

		if (max < 0 || max > numTopics) {
			max = numTopics;
		}

		for (int di = 0; di < dataset.size(); di++) {
			LabelSequence topicSequence = dataset.get(di).topicSequence;
			int[] currentDocTopics = topicSequence.getFeatures();

			pw.print (di); pw.print (' ');

			if (dataset.get(di).instance.getSource() != null) {
				pw.print (dataset.get(di).instance.getSource());
			}
			else {
				pw.print ("null-source");
			}

			pw.print (' ');
			docLen = currentDocTopics.length;

			// Count up the tokens
			int realDocLen = 0;
			for (int token=0; token < docLen; token++) {
				if(currentDocTopics[token] != -1) {
					topicCounts[ currentDocTopics[token] ]++;
					realDocLen ++;
				}
			}
			assert(realDocLen == docLen);
            alphaSum=0.0;
			for(int topic=0; topic < numTopics; topic++){
				alphaSum+=alpha[topic];
			}
			
			// And normalize and smooth by Dirichlet prior alpha
			for (int topic = 0; topic < numTopics; topic++) {
				sortedTopics[topic].set(topic, (double) (topicCounts[topic]+alpha[topic]) / (docLen + alphaSum));
			}
     
			Arrays.sort(sortedTopics);

			for (int i = 0; i < max; i++) {
				if (sortedTopics[i].getWeight() < threshold) { break; }

				pw.print (sortedTopics[i].getID() + " " +
						  sortedTopics[i].getWeight() + " ");
			}
			pw.print (" \n");

			Arrays.fill(topicCounts, 0);
		}
        pw.close();
	}
	
	public void printState (ArrayList dataset, File f) throws IOException {
		PrintStream out =
			new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
		printState(dataset, out);
		out.close();
	}

	public void printState (ArrayList dataset, PrintStream out) {

		out.println ("#doc source pos typeindex type topic");

		for (int di = 0; di < dataset.size(); di++) {
			FeatureSequence tokenSequence =	(FeatureSequence) dataset.get(di).instance.getData();
			LabelSequence topicSequence =	dataset.get(di).topicSequence;

			String source = "NA";
			if (dataset.get(di).instance.getSource() != null) {
				source = dataset.get(di).instance.getSource().toString();
			}

			for (int pi = 0; pi < topicSequence.getLength(); pi++) {
				int type = tokenSequence.getIndexAtPosition(pi);
				int topic = topicSequence.getIndexAtPosition(pi);
				out.print(di); out.print(' ');
				out.print(source); out.print(' ');
				out.print(pi); out.print(' ');
				out.print(type); out.print(' ');
				out.print(alphabet.lookupObject(type)); out.print(' ');
				out.print(topic); out.println();
			}
		}
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy