All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.topics.WeightedTopicModel Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.topics;

import java.util.*;
import java.util.logging.*;
import java.util.zip.*;
import java.util.regex.*;

import java.io.*;
import java.text.NumberFormat;

import cc.mallet.types.*;
import cc.mallet.util.*;

import gnu.trove.*;

public class WeightedTopicModel implements Serializable {

	private static Logger logger = MalletLogger.getLogger(WeightedTopicModel.class.getName());
	
	static CommandOption.String inputFile = new CommandOption.String
		(WeightedTopicModel.class, "input", "FILENAME", true, null,
		 "The filename from which to read the list of training instances.  Use - for stdin.  " +
		 "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);
	
	static CommandOption.String weightsFile = new CommandOption.String
		(WeightedTopicModel.class, "weights-filename", "FILENAME", true, null,
		 "The filename for the word-word weights file.", null);
	
	static CommandOption.String evaluatorFilename = new CommandOption.String
		(WeightedTopicModel.class, "evaluator-filename", "FILENAME", true, null,
		 "A held-out likelihood evaluator for new documents.  " +
		 "By default this is null, indicating that no file will be written.", null);
	
	static CommandOption.String stateFile = new CommandOption.String
		(WeightedTopicModel.class, "state-filename", "FILENAME", true, null,
		 "The filename in which to write the Gibbs sampling state after at the end of the iterations.  " +
		 "By default this is null, indicating that no file will be written.", null);
	
	static CommandOption.Integer numTopicsOption = new CommandOption.Integer
		(WeightedTopicModel.class, "num-topics", "INTEGER", true, 10,
		 "The number of topics to fit.", null);
	
	static CommandOption.Integer numEpochsOption = new CommandOption.Integer
		(WeightedTopicModel.class, "num-epochs", "INTEGER", true, 1,
		 "The number of cycles of training. Evaluators and state files will be saved after each epoch.", null);
	
	static CommandOption.Integer numIterationsOption = new CommandOption.Integer
		(WeightedTopicModel.class, "num-iterations", "INTEGER", true, 1000,
		 "The number of iterations of Gibbs sampling PER EPOCH.", null);
	
	static CommandOption.Integer randomSeedOption = new CommandOption.Integer
		(WeightedTopicModel.class, "random-seed", "INTEGER", true, 0,
		 "The random seed for the Gibbs sampler.  Default is 0, which will use the clock.", null);

	static CommandOption.Double alphaOption = new CommandOption.Double
		(WeightedTopicModel.class, "alpha", "DECIMAL", true, 50.0,
		 "Alpha parameter: smoothing over topic distribution.",null);
	
	static CommandOption.Double betaOption = new CommandOption.Double
		(WeightedTopicModel.class, "beta", "DECIMAL", true, 0.01,
		 "Beta parameter: smoothing over topic distribution.",null);
	

	public static Pattern sourceWordPattern = Pattern.compile("(.*) \\((\\d+)\\)");
	public static Pattern targetWordPattern = Pattern.compile("  (\\d+)\t(\\d+)\t([\\d\\.]+)\t(.*)");

	// the training instances and their topic assignments
	protected ArrayList data;  

	// the alphabet for the input data
	protected Alphabet alphabet; 

	// the alphabet for the topics
	protected LabelAlphabet topicAlphabet; 
	
	// The number of topics requested
	protected int numTopics;

	// The size of the vocabulary
	protected int numTypes;

	// Prior parameters
	protected double alpha;	 // Dirichlet(alpha,alpha,...) is the distribution over topics
	protected double alphaSum;
	
	protected double beta;
	protected double betaSum;

	// An array to put the topic counts for the current document. 
	// Initialized locally below.  Defined here to avoid
	// garbage collection overhead.
	protected int[] oneDocTopicCounts; // indexed by 

	// Statistics needed for sampling.
	protected int[][] typeTopicCounts; // indexed by 
	protected int[] tokensPerTopic; // indexed by 

	// Weights on type-type interactions
	protected TIntDoubleHashMap[] typeTypeWeights;

	protected double[][] logTypeTopicWeights;
	protected double[][] typeTopicWeights;
	protected double[] totalTopicWeights;

	public int showTopicsInterval = 50;
	public int wordsPerTopic = 10;
	
	protected Randoms random;
	protected NumberFormat formatter;
	protected boolean printLogLikelihood = false;

	protected double[] logCountRatioCache;
	
	public WeightedTopicModel (int numberOfTopics, double alphaSum, double beta, Randoms random) {

		this.data = new ArrayList();
		this.topicAlphabet = AlphabetFactory.labelAlphabetOfSize(numberOfTopics);
		this.numTopics = topicAlphabet.size();

		this.alphaSum = alphaSum;
		this.alpha = alphaSum / numTopics;
		this.beta = beta;
		this.random = random;
		
		oneDocTopicCounts = new int[numTopics];
		tokensPerTopic = new int[numTopics];
		
		formatter = NumberFormat.getInstance();
		formatter.setMaximumFractionDigits(5);

		logger.info("Weighted LDA: " + numTopics + " topics");
	}
	
	public Alphabet getAlphabet() { return alphabet; }
	public LabelAlphabet getTopicAlphabet() { return topicAlphabet; }
	public int getNumTopics() { return numTopics; }
	public ArrayList getData() { return data; }
	
	public void setTopicDisplay(int interval, int n) {
		this.showTopicsInterval = interval;
		this.wordsPerTopic = n;
	}

	public void setRandomSeed(int seed) {
		random = new Randoms(seed);
	}
	
	public int[][] getTypeTopicCounts() { return typeTopicCounts; }
	public int[] getTopicTotals() { return tokensPerTopic; }

	public void addInstances (InstanceList training) {

		alphabet = training.getDataAlphabet();
		numTypes = alphabet.size();
		betaSum = beta * numTypes;
		
		typeTopicCounts = new int[numTypes][numTopics];

		typeTopicWeights = new double[numTypes][numTopics];
		totalTopicWeights = new double[numTopics];

		for (int type = 0; type < numTypes; type++) {
			Arrays.fill(typeTopicWeights[type], beta);
		}
		Arrays.fill(totalTopicWeights, betaSum);

		int doc = 0;

		for (Instance instance : training) {
			doc++;

			FeatureSequence tokenSequence = (FeatureSequence) instance.getData();
			LabelSequence topicSequence =
				new LabelSequence(topicAlphabet, new int[ tokenSequence.size() ]);

			TopicAssignment t = new TopicAssignment (instance, topicSequence);
			data.add (t);
		}

	}

	public void readTypeTypeWeights (File weightsFile) throws Exception {
		
		typeTypeWeights = new TIntDoubleHashMap[numTypes];

		logger.info("num types: " + numTypes);

		for (int type = 0; type < numTypes; type++) {
			typeTypeWeights[type] = new TIntDoubleHashMap();
			typeTypeWeights[type].put(type, 1.0);
		}

		int sourceType = 0; // java complains if we don't initialize
		boolean sourceWordValid = true;
		
		BufferedReader reader = new BufferedReader(new FileReader(weightsFile));
		String line;
		while ((line = reader.readLine()) != null) {

			String[] fields = line.split("\t");

			double sum = 0.0;
			for (int i=1; i < fields.length; i += 2) {
				sum += Double.parseDouble(fields[i]);
			}

			sourceType = alphabet.lookupIndex( fields[0] );
			typeTypeWeights[sourceType].put(sourceType, Double.parseDouble(fields[1]) / sum);

			int i = 2;
			while (i < fields.length) {
				int targetType = alphabet.lookupIndex(fields[i]);
				typeTypeWeights[sourceType].put(targetType, Double.parseDouble(fields[i+1]) / sum);
				i += 2;
			}
		}
	}

	public void sample (int iterations, boolean shouldInitialize, int docCycleCount) throws IOException {

		for (int iteration = 1; iteration <= iterations; iteration++) {

			long iterationStart = System.currentTimeMillis();

			// Loop over every document in the corpus
			for (int doc = 0; doc < data.size(); doc++) {
				//			for (int doc = 0; doc < 5000; doc++) {
				FeatureSequence tokenSequence =
					(FeatureSequence) data.get(doc).instance.getData();
				LabelSequence topicSequence =
					(LabelSequence) data.get(doc).topicSequence;
				
				// Run the sampler in initialization mode for 
				//  the first iteration, and show debugging info 
				//  for the first document.
				sampleTopicsForOneDoc (tokenSequence, topicSequence, shouldInitialize && iteration == 1, false);
				
				for (int i = 1; i < docCycleCount; i++) {
					sampleTopicsForOneDoc (tokenSequence, topicSequence, false, false);
				}

				/*
				if ((doc+1) % 1000 == 0) {
					System.out.println(doc + 1);
				}
				*/
			}
		
			long elapsedMillis = System.currentTimeMillis() - iterationStart;
			logger.info(iteration + "\t" + elapsedMillis + "ms\t");

			// Occasionally print more information
			if (showTopicsInterval != 0 && iteration % showTopicsInterval == 0) {
				logger.info("<" + iteration + ">\n" +
							topWords (wordsPerTopic));
			}

		}
	}
	
	protected void sampleTopicsForOneDoc (FeatureSequence tokenSequence,
										  FeatureSequence topicSequence, 
										  boolean initializing, boolean debugging) {

		int[] oneDocTopics = topicSequence.getFeatures();

		int[] currentTypeTopicCounts;
		double[] currentTypeTopicWeights;

		int type, oldTopic, newTopic;
		double topicWeightsSum;
		int docLength = tokenSequence.getLength();

		int[] localTopicCounts = new int[numTopics];

		if (! initializing) {
			//		populate topic counts
			for (int position = 0; position < docLength; position++) {
				localTopicCounts[oneDocTopics[position]]++;
			}
		}

		double score, sum;
		double[] topicTermScores = new double[numTopics];

		//	Iterate over the positions (words) in the document 
		for (int position = 0; position < docLength; position++) {
			type = tokenSequence.getIndexAtPosition(position);
			oldTopic = oneDocTopics[position];

			TIntDoubleHashMap typeFactors = typeTypeWeights[type];
			int[] connectedTypes = typeFactors.keys();

			// Grab the relevant row from our two-dimensional array
			currentTypeTopicCounts = typeTopicCounts[type];
			currentTypeTopicWeights = typeTopicWeights[type];

			if (! initializing) {

				//	Remove this token from all counts. 
				localTopicCounts[oldTopic]--;
				tokensPerTopic[oldTopic]--;
				assert(tokensPerTopic[oldTopic] >= 0);
				currentTypeTopicCounts[oldTopic]--;

				int typeCount, otherTypeCount;
				typeCount = currentTypeTopicCounts[oldTopic]; // already incremented
				
				for (int otherType: connectedTypes) {
					double factor = typeFactors.get(otherType);

					typeTopicWeights[otherType][oldTopic] -= factor;
					totalTopicWeights[oldTopic] -= factor;
				}
			}

			// Now calculate and add up the scores for each topic for this word
			sum = 0.0;
				
			// Here's where the math happens! Note that overall performance is 
			//  dominated by what you do in this loop.
			for (int topic = 0; topic < numTopics; topic++) {
				score =
					(alpha + localTopicCounts[topic]) *
					(currentTypeTopicWeights[topic] / totalTopicWeights[topic]);
				sum += score;
				topicTermScores[topic] = score;
				

				if (debugging && type == 68) {System.out.println(type + "\t" + topic + "\t" + localTopicCounts[topic] + "\t" + currentTypeTopicCounts[topic] + "\t" + currentTypeTopicWeights[topic] + "\t" + tokensPerTopic[topic] + "\t" + sum);}
			}
			
			// Choose a random point between 0 and the sum of all topic scores
			double sample = random.nextUniform() * sum;

			if (debugging) {
				System.out.println("sample " + sample + " / " + sum);
			}

			// Figure out which topic contains that point
			newTopic = -1;
			while (sample > 0.0) {
				newTopic++;
				sample -= topicTermScores[newTopic];
			}

			// Make sure we actually sampled a topic
			if (debugging || newTopic == -1) {

				/*
				System.out.println(alphabet.lookupObject(type));
				for (int topic = 0; topic < numTopics; topic++) {
					System.out.println("(" + alpha + " + " + localTopicCounts[topic] + ") * " +
									   "(" + currentTypeTopicWeights[topic] + " / " + totalTopicWeights[topic] + ") = " +
									   topicTermScores[topic]);
				}
				*/
				
				//throw new IllegalStateException ("WeightedTopicModel: New topic not sampled.");
			}

			// Put that new topic into the counts
			oneDocTopics[position] = newTopic;
			localTopicCounts[newTopic]++;
			tokensPerTopic[newTopic]++;
			currentTypeTopicCounts[newTopic]++;

			//System.out.println(newTopic + "\t" + alphabet.lookupObject(type));

			int typeCount, otherTypeCount;
			typeCount = currentTypeTopicCounts[newTopic]; // already incremented

			for (int otherType: connectedTypes) {
				double factor = typeFactors.get(otherType);
				
				typeTopicWeights[otherType][newTopic] += factor;
				totalTopicWeights[newTopic] += factor;
			}


		}
	}
	
	/*
	public double modelLogLikelihood() {
		double logLikelihood = 0.0;
		int nonZeroTopics;

		// The likelihood of the model is a combination of a 
		// Dirichlet-multinomial for the words in each topic
		// and a Dirichlet-multinomial for the topics in each
		// document.

		// The likelihood function of a dirichlet multinomial is
		//	 Gamma( sum_i alpha_i )	 prod_i Gamma( alpha_i + N_i )
		//	prod_i Gamma( alpha_i )	  Gamma( sum_i (alpha_i + N_i) )

		// So the log likelihood is 
		//	logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) + 
		//	 sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ]

		// Do the documents first

		int[] topicCounts = new int[numTopics];
		double[] topicLogGammas = new double[numTopics];
		int[] docTopics;

		for (int topic=0; topic < numTopics; topic++) {
			topicLogGammas[ topic ] = Dirichlet.logGamma( alpha );
		}
	
		for (int doc=0; doc < data.size(); doc++) {
			LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence;

			docTopics = topicSequence.getFeatures();

			for (int token=0; token < docTopics.length; token++) {
				topicCounts[ docTopics[token] ]++;
			}

			for (int topic=0; topic < numTopics; topic++) {
				if (topicCounts[topic] > 0) {
					logLikelihood += (Dirichlet.logGamma(alpha + topicCounts[topic]) -
									  topicLogGammas[ topic ]);
				}
			}

			// subtract the (count + parameter) sum term
			logLikelihood -= Dirichlet.logGamma(alphaSum + docTopics.length);

			Arrays.fill(topicCounts, 0);
		}
	
		// add the parameter sum term
		logLikelihood += data.size() * Dirichlet.logGamma(alphaSum);

		// And the topics

		// Count the number of type-topic pairs
		int nonZeroTypeTopics = 0;

		for (int type=0; type < numTypes; type++) {
			// reuse this array as a pointer

			topicCounts = typeTopicCounts[type];

			for (int topic = 0; topic < numTopics; topic++) {
				if (topicCounts[topic] == 0) { continue; }
				
				nonZeroTypeTopics++;
				logLikelihood += Dirichlet.logGamma(beta + topicCounts[topic]);

				if (Double.isNaN(logLikelihood)) {
					System.out.println(topicCounts[topic]);
					System.exit(1);
				}
			}
		}
	
		for (int topic=0; topic < numTopics; topic++) {
			logLikelihood -= 
				Dirichlet.logGamma( (beta * numTopics) +
											tokensPerTopic[ topic ] );
			if (Double.isNaN(logLikelihood)) {
				System.out.println("after topic " + topic + " " + tokensPerTopic[ topic ]);
				System.exit(1);
			}

		}
	
		logLikelihood += 
			(Dirichlet.logGamma(beta * numTopics)) -
			(Dirichlet.logGamma(beta) * nonZeroTypeTopics);

		if (Double.isNaN(logLikelihood)) {
			System.out.println("at the end");
			System.exit(1);
		}


		return logLikelihood;
	}
	*/

	// 
	// Methods for displaying and saving results
	//

	public String topWords (int numWords) {

		StringBuilder output = new StringBuilder();

		IDSorter[] sortedWords = new IDSorter[numTypes];

		for (int topic = 0; topic < numTopics; topic++) {
			for (int type = 0; type < numTypes; type++) {
				sortedWords[type] = new IDSorter(type, typeTopicCounts[type][topic]);
			}

			Arrays.sort(sortedWords);
			
			output.append(topic + "\t" + tokensPerTopic[topic] + "\t" + formatter.format(totalTopicWeights[topic]));
			for (int i=0; i < numWords; i++) {
				output.append(alphabet.lookupObject(sortedWords[i].getID()) + " ");
			}
			output.append("\n");
		}

		return output.toString();
	}

	public MarginalProbEstimator getEstimator() {
		// The type-topic counts are "dense", meaning that the index of 
		//  the array element determines its topic. The marginal estimator
		//  uses the sparse "bit encoded" arrays used in ParallelTopicModel,
		//  so we need to convert to that format.

		int topicMask, topicBits;

		if (Integer.bitCount(numTopics) == 1) {
			// exact power of 2
			topicMask = numTopics - 1;
			topicBits = Integer.bitCount(topicMask);
		}
		else {
			// otherwise add an extra bit
			topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
			topicBits = Integer.bitCount(topicMask);
		}

		int[][] sparseTypeTopicCounts = new int[numTypes][];

		for (int type = 0; type < numTypes; type++) {
			
			int[] currentTypeTopicCounts = typeTopicCounts[type];

			// First figure out how many entries we have
			int numNonZeros = 0;
			for (int topic = 0; topic < numTopics; topic++) {
				if (currentTypeTopicCounts[topic] > 0) {
					numNonZeros ++;
				}
			}

			// Allocate the sparse array
			int[] sparseCounts = new int[numNonZeros];

			// And fill it, keeping the array in descending order

			for (int topic = 0; topic < numTopics; topic++) {
				if (currentTypeTopicCounts[topic] > 0) {
					int value = (currentTypeTopicCounts[topic] << topicBits) + topic;
					int i = 0;
					
					// Move values along. Note that java arrays are 
					//  all zeros at initialization.
					while (sparseCounts[i] > value) {
						i++;
					}
					// We've now found where to insert, push along any other values
					while (i < sparseCounts.length && value > sparseCounts[i]) {
						int temp = sparseCounts[i];
						sparseCounts[i] = value;
						value = temp;
						i++;
					}
				}
			}

			// Now add it to the array of arrays
			sparseTypeTopicCounts[type] = sparseCounts;
			
		}

		double[] alphas = new double[ numTopics ];
		Arrays.fill(alphas, alpha);
		return new MarginalProbEstimator(numTopics, alphas, alphaSum, beta,
										 sparseTypeTopicCounts, tokensPerTopic);
	}

	public void printState (File f) throws IOException {
		PrintStream out =
			new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
		printState(out);
		out.close();
	}
	
	public void printState (PrintStream stream) {

		stream.println ("#doc source pos typeindex type topic");

		for (int doc = 0; doc < data.size(); doc++) {
		//for (int doc = 0; doc < 5000; doc++) {
			FeatureSequence tokenSequence =	(FeatureSequence) data.get(doc).instance.getData();
			LabelSequence topicSequence =	(LabelSequence) data.get(doc).topicSequence;

			String source = "NA";

			StringBuilder out = new StringBuilder();
				
			for (int position = 0; position < topicSequence.getLength(); position++) {
				int type = tokenSequence.getIndexAtPosition(position);
				int topic = topicSequence.getIndexAtPosition(position);
				
				out.append(doc); out.append(' ');
				out.append(source); out.append(' '); 
				out.append(position); out.append(' ');
				out.append(type); out.append(' ');
				out.append(alphabet.lookupObject(type)); out.append(' ');
				out.append(topic);
				out.append("\n");
			}

			stream.print(out.toString());
		}
	}
	
	public static void main (String[] args) throws Exception {

		CommandOption.setSummary (WeightedTopicModel.class,
								  "Train topics with weights between word types encoded in the prior");
		CommandOption.process (WeightedTopicModel.class, args);

		InstanceList training = InstanceList.load (new File(inputFile.value));

		Randoms random = null;
		if (randomSeedOption.value != 0) {
			random = new Randoms(randomSeedOption.value);
		}
		else {
			random = new Randoms();
		}

		WeightedTopicModel lda =
			new WeightedTopicModel (numTopicsOption.value, alphaOption.value, betaOption.value, random);
		lda.addInstances(training);
		lda.readTypeTypeWeights(new File(weightsFile.value));

		int docCycleCount = 1;

		for (int epoch = 1; epoch <= numEpochsOption.value; epoch++) {
			lda.sample(numIterationsOption.value, epoch == 1, docCycleCount);
			
			if (stateFile.wasInvoked()) {
				lda.printState(new File(stateFile.value + "." + epoch));
			}
			if (evaluatorFilename.wasInvoked()) {
				try {
					ObjectOutputStream oos = 
						new ObjectOutputStream(new FileOutputStream(evaluatorFilename.value + "." + epoch));
					oos.writeObject(lda.getEstimator());
					oos.close();
				} catch (Exception e) {
					e.printStackTrace();
				}
			}
		}

	}
	
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy