All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.topics.PolylingualTopicModel Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.	For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.topics;

import java.util.*;
import java.util.zip.*;

import java.io.*;
import java.text.NumberFormat;

import cc.mallet.types.*;
import cc.mallet.util.CommandOption;
import cc.mallet.util.Randoms;

/**
 * Latent Dirichlet Allocation for loosely parallel corpora in arbitrary languages
 * 
 * @author David Mimno, Andrew McCallum
 */

public class PolylingualTopicModel implements Serializable {

	static CommandOption.SpacedStrings languageInputFiles = new CommandOption.SpacedStrings
		(PolylingualTopicModel.class, "language-inputs", "FILENAME [FILENAME ...]", true, null,
		  "Filenames for polylingual topic model. Each language should have its own file, " +
		  "with the same number of instances in each file. If a document is missing in " + 
		 "one language, there should be an empty instance.", null);

	static CommandOption.String outputModelFilename = new CommandOption.String
		(PolylingualTopicModel.class, "output-model", "FILENAME", true, null,
		  "The filename in which to write the binary topic model at the end of the iterations.  " +
		 "By default this is null, indicating that no file will be written.", null);

	static CommandOption.String inputModelFilename = new CommandOption.String
		(PolylingualTopicModel.class, "input-model", "FILENAME", true, null,
		  "The filename from which to read the binary topic model to which the --input will be appended, " +
		  "allowing incremental training.  " +
		 "By default this is null, indicating that no file will be read.", null);

	static CommandOption.String inferencerFilename = new CommandOption.String
		(PolylingualTopicModel.class, "inferencer-filename", "FILENAME", true, null,
		  "A topic inferencer applies a previously trained topic model to new documents.  " +
		 "By default this is null, indicating that no file will be written.", null);

	static CommandOption.String evaluatorFilename = new CommandOption.String
		(PolylingualTopicModel.class, "evaluator-filename", "FILENAME", true, null,
		  "A held-out likelihood evaluator for new documents.  " +
		 "By default this is null, indicating that no file will be written.", null);

	static CommandOption.String stateFile = new CommandOption.String
		(PolylingualTopicModel.class, "output-state", "FILENAME", true, null,
		  "The filename in which to write the Gibbs sampling state after at the end of the iterations.  " +
		 "By default this is null, indicating that no file will be written.", null);

	static CommandOption.String topicKeysFile = new CommandOption.String
		(PolylingualTopicModel.class, "output-topic-keys", "FILENAME", true, null,
		  "The filename in which to write the top words for each topic and any Dirichlet parameters.  " +
		 "By default this is null, indicating that no file will be written.", null);

	static CommandOption.String docTopicsFile = new CommandOption.String
		(PolylingualTopicModel.class, "output-doc-topics", "FILENAME", true, null,
		  "The filename in which to write the topic proportions per document, at the end of the iterations.  " +
		 "By default this is null, indicating that no file will be written.", null);

	static CommandOption.Double docTopicsThreshold = new CommandOption.Double
		(PolylingualTopicModel.class, "doc-topics-threshold", "DECIMAL", true, 0.0,
		  "When writing topic proportions per document with --output-doc-topics, " +
		 "do not print topics with proportions less than this threshold value.", null);

	static CommandOption.Integer docTopicsMax = new CommandOption.Integer
		(PolylingualTopicModel.class, "doc-topics-max", "INTEGER", true, -1,
		  "When writing topic proportions per document with --output-doc-topics, " +
		  "do not print more than INTEGER number of topics.  "+
		 "A negative value indicates that all topics should be printed.", null);

	static CommandOption.Integer outputModelIntervalOption = new CommandOption.Integer
		(PolylingualTopicModel.class, "output-model-interval", "INTEGER", true, 0,
		  "The number of iterations between writing the model (and its Gibbs sampling state) to a binary file.  " +
		 "You must also set the --output-model to use this option, whose argument will be the prefix of the filenames.", null);

	static CommandOption.Integer outputStateIntervalOption = new CommandOption.Integer
		(PolylingualTopicModel.class, "output-state-interval", "INTEGER", true, 0,
		  "The number of iterations between writing the sampling state to a text file.  " +
		 "You must also set the --output-state to use this option, whose argument will be the prefix of the filenames.", null);

	static CommandOption.Integer numTopicsOption = new CommandOption.Integer
		(PolylingualTopicModel.class, "num-topics", "INTEGER", true, 10,
		 "The number of topics to fit.", null);

	static CommandOption.Integer numIterationsOption = new CommandOption.Integer
		(PolylingualTopicModel.class, "num-iterations", "INTEGER", true, 1000,
		 "The number of iterations of Gibbs sampling.", null);

	static CommandOption.Integer randomSeedOption = new CommandOption.Integer
		(PolylingualTopicModel.class, "random-seed", "INTEGER", true, 0,
		 "The random seed for the Gibbs sampler.  Default is 0, which will use the clock.", null);

	static CommandOption.Integer topWordsOption = new CommandOption.Integer
		(PolylingualTopicModel.class, "num-top-words", "INTEGER", true, 20,
		 "The number of most probable words to print for each topic after model estimation.", null);

	static CommandOption.Integer showTopicsIntervalOption = new CommandOption.Integer
		(PolylingualTopicModel.class, "show-topics-interval", "INTEGER", true, 50,
		 "The number of iterations between printing a brief summary of the topics so far.", null);

	static CommandOption.Integer optimizeIntervalOption = new CommandOption.Integer
		(PolylingualTopicModel.class, "optimize-interval", "INTEGER", true, 0,
		 "The number of iterations between reestimating dirichlet hyperparameters.", null);

	static CommandOption.Integer optimizeBurnInOption = new CommandOption.Integer
		(PolylingualTopicModel.class, "optimize-burn-in", "INTEGER", true, 200,
		 "The number of iterations to run before first estimating dirichlet hyperparameters.", null);

	static CommandOption.Double alphaOption = new CommandOption.Double
		(PolylingualTopicModel.class, "alpha", "DECIMAL", true, 50.0,
		 "Alpha parameter: smoothing over topic distribution.",null);

	static CommandOption.Double betaOption = new CommandOption.Double
		(PolylingualTopicModel.class, "beta", "DECIMAL", true, 0.01,
		 "Beta parameter: smoothing over unigram distribution.",null);
	
	public class TopicAssignment implements Serializable {
		public Instance[] instances;
		public LabelSequence[] topicSequences;
		public Labeling topicDistribution;
		
		public TopicAssignment (Instance[] instances, LabelSequence[] topicSequences) {
			this.instances = instances;
			this.topicSequences = topicSequences;
		}
	}

	int numLanguages = 1;

	protected ArrayList data;  // the training instances and their topic assignments
	protected LabelAlphabet topicAlphabet;  // the alphabet for the topics

	protected int numStopwords = 0;
	
	protected int numTopics; // Number of topics to be fit

	HashSet testingIDs = null;

	// These values are used to encode type/topic counts as
	//  count/topic pairs in a single int.
	protected int topicMask;
	protected int topicBits;

	protected Alphabet[] alphabets;
	protected int[] vocabularySizes;

	protected double[] alpha;	 // Dirichlet(alpha,alpha,...) is the distribution over topics
	protected double alphaSum;
	protected double[] betas;   // Prior on per-topic multinomial distribution over words
	protected double[] betaSums;

	protected int[] languageMaxTypeCounts;

	public static final double DEFAULT_BETA = 0.01;
	
	protected double[] languageSmoothingOnlyMasses;
	protected double[][] languageCachedCoefficients;
	int topicTermCount = 0;
	int betaTopicCount = 0;
	int smoothingOnlyCount = 0;

	// An array to put the topic counts for the current document. 
	// Initialized locally below.  Defined here to avoid
	// garbage collection overhead.
	protected int[] oneDocTopicCounts; // indexed by 

	protected int[][][] languageTypeTopicCounts; // indexed by 
	protected int[][] languageTokensPerTopic; // indexed by 

	// for dirichlet estimation
	protected int[] docLengthCounts; // histogram of document sizes, summed over languages
	protected int[][] topicDocCounts; // histogram of document/topic counts, indexed by 

	protected int iterationsSoFar = 1;
	public int numIterations = 1000;
	public int burninPeriod = 5;
	public int saveSampleInterval = 5; // was 10;	
	public int optimizeInterval = 10;
	public int showTopicsInterval = 10; // was 50;
	public int wordsPerTopic = 7;

	protected int saveModelInterval = 0;
	protected String modelFilename;

	protected int saveStateInterval = 0;
	protected String stateFilename = null;
	
	protected Randoms random;
	protected NumberFormat formatter;
	protected boolean printLogLikelihood = false;
	
	public PolylingualTopicModel (int numberOfTopics) {
		this (numberOfTopics, numberOfTopics);
	}
	
	public PolylingualTopicModel (int numberOfTopics, double alphaSum) {
		this (numberOfTopics, alphaSum, new Randoms());
	}
	
	private static LabelAlphabet newLabelAlphabet (int numTopics) {
		LabelAlphabet ret = new LabelAlphabet();
		for (int i = 0; i < numTopics; i++)
			ret.lookupIndex("topic"+i);
		return ret;
	}
	
	public PolylingualTopicModel (int numberOfTopics, double alphaSum, Randoms random) {
		this (newLabelAlphabet (numberOfTopics), alphaSum, random);
	}
	
	public PolylingualTopicModel (LabelAlphabet topicAlphabet, double alphaSum, Randoms random)
	{
		this.data = new ArrayList();
		this.topicAlphabet = topicAlphabet;
		this.numTopics = topicAlphabet.size();

		if (Integer.bitCount(numTopics) == 1) {
			// exact power of 2
			topicMask = numTopics - 1;
			topicBits = Integer.bitCount(topicMask);
		}
		else {
			// otherwise add an extra bit
			topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
			topicBits = Integer.bitCount(topicMask);
		}


		this.alphaSum = alphaSum;
		this.alpha = new double[numTopics];
		Arrays.fill(alpha, alphaSum / numTopics);
		this.random = random;
		
		formatter = NumberFormat.getInstance();
		formatter.setMaximumFractionDigits(5);

		System.err.println("Polylingual LDA: " + numTopics + " topics, " + topicBits + " topic bits, " + 
						   Integer.toBinaryString(topicMask) + " topic mask");
	}

    public void loadTestingIDs(File testingIDFile) throws IOException {
        testingIDs = new HashSet();

        BufferedReader in = new BufferedReader(new FileReader(testingIDFile));
        String id = null;
        while ((id = in.readLine()) != null) {
            testingIDs.add(id);
        }
        in.close();
    }
	
	public LabelAlphabet getTopicAlphabet() { return topicAlphabet; }
	public int getNumTopics() { return numTopics; }
	public ArrayList getData() { return data; }
	
	public void setNumIterations (int numIterations) {
		this.numIterations = numIterations;
	}

	public void setBurninPeriod (int burninPeriod) {
		this.burninPeriod = burninPeriod;
	}

	public void setTopicDisplay(int interval, int n) {
		this.showTopicsInterval = interval;
		this.wordsPerTopic = n;
	}

	public void setRandomSeed(int seed) {
		random = new Randoms(seed);
	}

	public void setOptimizeInterval(int interval) {
		this.optimizeInterval = interval;
	}

	public void setModelOutput(int interval, String filename) {
		this.saveModelInterval = interval;
		this.modelFilename = filename;
	}
	
	/** Define how often and where to save the state 
	 *
	 * @param interval Save a copy of the state every interval iterations.
	 * @param filename Save the state to this file, with the iteration number as a suffix
	 */
	public void setSaveState(int interval, String filename) {
		this.saveStateInterval = interval;
		this.stateFilename = filename;
	}
	
	public void addInstances (InstanceList[] training) {

		numLanguages = training.length;

		languageTokensPerTopic = new int[numLanguages][numTopics];
		
		alphabets = new Alphabet[ numLanguages ];
		vocabularySizes = new int[ numLanguages ];
		betas = new double[ numLanguages ];
		betaSums = new double[ numLanguages ];
		languageMaxTypeCounts = new int[ numLanguages ];
		languageTypeTopicCounts = new int[ numLanguages ][][];
		
		int numInstances = training[0].size();

		HashSet[] stoplists = new HashSet[ numLanguages ];

		for (int language = 0; language < numLanguages; language++) {

			if (training[language].size() != numInstances) {
				System.err.println("Warning: language " + language + " has " +
								   training[language].size() + " instances, lang 0 has " +
								   numInstances);
			}

			alphabets[ language ] = training[ language ].getDataAlphabet();
			vocabularySizes[ language ] = alphabets[ language ].size();
			
			betas[language] = DEFAULT_BETA;
			betaSums[language] = betas[language] * vocabularySizes[ language ];
		
			languageTypeTopicCounts[language] = new int[ vocabularySizes[language] ][];

			int[][] typeTopicCounts = languageTypeTopicCounts[language];

			// Get the total number of occurrences of each word type
			int[] typeTotals = new int[ vocabularySizes[language] ];
			
			for (Instance instance : training[language]) {
				if (testingIDs != null &&
					testingIDs.contains(instance.getName())) {
					continue;
				}

				FeatureSequence tokens = (FeatureSequence) instance.getData();
				for (int position = 0; position < tokens.getLength(); position++) {
					int type = tokens.getIndexAtPosition(position);
					typeTotals[ type ]++;
				}
			}

			/* Automatic stoplist creation, currently disabled
			TreeSet sortedWords = new TreeSet();
			for (int type = 0; type < vocabularySizes[language]; type++) {
				sortedWords.add(new IDSorter(type, typeTotals[type]));
			}

			stoplists[language] = new HashSet();
			Iterator typeIterator = sortedWords.iterator();
			int totalStopwords = 0;

			while (typeIterator.hasNext() && totalStopwords < numStopwords) {
				stoplists[language].add(typeIterator.next().getID());
			}
			*/
			
			// Allocate enough space so that we never have to worry about
			//  overflows: either the number of topics or the number of times
			//  the type occurs.
			for (int type = 0; type < vocabularySizes[language]; type++) {
				if (typeTotals[type] > languageMaxTypeCounts[language]) {
					languageMaxTypeCounts[language] = typeTotals[type];
				}
				typeTopicCounts[type] = new int[ Math.min(numTopics, typeTotals[type]) ];
			}
		}
		
		for (int doc = 0; doc < numInstances; doc++) {

			if (testingIDs != null && 
				testingIDs.contains(training[0].get(doc).getName())) {
				continue;
			}

			Instance[] instances = new Instance[ numLanguages ];
			LabelSequence[] topicSequences = new LabelSequence[ numLanguages ];

			for (int language = 0; language < numLanguages; language++) {
				
				int[][] typeTopicCounts = languageTypeTopicCounts[language];
				int[] tokensPerTopic = languageTokensPerTopic[language];

				instances[language] = training[language].get(doc);
				FeatureSequence tokens = (FeatureSequence) instances[language].getData();
				topicSequences[language] = 
					new LabelSequence(topicAlphabet, new int[ tokens.size() ]);
			
				int[] topics = topicSequences[language].getFeatures();
				for (int position = 0; position < tokens.size(); position++) {
					
					int type = tokens.getIndexAtPosition(position);
					int[] currentTypeTopicCounts = typeTopicCounts[ type ];
					
					int topic = random.nextInt(numTopics);

					// If the word is one of the [numStopwords] most
					//  frequent words, put it in a non-sampled topic.
					//if (stoplists[language].contains(type)) {
					//	topic = -1;
					//}

					topics[position] = topic;
					tokensPerTopic[topic]++;
					
					// The format for these arrays is 
					//  the topic in the rightmost bits
					//  the count in the remaining (left) bits.
					// Since the count is in the high bits, sorting (desc)
					//  by the numeric value of the int guarantees that
					//  higher counts will be before the lower counts.
					
					// Start by assuming that the array is either empty
					//  or is in sorted (descending) order.
					
					// Here we are only adding counts, so if we find 
					//  an existing location with the topic, we only need
					//  to ensure that it is not larger than its left neighbor.
					
					int index = 0;
					int currentTopic = currentTypeTopicCounts[index] & topicMask;
					int currentValue;
					
					while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
						index++;
						
						/*
							// Debugging output...
   					if (index >= currentTypeTopicCounts.length) {
							for (int i=0; i < currentTypeTopicCounts.length; i++) {
								System.out.println((currentTypeTopicCounts[i] & topicMask) + ":" +
												   (currentTypeTopicCounts[i] >> topicBits) + " ");
							}
							
							System.out.println(type + " " + typeTotals[type]);
						}
						*/
						currentTopic = currentTypeTopicCounts[index] & topicMask;
					}
					currentValue = currentTypeTopicCounts[index] >> topicBits;
					
					if (currentValue == 0) {
						// new value is 1, so we don't have to worry about sorting
						//  (except by topic suffix, which doesn't matter)
						
						currentTypeTopicCounts[index] =
							(1 << topicBits) + topic;
					}
					else {
						currentTypeTopicCounts[index] =
							((currentValue + 1) << topicBits) + topic;
						
						// Now ensure that the array is still sorted by 
						//  bubbling this value up.
						while (index > 0 &&
							   currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
							int temp = currentTypeTopicCounts[index];
							currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
							currentTypeTopicCounts[index - 1] = temp;
							
							index--;
						}
					}
				}
			}

			TopicAssignment t = new TopicAssignment (instances, topicSequences);
			data.add (t);
		}

		initializeHistograms();

		languageSmoothingOnlyMasses = new double[ numLanguages ];
		languageCachedCoefficients = new double[ numLanguages ][ numTopics ];

		cacheValues();
	}

	/** 
	 *  Gather statistics on the size of documents 
	 *  and create histograms for use in Dirichlet hyperparameter
	 *  optimization.
	 */
	private void initializeHistograms() {

		int maxTokens = 0;
		int totalTokens = 0;

		for (int doc = 0; doc < data.size(); doc++) {
			int length = 0;
			for (LabelSequence sequence : data.get(doc).topicSequences) {
				length += sequence.getLength();
			}

			if (length > maxTokens) {
				maxTokens = length;
			}

			totalTokens += length;
		}

		System.err.println("max tokens: " + maxTokens);
		System.err.println("total tokens: " + totalTokens);

		docLengthCounts = new int[maxTokens + 1];
		topicDocCounts = new int[numTopics][maxTokens + 1];
		
	}

	private void cacheValues() {

		for (int language = 0; language < numLanguages; language++) {
			languageSmoothingOnlyMasses[language] = 0.0;
			
			for (int topic=0; topic < numTopics; topic++) {
				languageSmoothingOnlyMasses[language] +=
					alpha[topic] * betas[language] /
					(languageTokensPerTopic[language][topic] + betaSums[language]);
				languageCachedCoefficients[language][topic] =
					alpha[topic] / (languageTokensPerTopic[language][topic] + betaSums[language]);
			}
			
		}
		
	}
	
	private void clearHistograms() {
		Arrays.fill(docLengthCounts, 0);
		for (int topic = 0; topic < topicDocCounts.length; topic++)
			Arrays.fill(topicDocCounts[topic], 0);
	}

	public void estimate () throws IOException {
		estimate (numIterations);
	}
	
	public void estimate (int iterationsThisRound) throws IOException {

		long startTime = System.currentTimeMillis();
		int maxIteration = iterationsSoFar + iterationsThisRound;

		long totalTime = 0;
	
		for ( ; iterationsSoFar <= maxIteration; iterationsSoFar++) {
			long iterationStart = System.currentTimeMillis();
			
			if (showTopicsInterval != 0 && iterationsSoFar != 0 && iterationsSoFar % showTopicsInterval == 0) {
				System.out.println();
				printTopWords (System.out, wordsPerTopic, false);

			}

			if (saveStateInterval != 0 && iterationsSoFar % saveStateInterval == 0) {
				this.printState(new File(stateFilename + '.' + iterationsSoFar));
			}

			/*
			  if (saveModelInterval != 0 && iterations % saveModelInterval == 0) {
			  this.write (new File(modelFilename+'.'+iterations));
			  }
			*/

			// TODO this condition should also check that we have more than one sample to work with here
			// (The number of samples actually obtained is not yet tracked.)
			if (iterationsSoFar > burninPeriod && optimizeInterval != 0 &&
				iterationsSoFar % optimizeInterval == 0) {

				alphaSum = Dirichlet.learnParameters(alpha, topicDocCounts, docLengthCounts);
				optimizeBetas();
				clearHistograms();
				cacheValues();
			}

			// Loop over every document in the corpus
			topicTermCount = betaTopicCount = smoothingOnlyCount = 0;

			for (int doc = 0; doc < data.size(); doc++) {

				sampleTopicsForOneDoc (data.get(doc),
									   (iterationsSoFar >= burninPeriod &&
										iterationsSoFar % saveSampleInterval == 0));
			}
		
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            totalTime += elapsedMillis;

			if ((iterationsSoFar + 1) % 10 == 0) {
				
				double ll = modelLogLikelihood();
				System.out.println(elapsedMillis + "\t" + totalTime + "\t" +
								   ll);
			}
			else {
				System.out.print(elapsedMillis + " ");
			}
		}

		/*
		long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
		long minutes = seconds / 60;	seconds %= 60;
		long hours = minutes / 60;	minutes %= 60;
		long days = hours / 24;	hours %= 24;
		System.out.print ("\nTotal time: ");
		if (days != 0) { System.out.print(days); System.out.print(" days "); }
		if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
		if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
		System.out.print(seconds); System.out.println(" seconds");
		*/
	}
	
	public void optimizeBetas() {
		
		for (int language = 0; language < numLanguages; language++) {
			
			// The histogram starts at count 0, so if all of the
			//  tokens of the most frequent type were assigned to one topic,
			//  we would need to store a maxTypeCount + 1 count.
			int[] countHistogram = new int[languageMaxTypeCounts[language] + 1];
			
			// Now count the number of type/topic pairs that have
			//  each number of tokens.
			
			int[][] typeTopicCounts = languageTypeTopicCounts[language];
			int[] tokensPerTopic = languageTokensPerTopic[language];

			int index;
			for (int type = 0; type < vocabularySizes[language]; type++) {
				int[] counts = typeTopicCounts[type];
				index = 0;
				while (index < counts.length &&
					   counts[index] > 0) {
					int count = counts[index] >> topicBits;
					countHistogram[count]++;
					index++;
				}
			}
			
			// Figure out how large we need to make the "observation lengths"
			//  histogram.
			int maxTopicSize = 0;
			for (int topic = 0; topic < numTopics; topic++) {
				if (tokensPerTopic[topic] > maxTopicSize) {
					maxTopicSize = tokensPerTopic[topic];
				}
			}
			
			// Now allocate it and populate it.
			int[] topicSizeHistogram = new int[maxTopicSize + 1];
			for (int topic = 0; topic < numTopics; topic++) {
				topicSizeHistogram[ tokensPerTopic[topic] ]++;
			}
			
			betaSums[language] = Dirichlet.learnSymmetricConcentration(countHistogram,
																	   topicSizeHistogram,
																	   vocabularySizes[ language ],
																	   betaSums[language]);
			betas[language] = betaSums[language] / vocabularySizes[ language ];
		}
	}

	protected void sampleTopicsForOneDoc (TopicAssignment topicAssignment, 
										  boolean shouldSaveState) {

		int[] currentTypeTopicCounts;
		int type, oldTopic, newTopic;
		double topicWeightsSum;

		int[] localTopicCounts = new int[numTopics];
		int[] localTopicIndex = new int[numTopics];

		for (int language = 0; language < numLanguages; language++) {

			int[] oneDocTopics =
				topicAssignment.topicSequences[language].getFeatures();
			int docLength = 
				topicAssignment.topicSequences[language].getLength();
			
			//		populate topic counts
			for (int position = 0; position < docLength; position++) {
				localTopicCounts[oneDocTopics[position]]++;
			}
		}

		// Build an array that densely lists the topics that
		//  have non-zero counts.
		int denseIndex = 0;
		for (int topic = 0; topic < numTopics; topic++) {
			if (localTopicCounts[topic] != 0) {
				localTopicIndex[denseIndex] = topic;
				denseIndex++;
			}
		}

		// Record the total number of non-zero topics
		int nonZeroTopics = denseIndex;

		for (int language = 0; language < numLanguages; language++) {

            int[] oneDocTopics =
				topicAssignment.topicSequences[language].getFeatures();
            int docLength =
				topicAssignment.topicSequences[language].getLength();
			FeatureSequence tokenSequence =
				(FeatureSequence) topicAssignment.instances[language].getData();

			int[][] typeTopicCounts = languageTypeTopicCounts[language];
			int[] tokensPerTopic = languageTokensPerTopic[language];
			double beta = betas[language];
			double betaSum = betaSums[language];

			// Initialize the smoothing-only sampling bucket
			double smoothingOnlyMass = languageSmoothingOnlyMasses[language];
			//for (int topic = 0; topic < numTopics; topic++) 
			//smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
			
			// Initialize the cached coefficients, using only smoothing.
			//cachedCoefficients = new double[ numTopics ];
			//for (int topic=0; topic < numTopics; topic++)
			//	cachedCoefficients[topic] =  alpha[topic] / (tokensPerTopic[topic] + betaSum);
			
			double[] cachedCoefficients =
				languageCachedCoefficients[language];

			//		Initialize the topic count/beta sampling bucket
			double topicBetaMass = 0.0;
			
			// Initialize cached coefficients and the topic/beta 
			//  normalizing constant.
			
			for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
				int topic = localTopicIndex[denseIndex];
				int n = localTopicCounts[topic];
				
				//	initialize the normalization constant for the (B * n_{t|d}) term
				topicBetaMass += beta * n /	(tokensPerTopic[topic] + betaSum);	
				
				//	update the coefficients for the non-zero topics
				cachedCoefficients[topic] =	(alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);
			}

			double topicTermMass = 0.0;

			double[] topicTermScores = new double[numTopics];
			int[] topicTermIndices;
			int[] topicTermValues;
			int i;
			double score;

			//	Iterate over the positions (words) in the document 
			for (int position = 0; position < docLength; position++) {
				type = tokenSequence.getIndexAtPosition(position);
				oldTopic = oneDocTopics[position];
				if (oldTopic == -1) { continue; }

				currentTypeTopicCounts = typeTopicCounts[type];
				
				//	Remove this token from all counts. 
				
				// Remove this topic's contribution to the 
				//  normalizing constants
				smoothingOnlyMass -= alpha[oldTopic] * beta / 
					(tokensPerTopic[oldTopic] + betaSum);
				topicBetaMass -= beta * localTopicCounts[oldTopic] /
					(tokensPerTopic[oldTopic] + betaSum);
				
				// Decrement the local doc/topic counts
				
				localTopicCounts[oldTopic]--;
				
				// Maintain the dense index, if we are deleting
				//  the old topic
				if (localTopicCounts[oldTopic] == 0) {
					
					// First get to the dense location associated with
					//  the old topic.
					
					denseIndex = 0;
					
					// We know it's in there somewhere, so we don't 
					//  need bounds checking.
					while (localTopicIndex[denseIndex] != oldTopic) {
						denseIndex++;
					}
					
					// shift all remaining dense indices to the left.
					while (denseIndex < nonZeroTopics) {
						if (denseIndex < localTopicIndex.length - 1) {
							localTopicIndex[denseIndex] = 
								localTopicIndex[denseIndex + 1];
						}
						denseIndex++;
					}
					
					nonZeroTopics --;
				}
				
				// Decrement the global topic count totals
				tokensPerTopic[oldTopic]--;
				//assert(tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0";
					
				
				// Add the old topic's contribution back into the
				//  normalizing constants.
				smoothingOnlyMass += alpha[oldTopic] * beta / 
					(tokensPerTopic[oldTopic] + betaSum);
				topicBetaMass += beta * localTopicCounts[oldTopic] /
					(tokensPerTopic[oldTopic] + betaSum);
				
				// Reset the cached coefficient for this topic
				cachedCoefficients[oldTopic] = 
					(alpha[oldTopic] + localTopicCounts[oldTopic]) /
					(tokensPerTopic[oldTopic] + betaSum);
				
				
				// Now go over the type/topic counts, decrementing
				//  where appropriate, and calculating the score
				//  for each topic at the same time.
				
				int index = 0;
				int currentTopic, currentValue;
				
				boolean alreadyDecremented = false;
				
				topicTermMass = 0.0;
				
				while (index < currentTypeTopicCounts.length && 
					   currentTypeTopicCounts[index] > 0) {
					currentTopic = currentTypeTopicCounts[index] & topicMask;
					currentValue = currentTypeTopicCounts[index] >> topicBits;
					
					if (! alreadyDecremented && 
						currentTopic == oldTopic) {
						
						// We're decrementing and adding up the 
						//  sampling weights at the same time, but
						//  decrementing may require us to reorder
						//  the topics, so after we're done here,
						//  look at this cell in the array again.
						
						currentValue --;
						if (currentValue == 0) {
							currentTypeTopicCounts[index] = 0;
						}
						else {
							currentTypeTopicCounts[index] =
								(currentValue << topicBits) + oldTopic;
						}
						
						// Shift the reduced value to the right, if necessary.
						
						int subIndex = index;
						while (subIndex < currentTypeTopicCounts.length - 1 && 
							   currentTypeTopicCounts[subIndex] < currentTypeTopicCounts[subIndex + 1]) {
							int temp = currentTypeTopicCounts[subIndex];
							currentTypeTopicCounts[subIndex] = currentTypeTopicCounts[subIndex + 1];
							currentTypeTopicCounts[subIndex + 1] = temp;
							
							subIndex++;
						}
						
						alreadyDecremented = true;
					}
					else {
						score = 
							cachedCoefficients[currentTopic] * currentValue;
						topicTermMass += score;
						topicTermScores[index] = score;
						
						index++;
					}
				}
				
				double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
				double origSample = sample;
				
				//	Make sure it actually gets set
				newTopic = -1;
				
				if (sample < topicTermMass) {
					//topicTermCount++;
					
					i = -1;
					while (sample > 0) {
						i++;
						sample -= topicTermScores[i];
					}
					
					newTopic = currentTypeTopicCounts[i] & topicMask;
					currentValue = currentTypeTopicCounts[i] >> topicBits;
					
					currentTypeTopicCounts[i] = ((currentValue + 1) << topicBits) + newTopic;
					
					// Bubble the new value up, if necessary
					
					while (i > 0 &&
						   currentTypeTopicCounts[i] > currentTypeTopicCounts[i - 1]) {
						int temp = currentTypeTopicCounts[i];
						currentTypeTopicCounts[i] = currentTypeTopicCounts[i - 1];
						currentTypeTopicCounts[i - 1] = temp;
						
						i--;
					}
					
				}
				else {
					sample -= topicTermMass;
					
					if (sample < topicBetaMass) {
						//betaTopicCount++;
						
						sample /= beta;
						
						for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
							int topic = localTopicIndex[denseIndex];
							
							sample -= localTopicCounts[topic] /
								(tokensPerTopic[topic] + betaSum);

							if (sample <= 0.0) {
								newTopic = topic;
								break;
							}
						}
						
					}
					else {
						//smoothingOnlyCount++;
						
						sample -= topicBetaMass;
						
						sample /= beta;
						
						newTopic = 0;
						sample -= alpha[newTopic] /
							(tokensPerTopic[newTopic] + betaSum);
						
						while (sample > 0.0) {
							newTopic++;
							sample -= alpha[newTopic] /
								(tokensPerTopic[newTopic] + betaSum);
						}
					
					}
					
					// Move to the position for the new topic,
					//  which may be the first empty position if this
					//  is a new topic for this word.
					
					index = 0;
					while (currentTypeTopicCounts[index] > 0 &&
						   (currentTypeTopicCounts[index] & topicMask) != newTopic) {
						index++;
					}
					
					// index should now be set to the position of the new topic,
					//  which may be an empty cell at the end of the list.
					
					if (currentTypeTopicCounts[index] == 0) {
						// inserting a new topic, guaranteed to be in
						//  order w.r.t. count, if not topic.
						currentTypeTopicCounts[index] = (1 << topicBits) + newTopic;
					}
					else {
						currentValue = currentTypeTopicCounts[index] >> topicBits;
						currentTypeTopicCounts[index] = ((currentValue + 1) << topicBits) + newTopic;
						
						// Bubble the increased value left, if necessary
						while (index > 0 &&
							   currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
							int temp = currentTypeTopicCounts[index];
							currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
							currentTypeTopicCounts[index - 1] = temp;
							
							index--;
						}
					}
					
				}
				
				if (newTopic == -1) {
					System.err.println("PolylingualTopicModel sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " + 
									   topicBetaMass + " " + topicTermMass);
					newTopic = numTopics-1; // TODO is this appropriate
					//throw new IllegalStateException ("PolylingualTopicModel: New topic not sampled.");
				}
				//assert(newTopic != -1);
				
				//			Put that new topic into the counts
				oneDocTopics[position] = newTopic;
				
				smoothingOnlyMass -= alpha[newTopic] * beta / 
					(tokensPerTopic[newTopic] + betaSum);
				topicBetaMass -= beta * localTopicCounts[newTopic] /
					(tokensPerTopic[newTopic] + betaSum);
				
				localTopicCounts[newTopic]++;
				
				// If this is a new topic for this document,
				//  add the topic to the dense index.
				if (localTopicCounts[newTopic] == 1) {
					
					// First find the point where we 
					//  should insert the new topic by going to
					//  the end (which is the only reason we're keeping
					//  track of the number of non-zero
					//  topics) and working backwards
					
					denseIndex = nonZeroTopics;
					
					while (denseIndex > 0 &&
						   localTopicIndex[denseIndex - 1] > newTopic) {
						
						localTopicIndex[denseIndex] =
							localTopicIndex[denseIndex - 1];
						denseIndex--;
					}
					
					localTopicIndex[denseIndex] = newTopic;
					nonZeroTopics++;
				}
				
				tokensPerTopic[newTopic]++;
				
				//	update the coefficients for the non-zero topics
				cachedCoefficients[newTopic] =
					(alpha[newTopic] + localTopicCounts[newTopic]) /
					(tokensPerTopic[newTopic] + betaSum);
				
				smoothingOnlyMass += alpha[newTopic] * beta / 
					(tokensPerTopic[newTopic] + betaSum);
				topicBetaMass += beta * localTopicCounts[newTopic] /
					(tokensPerTopic[newTopic] + betaSum);
				
				// Save the smoothing-only mass to the global cache
				languageSmoothingOnlyMasses[language] = smoothingOnlyMass;

			}
		}

		if (shouldSaveState) {
			// Update the document-topic count histogram,
			//  for dirichlet estimation

			int totalLength = 0;

			for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
				int topic = localTopicIndex[denseIndex];
				
				topicDocCounts[topic][ localTopicCounts[topic] ]++;
				totalLength += localTopicCounts[topic];
			}

			docLengthCounts[ totalLength ]++;

		}

	}

	public void printTopWords (File file, int numWords, boolean useNewLines) throws IOException {
		PrintStream out = new PrintStream (file);
		printTopWords(out, numWords, useNewLines);
		out.close();
	}
	
    public void printTopWords (PrintStream out, int numWords, boolean usingNewLines) {

		TreeSet[][] languageTopicSortedWords = new TreeSet[numLanguages][numTopics];

		for (int language = 0; language < numLanguages; language++) {
			TreeSet[] topicSortedWords = languageTopicSortedWords[language];
			int[][] typeTopicCounts = languageTypeTopicCounts[language];

			for (int topic = 0; topic < numTopics; topic++) {
				topicSortedWords[topic] = new TreeSet();
			}

			for (int type = 0; type < vocabularySizes[language]; type++) {
				
				int[] topicCounts = typeTopicCounts[type];
				
				int index = 0;
				while (index < topicCounts.length &&
					   topicCounts[index] > 0) {
					
					int topic = topicCounts[index] & topicMask;
					int count = topicCounts[index] >> topicBits;
					
					topicSortedWords[topic].add(new IDSorter(type, count));

					index++;
				}
			}
		}

        for (int topic = 0; topic < numTopics; topic++) {

			out.println (topic + "\t" + formatter.format(alpha[topic]));
				
			for (int language = 0; language < numLanguages; language++) {
				
				out.print(" " + language + "\t" + languageTokensPerTopic[language][topic] + "\t" + betas[language] + "\t");

				TreeSet sortedWords = languageTopicSortedWords[language][topic];
				Alphabet alphabet = alphabets[language];

				int word = 1;
				Iterator iterator = sortedWords.iterator();
				while (iterator.hasNext() && word < numWords) {
					IDSorter info = iterator.next();
					
					out.print(alphabet.lookupObject(info.getID()) + " ");
					word++;
				}
				
				out.println();
            }
        }
    }

	public void printDocumentTopics (File f) throws IOException {
		printDocumentTopics (new PrintWriter (f, "UTF-8") );
	}

	public void printDocumentTopics (PrintWriter pw) {
		printDocumentTopics (pw, 0.0, -1);
	}

	/**
	 *  @param pw          A print writer
	 *  @param threshold   Only print topics with proportion greater than this number
	 *  @param max         Print no more than this many topics
	 */
	public void printDocumentTopics (PrintWriter pw, double threshold, int max)	{
		pw.print ("#doc source topic proportion ...\n");
		int docLength;
		int[] topicCounts = new int[ numTopics ];

		IDSorter[] sortedTopics = new IDSorter[ numTopics ];
		for (int topic = 0; topic < numTopics; topic++) {
			// Initialize the sorters with dummy values
			sortedTopics[topic] = new IDSorter(topic, topic);
		}

		if (max < 0 || max > numTopics) {
			max = numTopics;
		}

		for (int di = 0; di < data.size(); di++) {

			pw.print (di); pw.print (' ');

			int totalLength = 0;

			for (int language = 0; language < numLanguages; language++) {
			
				LabelSequence topicSequence = (LabelSequence) data.get(di).topicSequences[language];
				int[] currentDocTopics = topicSequence.getFeatures();
				
				docLength = topicSequence.getLength();
				totalLength += docLength;
				
				// Count up the tokens
				for (int token=0; token < docLength; token++) {
					topicCounts[ currentDocTopics[token] ]++;
				}
			}
				
			// And normalize
			for (int topic = 0; topic < numTopics; topic++) {
				sortedTopics[topic].set(topic, (float) topicCounts[topic] / totalLength);
			}
			
			Arrays.sort(sortedTopics);

			for (int i = 0; i < max; i++) {
				if (sortedTopics[i].getWeight() < threshold) { break; }
				
				pw.print (sortedTopics[i].getID() + " " + 
						  sortedTopics[i].getWeight() + " ");
			}
			pw.print (" \n");

			Arrays.fill(topicCounts, 0);
		}
		
	}
	
	public void printState (File f) throws IOException {
		PrintStream out =
			new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))), 
							false, "UTF-8");
		printState(out);
		out.close();
	}
	
	public void printState (PrintStream out) {

		out.println ("#doc lang pos typeindex type topic");

		for (int doc = 0; doc < data.size(); doc++) {
			for (int language =0; language < numLanguages; language++) {
				FeatureSequence tokenSequence =	(FeatureSequence) data.get(doc).instances[language].getData();
				LabelSequence topicSequence =	(LabelSequence) data.get(doc).topicSequences[language];
				
				for (int pi = 0; pi < topicSequence.getLength(); pi++) {
					int type = tokenSequence.getIndexAtPosition(pi);
					int topic = topicSequence.getIndexAtPosition(pi);
					out.print(doc); out.print(' ');
					out.print(language); out.print(' '); 
					out.print(pi); out.print(' ');
					out.print(type); out.print(' ');
					out.print(alphabets[language].lookupObject(type)); out.print(' ');
					out.print(topic); out.println();
				}
			}
		}
	}

	public double modelLogLikelihood() {
		double logLikelihood = 0.0;
		int nonZeroTopics;

		// The likelihood of the model is a combination of a 
		// Dirichlet-multinomial for the words in each topic
		// and a Dirichlet-multinomial for the topics in each
		// document.

		// The likelihood function of a dirichlet multinomial is
		//	 Gamma( sum_i alpha_i )	 prod_i Gamma( alpha_i + N_i )
		//	prod_i Gamma( alpha_i )	  Gamma( sum_i (alpha_i + N_i) )

		// So the log likelihood is 
		//	logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) + 
		//	 sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ]

		// Do the documents first

		int[] topicCounts = new int[numTopics];
		double[] topicLogGammas = new double[numTopics];
		int[] docTopics;

		for (int topic=0; topic < numTopics; topic++) {
			topicLogGammas[ topic ] = Dirichlet.logGammaStirling( alpha[topic] );
		}
	
		for (int doc=0; doc < data.size(); doc++) {

			int totalLength = 0;

            for (int language = 0; language < numLanguages; language++) {

                LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequences[language];
                int[] currentDocTopics = topicSequence.getFeatures();

				totalLength += topicSequence.getLength();

                // Count up the tokens
                for (int token=0; token < topicSequence.getLength(); token++) {
                    topicCounts[ currentDocTopics[token] ]++;
                }
            }

			for (int topic=0; topic < numTopics; topic++) {
				if (topicCounts[topic] > 0) {
					logLikelihood += (Dirichlet.logGammaStirling(alpha[topic] + topicCounts[topic]) -
									  topicLogGammas[ topic ]);
				}
			}

			// subtract the (count + parameter) sum term
			logLikelihood -= Dirichlet.logGammaStirling(alphaSum + totalLength);

			Arrays.fill(topicCounts, 0);
		}
	
		// add the parameter sum term
		logLikelihood += data.size() * Dirichlet.logGammaStirling(alphaSum);

		// And the topics

		for (int language = 0; language < numLanguages; language++) {
			int[][] typeTopicCounts = languageTypeTopicCounts[language];
			int[] tokensPerTopic = languageTokensPerTopic[language];
			double beta = betas[language];

			// Count the number of type-topic pairs
			int nonZeroTypeTopics = 0;
			
			for (int type=0; type < vocabularySizes[language]; type++) {
				// reuse this array as a pointer
				
				topicCounts = typeTopicCounts[type];
				
				int index = 0;
				while (index < topicCounts.length &&
					   topicCounts[index] > 0) {
					int topic = topicCounts[index] & topicMask;
					int count = topicCounts[index] >> topicBits;
					
					nonZeroTypeTopics++;
					logLikelihood += Dirichlet.logGammaStirling(beta + count);
					
					if (Double.isNaN(logLikelihood)) {
						System.out.println(count);
						System.exit(1);
					}
					
					index++;
				}
			}
			
			for (int topic=0; topic < numTopics; topic++) {
				logLikelihood -= 
					Dirichlet.logGammaStirling( (beta * numTopics) +
												tokensPerTopic[ topic ] );
				if (Double.isNaN(logLikelihood)) {
					System.out.println("after topic " + topic + " " + tokensPerTopic[ topic ]);
					System.exit(1);
				}
				
			}
			
			logLikelihood += 
				(Dirichlet.logGammaStirling(beta * numTopics)) -
				(Dirichlet.logGammaStirling(beta) * nonZeroTypeTopics);
		}

		if (Double.isNaN(logLikelihood)) {
			System.out.println("at the end");
			System.exit(1);
		}


		return logLikelihood;
	}
	
    /** Return a tool for estimating topic distributions for new documents */
    public TopicInferencer getInferencer(int language) {
		return new TopicInferencer(languageTypeTopicCounts[language], languageTokensPerTopic[language],
                                   alphabets[language],
								   alpha, betas[language], betaSums[language]);
    }

	// Serialization

	private static final long serialVersionUID = 1;
	private static final int CURRENT_SERIAL_VERSION = 0;
	private static final int NULL_INTEGER = -1;

	private void writeObject (ObjectOutputStream out) throws IOException {
		out.writeInt (CURRENT_SERIAL_VERSION);
		
		out.writeInt(numLanguages);
		out.writeObject(data);
		out.writeObject(topicAlphabet);

		out.writeInt(numTopics);

		out.writeObject(testingIDs);

		out.writeInt(topicMask);
		out.writeInt(topicBits);

		out.writeObject(alphabets);
		out.writeObject(vocabularySizes);

		out.writeObject(alpha);
		out.writeDouble(alphaSum);
		out.writeObject(betas);
		out.writeObject(betaSums);

		out.writeObject(languageMaxTypeCounts);

		out.writeObject(languageTypeTopicCounts);
		out.writeObject(languageTokensPerTopic);

		out.writeObject(languageSmoothingOnlyMasses);
		out.writeObject(languageCachedCoefficients);

		out.writeObject(docLengthCounts);
		out.writeObject(topicDocCounts);

		out.writeInt(numIterations);
		out.writeInt(burninPeriod);
		out.writeInt(saveSampleInterval);
		out.writeInt(optimizeInterval);
		out.writeInt(showTopicsInterval);
		out.writeInt(wordsPerTopic);

		out.writeInt(saveStateInterval);
		out.writeObject(stateFilename);

		out.writeInt(saveModelInterval);
		out.writeObject(modelFilename);

		out.writeObject(random);
		out.writeObject(formatter);
		out.writeBoolean(printLogLikelihood);

	}

	private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
		
		int version = in.readInt();

		numLanguages = in.readInt();
		data = (ArrayList) in.readObject ();
		topicAlphabet = (LabelAlphabet) in.readObject();
		
		numTopics = in.readInt();
		
		testingIDs = (HashSet) in.readObject();

		topicMask = in.readInt();
		topicBits = in.readInt();
		
		alphabets = (Alphabet[]) in.readObject();
		vocabularySizes = (int[]) in.readObject();
		
		alpha = (double[]) in.readObject();
		alphaSum = in.readDouble();
		betas = (double[]) in.readObject();
		betaSums = (double[]) in.readObject();

		languageMaxTypeCounts = (int[]) in.readObject();
		
		languageTypeTopicCounts = (int[][][]) in.readObject();
		languageTokensPerTopic = (int[][]) in.readObject();
		
		languageSmoothingOnlyMasses = (double[]) in.readObject();
		languageCachedCoefficients = (double[][]) in.readObject();

		docLengthCounts = (int[]) in.readObject();
		topicDocCounts = (int[][]) in.readObject();
		
		numIterations = in.readInt();
		burninPeriod = in.readInt();
		saveSampleInterval = in.readInt();
		optimizeInterval = in.readInt();
		showTopicsInterval = in.readInt();
		wordsPerTopic = in.readInt();

		saveStateInterval = in.readInt();
		stateFilename = (String) in.readObject();
		
		saveModelInterval = in.readInt();
		modelFilename = (String) in.readObject();
		
		random = (Randoms) in.readObject();
		formatter = (NumberFormat) in.readObject();
		printLogLikelihood = in.readBoolean();

	}

	public void write (File serializedModelFile) {
		try {
			ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream(serializedModelFile));
			oos.writeObject(this);
			oos.close();
		} catch (IOException e) {
			System.err.println("Problem serializing PolylingualTopicModel to file " +
							   serializedModelFile + ": " + e);
		}
	}

	public static PolylingualTopicModel read (File f) throws Exception {

		PolylingualTopicModel topicModel = null;

		ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f));
		topicModel = (PolylingualTopicModel) ois.readObject();
		ois.close();

		topicModel.initializeHistograms();

		return topicModel;
	}


	public static void main (String[] args) throws IOException {

		CommandOption.setSummary (PolylingualTopicModel.class,
								  "A tool for estimating, saving and printing diagnostics for topic models over comparable corpora.");
		CommandOption.process (PolylingualTopicModel.class, args);

		PolylingualTopicModel topicModel = null;

		if (inputModelFilename.value != null) {

			try {
				topicModel = PolylingualTopicModel.read(new File(inputModelFilename.value));
			} catch (Exception e) {
				System.err.println("Unable to restore saved topic model " + 
								   inputModelFilename.value + ": " + e);
				System.exit(1);
			}
		}
		else {

			int numLanguages = languageInputFiles.value.length;
			
			InstanceList[] training = new InstanceList[ numLanguages ];
			for (int i=0; i < training.length; i++) {
				training[i] = InstanceList.load(new File(languageInputFiles.value[i]));
				if (training[i] != null) { System.out.println(i + " is not null"); }
				else { System.out.println(i + " is null"); }
			}

			System.out.println ("Data loaded.");
		
			// For historical reasons we currently only support FeatureSequence data,
			//  not the FeatureVector, which is the default for the input functions.
			//  Provide a warning to avoid ClassCastExceptions.
			if (training[0].size() > 0 &&
				training[0].get(0) != null) {
				Object data = training[0].get(0).getData();
				if (! (data instanceof FeatureSequence)) {
					System.err.println("Topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.");
					System.exit(1);
				}
			}
			
			topicModel = new PolylingualTopicModel (numTopicsOption.value, alphaOption.value);
			if (randomSeedOption.value != 0) {
				topicModel.setRandomSeed(randomSeedOption.value);
			}
			
			topicModel.addInstances(training);
		}

		topicModel.setTopicDisplay(showTopicsIntervalOption.value, topWordsOption.value);

		topicModel.setNumIterations(numIterationsOption.value);
		topicModel.setOptimizeInterval(optimizeIntervalOption.value);
		topicModel.setBurninPeriod(optimizeBurnInOption.value);

		if (outputStateIntervalOption.value != 0) {
			topicModel.setSaveState(outputStateIntervalOption.value, stateFile.value);
		}

		if (outputModelIntervalOption.value != 0) {
			topicModel.setModelOutput(outputModelIntervalOption.value, outputModelFilename.value);
		}

		topicModel.estimate();

		if (topicKeysFile.value != null) {
			topicModel.printTopWords(new File(topicKeysFile.value), topWordsOption.value, false);
		}

		if (stateFile.value != null) {
			topicModel.printState (new File(stateFile.value));
		}

		if (docTopicsFile.value != null) {
			PrintWriter out = new PrintWriter (new FileWriter ((new File(docTopicsFile.value))));
			topicModel.printDocumentTopics(out, docTopicsThreshold.value, docTopicsMax.value);
			out.close();
		}

		if (inferencerFilename.value != null) {
			try {
				for (int language = 0; language < topicModel.numLanguages; language++) {

					ObjectOutputStream oos =
						new ObjectOutputStream(new FileOutputStream(inferencerFilename.value + "." + language));
					oos.writeObject(topicModel.getInferencer(language));
					oos.close();
				}

			} catch (Exception e) {
				System.err.println(e.getMessage());

			}

		}

		if (outputModelFilename.value != null) {
			assert (topicModel != null);
			
			topicModel.write(new File(outputModelFilename.value));
		}

	}
	
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy