cc.mallet.topics.LabeledLDA Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
package cc.mallet.topics;
import java.util.*;
import java.util.logging.*;
import java.util.zip.*;
import java.io.*;
import cc.mallet.types.*;
import cc.mallet.util.*;
import cc.mallet.pipe.iterator.DBInstanceIterator;
* LabeledLDA
* @author David Mimno
public class LabeledLDA implements Serializable {
protected static Logger logger = MalletLogger.getLogger(LabeledLDA.class.getName());
static cc.mallet.util.CommandOption.String inputFile =
new cc.mallet.util.CommandOption.String(LabeledLDA.class, "input", "FILENAME", true, null,
"The filename from which to read the list of training instances. Use - for stdin. " +
"The instances must be FeatureSequence, not FeatureVector", null);
static cc.mallet.util.CommandOption.String outputPrefix =
new cc.mallet.util.CommandOption.String(LabeledLDA.class, "output-prefix", "STRING", true, null,
"The prefix for output files (sampling states, parameters, etc) " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.String inputModelFilename = new CommandOption.String(LabeledLDA.class, "input-model", "FILENAME", true, null,
"The filename from which to read the binary topic model. The --input option is ignored. " +
"By default this is null, indicating that no file will be read.", null);
static CommandOption.String inputStateFilename = new CommandOption.String(LabeledLDA.class, "input-state", "FILENAME", true, null,
"The filename from which to read the gzipped Gibbs sampling state created by --output-state. " +
"The original input file must be included, using --input. " +
"By default this is null, indicating that no file will be read.", null);
// Model output options
static CommandOption.String outputModelFilename =
new CommandOption.String(LabeledLDA.class, "output-model", "FILENAME", true, null,
"The filename in which to write the binary topic model at the end of the iterations. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.String stateFile =
new CommandOption.String(LabeledLDA.class, "output-state", "FILENAME", true, null,
"The filename in which to write the Gibbs sampling state after at the end of the iterations. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.Integer outputModelInterval =
new CommandOption.Integer(LabeledLDA.class, "output-model-interval", "INTEGER", true, 0,
"The number of iterations between writing the model (and its Gibbs sampling state) to a binary file. " +
"You must also set the --output-model to use this option, whose argument will be the prefix of the filenames.", null);
static CommandOption.Integer outputStateInterval =
new CommandOption.Integer(LabeledLDA.class, "output-state-interval", "INTEGER", true, 0,
"The number of iterations between writing the sampling state to a text file. " +
"You must also set the --output-state to use this option, whose argument will be the prefix of the filenames.", null);
static CommandOption.String inferencerFilename =
new CommandOption.String(LabeledLDA.class, "inferencer-filename", "FILENAME", true, null,
"A topic inferencer applies a previously trained topic model to new documents. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.String evaluatorFilename =
new CommandOption.String(LabeledLDA.class, "evaluator-filename", "FILENAME", true, null,
"A held-out likelihood evaluator for new documents. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.String topicKeysFile =
new CommandOption.String(LabeledLDA.class, "output-topic-keys", "FILENAME", true, null,
"The filename in which to write the top words for each topic and any Dirichlet parameters. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.Integer numTopWords = new CommandOption.Integer(LabeledLDA.class, "num-top-words", "INTEGER", true, 20,
"The number of most probable words to print for each topic after model estimation.", null);
static CommandOption.Integer showTopicsIntervalOption = new CommandOption.Integer(LabeledLDA.class, "show-topics-interval", "INTEGER", true, 50,
"The number of iterations between printing a brief summary of the topics so far.", null);
static CommandOption.String topicWordWeightsFile = new CommandOption.String(LabeledLDA.class, "topic-word-weights-file", "FILENAME", true, null,
"The filename in which to write unnormalized weights for every topic and word type. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.String wordTopicCountsFile = new CommandOption.String(LabeledLDA.class, "word-topic-counts-file", "FILENAME", true, null,
"The filename in which to write a sparse representation of topic-word assignments. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.String diagnosticsFile = new CommandOption.String(LabeledLDA.class, "diagnostics-file", "FILENAME", true, null,
"The filename in which to write measures of topic quality, in XML format. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.String topicReportXMLFile = new CommandOption.String(LabeledLDA.class, "xml-topic-report", "FILENAME", true, null,
"The filename in which to write the top words for each topic and any Dirichlet parameters in XML format. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.String topicPhraseReportXMLFile = new CommandOption.String(LabeledLDA.class, "xml-topic-phrase-report", "FILENAME", true, null,
"The filename in which to write the top words and phrases for each topic and any Dirichlet parameters in XML format. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.String topicDocsFile = new CommandOption.String(LabeledLDA.class, "output-topic-docs", "FILENAME", true, null,
"The filename in which to write the most prominent documents for each topic, at the end of the iterations. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.Integer numTopDocs = new CommandOption.Integer(LabeledLDA.class, "num-top-docs", "INTEGER", true, 100,
"When writing topic documents with --output-topic-docs, " +
"report this number of top documents.", null);
static CommandOption.String docTopicsFile = new CommandOption.String(LabeledLDA.class, "output-doc-topics", "FILENAME", true, null,
"The filename in which to write the topic proportions per document, at the end of the iterations. " +
"By default this is null, indicating that no file will be written.", null);
static CommandOption.Double docTopicsThreshold = new CommandOption.Double(LabeledLDA.class, "doc-topics-threshold", "DECIMAL", true, 0.0,
"When writing topic proportions per document with --output-doc-topics, " +
"do not print topics with proportions less than this threshold value.", null);
static CommandOption.Integer docTopicsMax = new CommandOption.Integer(LabeledLDA.class, "doc-topics-max", "INTEGER", true, -1,
"When writing topic proportions per document with --output-doc-topics, " +
"do not print more than INTEGER number of topics. "+
"A negative value indicates that all topics should be printed.", null);
// Model parameters
static CommandOption.Integer numIterationsOption =
new CommandOption.Integer(LabeledLDA.class, "num-iterations", "INTEGER", true, 1000,
"The number of iterations of Gibbs sampling.", null);
static CommandOption.Boolean noInference =
new CommandOption.Boolean(LabeledLDA.class, "no-inference", "true|false", false, false,
"Do not perform inference, just load a saved model and create a report. Equivalent to --num-iterations 0.", null);
static CommandOption.Integer randomSeed =
new CommandOption.Integer(LabeledLDA.class, "random-seed", "INTEGER", true, 0,
"The random seed for the Gibbs sampler. Default is 0, which will use the clock.", null);
// Hyperparameters
static CommandOption.Double alphaOption =
new CommandOption.Double(LabeledLDA.class, "alpha", "DECIMAL", true, 0.1,
"Alpha parameter: smoothing over doc topic distribution (NOT the sum over topics).",null);
static CommandOption.Double betaOption =
new CommandOption.Double(LabeledLDA.class, "beta", "DECIMAL", true, 0.01,
"Beta parameter: smoothing over word distributions.",null);
// the training instances and their topic assignments
protected ArrayList data;
// the alphabet for the input data
protected Alphabet alphabet;
// this alphabet stores the string meanings of the labels/topics
protected Alphabet labelAlphabet;
// the alphabet for the topics
protected LabelAlphabet topicAlphabet;
// The number of topics requested
protected int numTopics;
// The size of the vocabulary
protected int numTypes;
// Prior parameters
protected double alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics
protected double beta; // Prior on per-topic multinomial distribution over words
protected double betaSum;
public static final double DEFAULT_BETA = 0.01;
// An array to put the topic counts for the current document.
// Initialized locally below. Defined here to avoid
// garbage collection overhead.
protected int[] oneDocTopicCounts; // indexed by
// Statistics needed for sampling.
protected int[][] typeTopicCounts; // indexed by
protected int[] tokensPerTopic; // indexed by
public int numIterations = 1000;
public int showTopicsInterval = 50;
public int wordsPerTopic = 10;
protected Randoms random;
protected boolean printLogLikelihood = false;
public LabeledLDA (double alpha, double beta) {
this.data = new ArrayList();
this.alpha = alpha;
this.beta = beta;
this.random = new Randoms();
logger.info("Labeled LDA");
public Alphabet getAlphabet() { return alphabet; }
public LabelAlphabet getTopicAlphabet() { return topicAlphabet; }
public ArrayList getData() { return data; }
public void setTopicDisplay(int interval, int n) {
this.showTopicsInterval = interval;
this.wordsPerTopic = n;
public void setRandomSeed(int seed) {
random = new Randoms(seed);
public void setNumIterations (int numIterations) {
this.numIterations = numIterations;
public int[][] getTypeTopicCounts() { return typeTopicCounts; }
public int[] getTopicTotals() { return tokensPerTopic; }
public void addInstances (InstanceList training) {
alphabet = training.getDataAlphabet();
numTypes = alphabet.size();
betaSum = beta * numTypes;
// We have one topic for every possible label.
labelAlphabet = training.getTargetAlphabet();
numTopics = labelAlphabet.size();
oneDocTopicCounts = new int[numTopics];
tokensPerTopic = new int[numTopics];
typeTopicCounts = new int[numTypes][numTopics];
topicAlphabet = AlphabetFactory.labelAlphabetOfSize(numTopics);
int doc = 0;
for (Instance instance : training) {
FeatureSequence tokens = (FeatureSequence) instance.getData();
FeatureVector labels = (FeatureVector) instance.getTarget();
LabelSequence topicSequence =
new LabelSequence(topicAlphabet, new int[ tokens.size() ]);
int[] topics = topicSequence.getFeatures();
for (int position = 0; position < tokens.size(); position++) {
int topic = labels.indexAtLocation(random.nextInt(labels.numLocations()));
topics[position] = topic;
int type = tokens.getIndexAtPosition(position);
TopicAssignment t = new TopicAssignment (instance, topicSequence);
data.add (t);
public void initializeFromState(File stateFile) throws IOException {
String line;
String[] fields;
BufferedReader reader = new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(stateFile))));
line = reader.readLine();
// Skip some lines starting with "#" that describe the format and specify hyperparameters
while (line.startsWith("#")) {
line = reader.readLine();
fields = line.split(" ");
for (TopicAssignment document: data) {
FeatureSequence tokens = (FeatureSequence) document.instance.getData();
FeatureSequence topicSequence = (FeatureSequence) document.topicSequence;
int[] topics = topicSequence.getFeatures();
for (int position = 0; position < tokens.size(); position++) {
int type = tokens.getIndexAtPosition(position);
if (type == Integer.parseInt(fields[3])) {
int topic = Integer.parseInt(fields[5]);
topics[position] = topic;
// This is the difference between the dense type-topic representation used here
// and the sparse version used in ParallelTopicModel.
else {
System.err.println("instance list and state do not match: " + line);
throw new IllegalStateException();
line = reader.readLine();
if (line != null) {
fields = line.split(" ");
public void estimate() throws IOException {
for (int iteration = 1; iteration <= numIterations; iteration++) {
long iterationStart = System.currentTimeMillis();
// Loop over every document in the corpus
for (int doc = 0; doc < data.size(); doc++) {
FeatureSequence tokenSequence =
(FeatureSequence) data.get(doc).instance.getData();
FeatureVector labels = (FeatureVector) data.get(doc).instance.getTarget();
LabelSequence topicSequence =
(LabelSequence) data.get(doc).topicSequence;
sampleTopicsForOneDoc (tokenSequence, labels, topicSequence);
long elapsedMillis = System.currentTimeMillis() - iterationStart;
logger.info(iteration + "\t" + elapsedMillis + "ms\t");
// Occasionally print more information
if (showTopicsInterval != 0 && iteration % showTopicsInterval == 0) {
logger.info("<" + iteration + "> Log Likelihood: " + modelLogLikelihood() + "\n" +
topWords (wordsPerTopic));
protected void sampleTopicsForOneDoc (FeatureSequence tokenSequence,
FeatureVector labels,
FeatureSequence topicSequence) {
int[] possibleTopics = labels.getIndices();
int numLabels = labels.numLocations();
int[] oneDocTopics = topicSequence.getFeatures();
int[] currentTypeTopicCounts;
int type, oldTopic, newTopic;
double topicWeightsSum;
int docLength = tokenSequence.getLength();
int[] localTopicCounts = new int[numTopics];
// populate topic counts
for (int position = 0; position < docLength; position++) {
double score, sum;
double[] topicTermScores = new double[numLabels];
// Iterate over the positions (words) in the document
for (int position = 0; position < docLength; position++) {
type = tokenSequence.getIndexAtPosition(position);
oldTopic = oneDocTopics[position];
// Grab the relevant row from our two-dimensional array
currentTypeTopicCounts = typeTopicCounts[type];
// Remove this token from all counts.
assert(tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0";
// Now calculate and add up the scores for each topic for this word
sum = 0.0;
// Here's where the math happens! Note that overall performance is
// dominated by what you do in this loop.
for (int labelPosition = 0; labelPosition < numLabels; labelPosition++) {
int topic = possibleTopics[labelPosition];
score =
(alpha + localTopicCounts[topic]) *
((beta + currentTypeTopicCounts[topic]) /
(betaSum + tokensPerTopic[topic]));
sum += score;
topicTermScores[labelPosition] = score;
// Choose a random point between 0 and the sum of all topic scores
double sample = random.nextUniform() * sum;
// Figure out which topic contains that point
int labelPosition = -1;
while (sample > 0.0) {
sample -= topicTermScores[labelPosition];
// Make sure we actually sampled a topic
if (labelPosition == -1) {
throw new IllegalStateException ("LabeledLDA: New topic not sampled.");
newTopic = possibleTopics[labelPosition];
// Put that new topic into the counts
oneDocTopics[position] = newTopic;
public double modelLogLikelihood() {
double logLikelihood = 0.0;
int nonZeroTopics;
// The likelihood of the model is a combination of a
// Dirichlet-multinomial for the words in each topic
// and a Dirichlet-multinomial for the topics in each
// document.
// The likelihood function of a dirichlet multinomial is
// Gamma( sum_i alpha_i ) prod_i Gamma( alpha_i + N_i )
// prod_i Gamma( alpha_i ) Gamma( sum_i (alpha_i + N_i) )
// So the log likelihood is
// logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) +
// sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ]
// Do the documents first
int[] topicCounts = new int[numTopics];
double[] topicLogGammas = new double[numTopics];
int[] docTopics;
for (int topic=0; topic < numTopics; topic++) {
topicLogGammas[ topic ] = Dirichlet.logGamma( alpha );
for (int doc=0; doc < data.size(); doc++) {
LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence;
FeatureVector labels = (FeatureVector) data.get(doc).instance.getTarget();
docTopics = topicSequence.getFeatures();
for (int token=0; token < docTopics.length; token++) {
topicCounts[ docTopics[token] ]++;
for (int topic=0; topic < numTopics; topic++) {
if (topicCounts[topic] > 0) {
logLikelihood += (Dirichlet.logGamma(alpha + topicCounts[topic]) -
topicLogGammas[ topic ]);
// add the parameter sum term
logLikelihood += Dirichlet.logGamma(alpha * labels.numLocations());
// subtract the (count + parameter) sum term
logLikelihood -= Dirichlet.logGamma(alpha * labels.numLocations() + docTopics.length);
Arrays.fill(topicCounts, 0);
// And the topics
// Count the number of type-topic pairs
int nonZeroTypeTopics = 0;
for (int type=0; type < numTypes; type++) {
// reuse this array as a pointer
topicCounts = typeTopicCounts[type];
for (int topic = 0; topic < numTopics; topic++) {
if (topicCounts[topic] == 0) { continue; }
logLikelihood += Dirichlet.logGamma(beta + topicCounts[topic]);
if (Double.isNaN(logLikelihood)) {
for (int topic=0; topic < numTopics; topic++) {
logLikelihood -=
Dirichlet.logGamma( (beta * numTopics) +
tokensPerTopic[ topic ] );
if (Double.isNaN(logLikelihood)) {
System.out.println("after topic " + topic + " " + tokensPerTopic[ topic ]);
logLikelihood +=
(Dirichlet.logGamma(beta * numTopics)) -
(Dirichlet.logGamma(beta) * nonZeroTypeTopics);
if (Double.isNaN(logLikelihood)) {
System.out.println("at the end");
return logLikelihood;
// Methods for displaying and saving results
public String topWords (int numWords) {
StringBuilder output = new StringBuilder();
IDSorter[] sortedWords = new IDSorter[numTypes];
for (int topic = 0; topic < numTopics; topic++) {
if (tokensPerTopic[topic] == 0) { continue; }
for (int type = 0; type < numTypes; type++) {
sortedWords[type] = new IDSorter(type, typeTopicCounts[type][topic]);
output.append(topic + "\t" + labelAlphabet.lookupObject(topic) + "\t" + tokensPerTopic[topic] + "\t");
for (int i=0; i < numWords; i++) {
if (sortedWords[i].getWeight() == 0) { break; }
output.append(alphabet.lookupObject(sortedWords[i].getID()) + " ");
return output.toString();
// Serialization
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private static final int NULL_INTEGER = -1;
public void write (File f) {
try {
ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream(f));
catch (IOException e) {
System.err.println("Exception writing file " + f + ": " + e);
public static LabeledLDA read (File f) throws Exception {
LabeledLDA topicModel = null;
ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f));
topicModel = (LabeledLDA) ois.readObject();
return topicModel;
private void writeObject (ObjectOutputStream out) throws IOException {
// Instance lists
out.writeObject (data);
out.writeObject (alphabet);
out.writeObject (topicAlphabet);
out.writeInt (numTopics);
out.writeObject (alpha);
out.writeDouble (beta);
out.writeDouble (betaSum);
out.writeObject (typeTopicCounts);
for (int ti = 0; ti < numTopics; ti++) {
out.writeInt (tokensPerTopic[ti]);
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int featuresLength;
int version = in.readInt ();
data = (ArrayList) in.readObject ();
alphabet = (Alphabet) in.readObject();
topicAlphabet = (LabelAlphabet) in.readObject();
numTopics = in.readInt();
alpha = in.readDouble();
beta = in.readDouble();
betaSum = in.readDouble();
showTopicsInterval = in.readInt();
wordsPerTopic = in.readInt();
random = (Randoms) in.readObject();
printLogLikelihood = in.readBoolean();
int numDocs = data.size();
this.numTypes = alphabet.size();
typeTopicCounts = (int[][]) in.readObject();
tokensPerTopic = new int[numTopics];
for (int ti = 0; ti < numTopics; ti++) {
tokensPerTopic[ti] = in.readInt();
public static void main (String[] args) throws Exception {
CommandOption.setSummary (LabeledLDA.class,
"Sample associations between words and labels");
CommandOption.process (LabeledLDA.class, args);
LabeledLDA labeledLDA;
if (inputModelFilename.value != null) {
labeledLDA = LabeledLDA.read(new File(inputModelFilename.value));
else {
labeledLDA = new LabeledLDA (alphaOption.value, betaOption.value);
if (randomSeed.value != 0) {
if (inputFile.value != null) {
InstanceList training = null;
try {
if (inputFile.value.startsWith("db:")) {
training = DBInstanceIterator.getInstances(inputFile.value.substring(3));
else {
training = InstanceList.load (new File(inputFile.value));
} catch (Exception e) {
logger.warning("Unable to restore instance list " +
inputFile.value + ": " + e);
logger.info("Data loaded.");
if (training.size() > 0 &&
training.get(0) != null) {
Object data = training.get(0).getData();
if (! (data instanceof FeatureSequence)) {
logger.warning("Topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.");
if (inputStateFilename.value != null) {
logger.info("Initializing from saved state.");
labeledLDA.initializeFromState(new File(inputStateFilename.value));
labeledLDA.setTopicDisplay(showTopicsIntervalOption.value, numTopWords.value);
if (! noInference.value()) {
if (topicKeysFile.value != null) {
PrintStream out = new PrintStream (new File(topicKeysFile.value));
if (outputModelFilename.value != null) {
assert (labeledLDA != null);
try {
ObjectOutputStream oos =
new ObjectOutputStream (new FileOutputStream (outputModelFilename.value));
oos.writeObject (labeledLDA);
} catch (Exception e) {
logger.warning("Couldn't write topic model to filename " + outputModelFilename.value);
// I don't want to directly inherit from ParallelTopicModel
// because the two implementations treat the type-topic counts differently.
// Instead, simulate a standard Parallel Topic Model by copying over
// the appropriate data structures.
ParallelTopicModel topicModel = new ParallelTopicModel(labeledLDA.topicAlphabet, labeledLDA.alpha * labeledLDA.numTopics, labeledLDA.beta);
topicModel.data = labeledLDA.data;
topicModel.alphabet = labeledLDA.alphabet;
topicModel.numTypes = labeledLDA.numTypes;
topicModel.betaSum = labeledLDA.betaSum;
if (diagnosticsFile.value != null) {
PrintWriter out = new PrintWriter(diagnosticsFile.value);
TopicModelDiagnostics diagnostics = new TopicModelDiagnostics(topicModel, numTopWords.value);
if (topicReportXMLFile.value != null) {
PrintWriter out = new PrintWriter(topicReportXMLFile.value);
topicModel.topicXMLReport(out, numTopWords.value);
if (topicPhraseReportXMLFile.value != null) {
PrintWriter out = new PrintWriter(topicPhraseReportXMLFile.value);
topicModel.topicPhraseXMLReport(out, numTopWords.value);
if (stateFile.value != null && outputStateInterval.value == 0) {
topicModel.printState (new File(stateFile.value));
if (topicDocsFile.value != null) {
PrintWriter out = new PrintWriter (new FileWriter ((new File(topicDocsFile.value))));
topicModel.printTopicDocuments(out, numTopDocs.value);
if (docTopicsFile.value != null) {
PrintWriter out = new PrintWriter (new FileWriter ((new File(docTopicsFile.value))));
if (docTopicsThreshold.value == 0.0) {
else {
topicModel.printDocumentTopics(out, docTopicsThreshold.value, docTopicsMax.value);
if (topicWordWeightsFile.value != null) {
topicModel.printTopicWordWeights(new File (topicWordWeightsFile.value));
if (wordTopicCountsFile.value != null) {
topicModel.printTypeTopicCounts(new File (wordTopicCountsFile.value));
if (inferencerFilename.value != null) {
try {
ObjectOutputStream oos =
new ObjectOutputStream(new FileOutputStream(inferencerFilename.value));
} catch (Exception e) {
logger.warning("Couldn't create inferencer: " + e.getMessage());
if (evaluatorFilename.value != null) {
try {
ObjectOutputStream oos =
new ObjectOutputStream(new FileOutputStream(evaluatorFilename.value));
} catch (Exception e) {
logger.warning("Couldn't create evaluator: " + e.getMessage());
© 2015 - 2025 Weber Informatics LLC | Privacy Policy