
cc.mallet.topics.TopicModelDiagnostics Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jcore-mallet-2.0.9 Show documentation
Show all versions of jcore-mallet-2.0.9 Show documentation
MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.
The newest version!
package cc.mallet.topics;
import java.io.*;
import java.util.*;
import cc.mallet.types.*;
import gnu.trove.*;
public class TopicModelDiagnostics {
int numTopics;
int numTopWords;
public static final int TWO_PERCENT_INDEX = 1;
public static final int FIFTY_PERCENT_INDEX = 6;
public static final double[] DEFAULT_DOC_PROPORTIONS = { 0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.5 };
/** All words in sorted order, with counts */
ArrayList> topicSortedWords;
/** The top N words in each topic in an array for easy access */
String[][] topicTopWords;
ArrayList diagnostics;
ParallelTopicModel model;
Alphabet alphabet;
int[][][] topicCodocumentMatrices;
int[] numRank1Documents;
int[] numNonZeroDocuments;
int[][] numDocumentsAtProportions;
// This quantity is used in entropy calculation
double[] sumCountTimesLogCount;
int[] wordTypeCounts;
int numTokens = 0;
public TopicModelDiagnostics (ParallelTopicModel model, int numTopWords) {
numTopics = model.getNumTopics();
this.numTopWords = numTopWords;
this.model = model;
alphabet = model.getAlphabet();
topicSortedWords = model.getSortedWords();
topicTopWords = new String[numTopics][numTopWords];
numRank1Documents = new int[numTopics];
numNonZeroDocuments = new int[numTopics];
numDocumentsAtProportions = new int[numTopics][ DEFAULT_DOC_PROPORTIONS.length ];
sumCountTimesLogCount = new double[numTopics];
diagnostics = new ArrayList();
for (int topic = 0; topic < numTopics; topic++) {
int position = 0;
TreeSet sortedWords = topicSortedWords.get(topic);
// How many words should we report? Some topics may have fewer than
// the default number of words with non-zero weight.
int limit = numTopWords;
if (sortedWords.size() < numTopWords) { limit = sortedWords.size(); }
Iterator iterator = sortedWords.iterator();
for (int i=0; i < limit; i++) {
IDSorter info = iterator.next();
topicTopWords[topic][i] = (String) alphabet.lookupObject(info.getID());
}
}
collectDocumentStatistics();
diagnostics.add(getTokensPerTopic(model.tokensPerTopic));
diagnostics.add(getDocumentEntropy(model.tokensPerTopic));
diagnostics.add(getWordLengthScores());
diagnostics.add(getCoherence());
diagnostics.add(getDistanceFromUniform());
diagnostics.add(getDistanceFromCorpus());
diagnostics.add(getEffectiveNumberOfWords());
diagnostics.add(getTokenDocumentDiscrepancies());
diagnostics.add(getRank1Percent());
diagnostics.add(getDocumentPercentRatio(FIFTY_PERCENT_INDEX, TWO_PERCENT_INDEX));
diagnostics.add(getDocumentPercent(5));
diagnostics.add(getExclusivity());
}
public void collectDocumentStatistics () {
topicCodocumentMatrices = new int[numTopics][numTopWords][numTopWords];
wordTypeCounts = new int[alphabet.size()];
numTokens = 0;
// This is an array of hash sets containing the words-of-interest for each topic,
// used for checking if the word at some position is one of those words.
TIntHashSet[] topicTopWordIndices = new TIntHashSet[numTopics];
// The same as the topic top words, but with int indices instead of strings,
// used for iterating over positions.
int[][] topicWordIndicesInOrder = new int[numTopics][numTopWords];
// This is an array of hash sets that will hold the words-of-interest present in a document,
// which will be cleared after every document.
TIntHashSet[] docTopicWordIndices = new TIntHashSet[numTopics];
int numDocs = model.getData().size();
// The count of each topic, again cleared after every document.
int[] topicCounts = new int[numTopics];
for (int topic = 0; topic < numTopics; topic++) {
TIntHashSet wordIndices = new TIntHashSet();
for (int i = 0; i < numTopWords; i++) {
if (topicTopWords[topic][i] != null) {
int type = alphabet.lookupIndex(topicTopWords[topic][i]);
topicWordIndicesInOrder[topic][i] = type;
wordIndices.add(type);
}
}
topicTopWordIndices[topic] = wordIndices;
docTopicWordIndices[topic] = new TIntHashSet();
}
int doc = 0;
for (TopicAssignment document: model.getData()) {
FeatureSequence tokens = (FeatureSequence) document.instance.getData();
FeatureSequence topics = (FeatureSequence) document.topicSequence;
for (int position = 0; position < tokens.size(); position++) {
int type = tokens.getIndexAtPosition(position);
int topic = topics.getIndexAtPosition(position);
numTokens++;
wordTypeCounts[type]++;
topicCounts[topic]++;
if (topicTopWordIndices[topic].contains(type)) {
docTopicWordIndices[topic].add(type);
}
}
int docLength = tokens.size();
if (docLength > 0) {
int maxTopic = -1;
int maxCount = -1;
for (int topic = 0; topic < numTopics; topic++) {
if (topicCounts[topic] > 0) {
numNonZeroDocuments[topic]++;
if (topicCounts[topic] > maxCount) {
maxTopic = topic;
maxCount = topicCounts[topic];
}
sumCountTimesLogCount[topic] += topicCounts[topic] * Math.log(topicCounts[topic]);
double proportion = (model.alpha[topic] + topicCounts[topic]) / (model.alphaSum + docLength);
for (int i = 0; i < DEFAULT_DOC_PROPORTIONS.length; i++) {
if (proportion < DEFAULT_DOC_PROPORTIONS[i]) { break; }
numDocumentsAtProportions[topic][i]++;
}
TIntHashSet supportedWords = docTopicWordIndices[topic];
int[] indices = topicWordIndicesInOrder[topic];
for (int i = 0; i < numTopWords; i++) {
if (supportedWords.contains(indices[i])) {
for (int j = i; j < numTopWords; j++) {
if (i == j) {
// Diagonals are total number of documents with word W in topic T
topicCodocumentMatrices[topic][i][i]++;
}
else if (supportedWords.contains(indices[j])) {
topicCodocumentMatrices[topic][i][j]++;
topicCodocumentMatrices[topic][j][i]++;
}
}
}
}
docTopicWordIndices[topic].clear();
topicCounts[topic] = 0;
}
}
if (maxTopic > -1) {
numRank1Documents[maxTopic]++;
}
}
doc++;
}
}
public int[][] getCodocumentMatrix(int topic) {
return topicCodocumentMatrices[topic];
}
public TopicScores getTokensPerTopic(int[] tokensPerTopic) {
TopicScores scores = new TopicScores("tokens", numTopics, numTopWords);
for (int topic = 0; topic < numTopics; topic++) {
scores.setTopicScore(topic, tokensPerTopic[topic]);
}
return scores;
}
public TopicScores getDocumentEntropy(int[] tokensPerTopic) {
TopicScores scores = new TopicScores("document_entropy", numTopics, numTopWords);
for (int topic = 0; topic < numTopics; topic++) {
scores.setTopicScore(topic, -sumCountTimesLogCount[topic] / tokensPerTopic[topic] + Math.log(tokensPerTopic[topic]));
}
return scores;
}
public TopicScores getDistanceFromUniform() {
int[] tokensPerTopic = model.tokensPerTopic;
TopicScores scores = new TopicScores("uniform_dist", numTopics, numTopWords);
scores.wordScoresDefined = true;
int numTypes = alphabet.size();
for (int topic = 0; topic < numTopics; topic++) {
double topicScore = 0.0;
int position = 0;
TreeSet sortedWords = topicSortedWords.get(topic);
for (IDSorter info: sortedWords) {
int type = info.getID();
double count = info.getWeight();
double score = (count / tokensPerTopic[topic]) *
Math.log( (count * numTypes) / tokensPerTopic[topic] );
if (position < numTopWords) {
scores.setTopicWordScore(topic, position, score);
}
topicScore += score;
position++;
}
scores.setTopicScore(topic, topicScore);
}
return scores;
}
public TopicScores getEffectiveNumberOfWords() {
int[] tokensPerTopic = model.tokensPerTopic;
TopicScores scores = new TopicScores("eff_num_words", numTopics, numTopWords);
int numTypes = alphabet.size();
for (int topic = 0; topic < numTopics; topic++) {
double sumSquaredProbabilities = 0.0;
TreeSet sortedWords = topicSortedWords.get(topic);
for (IDSorter info: sortedWords) {
int type = info.getID();
double probability = info.getWeight() / tokensPerTopic[topic];
sumSquaredProbabilities += probability * probability;
}
scores.setTopicScore(topic, 1.0 / sumSquaredProbabilities);
}
return scores;
}
/** Low-quality topics may be very similar to the global distribution. */
public TopicScores getDistanceFromCorpus() {
int[] tokensPerTopic = model.tokensPerTopic;
TopicScores scores = new TopicScores("corpus_dist", numTopics, numTopWords);
scores.wordScoresDefined = true;
for (int topic = 0; topic < numTopics; topic++) {
double coefficient = (double) numTokens / tokensPerTopic[topic];
double topicScore = 0.0;
int position = 0;
TreeSet sortedWords = topicSortedWords.get(topic);
for (IDSorter info: sortedWords) {
int type = info.getID();
double count = info.getWeight();
double score = (count / tokensPerTopic[topic]) *
Math.log( coefficient * count / wordTypeCounts[type] );
if (position < numTopWords) {
//System.out.println(alphabet.lookupObject(type) + ": " + count + " * " + numTokens + " / " + wordTypeCounts[type] + " * " + tokensPerTopic[topic] + " = " + (coefficient * count / wordTypeCounts[type]));
scores.setTopicWordScore(topic, position, score);
}
topicScore += score;
position++;
}
scores.setTopicScore(topic, topicScore);
}
return scores;
}
public TopicScores getTokenDocumentDiscrepancies() {
TopicScores scores = new TopicScores("token-doc-diff", numTopics, numTopWords);
scores.wordScoresDefined = true;
for (int topic = 0; topic < numTopics; topic++) {
int[][] matrix = topicCodocumentMatrices[topic];
TreeSet sortedWords = topicSortedWords.get(topic);
double topicScore = 0.0;
double[] wordDistribution = new double[numTopWords];
double[] docDistribution = new double[numTopWords];
double wordSum = 0.0;
double docSum = 0.0;
int position = 0;
Iterator iterator = sortedWords.iterator();
while (iterator.hasNext() && position < numTopWords) {
IDSorter info = iterator.next();
wordDistribution[position] = info.getWeight();
docDistribution[position] = matrix[position][position];
wordSum += wordDistribution[position];
docSum += docDistribution[position];
position++;
}
for (position = 0; position < numTopWords; position++) {
double p = wordDistribution[position] / wordSum;
double q = docDistribution[position] / docSum;
double meanProb = 0.5 * (p + q);
double score = 0.0;
if (p > 0) {
score += 0.5 * p * Math.log(p / meanProb);
}
if (q > 0) {
score += 0.5 * q * Math.log(q / meanProb);
}
scores.setTopicWordScore(topic, position, score);
topicScore += score;
}
scores.setTopicScore(topic, topicScore);
}
return scores;
}
/** Low-quality topics often have lots of unusually short words. */
public TopicScores getWordLengthScores() {
TopicScores scores = new TopicScores("word-length", numTopics, numTopWords);
scores.wordScoresDefined = true;
for (int topic = 0; topic < numTopics; topic++) {
int total = 0;
for (int position = 0; position < topicTopWords[topic].length; position++) {
if (topicTopWords[topic][position] == null) { break; }
int length = topicTopWords[topic][position].length();
total += length;
scores.setTopicWordScore(topic, position, length);
}
scores.setTopicScore(topic, (double) total / topicTopWords[topic].length);
}
return scores;
}
/** Low-quality topics often have lots of unusually short words. */
public TopicScores getWordLengthStandardDeviation() {
TopicScores scores = new TopicScores("word-length-sd", numTopics, numTopWords);
scores.wordScoresDefined = true;
// Get the mean length
double meanLength = 0.0;
int totalWords = 0;
for (int topic = 0; topic < numTopics; topic++) {
for (int position = 0; position < topicTopWords[topic].length; position++) {
// Some topics may not have all N words
if (topicTopWords[topic][position] == null) { break; }
meanLength += topicTopWords[topic][position].length();
totalWords ++;
}
}
meanLength /= totalWords;
// Now calculate the standard deviation
double lengthVariance = 0.0;
for (int topic = 0; topic < numTopics; topic++) {
for (int position = 0; position < topicTopWords[topic].length; position++) {
if (topicTopWords[topic][position] == null) { break; }
int length = topicTopWords[topic][position].length();
lengthVariance += (length - meanLength) * (length - meanLength);
}
}
lengthVariance /= (totalWords - 1);
// Finally produce an overall topic score
double lengthSD = Math.sqrt(lengthVariance);
for (int topic = 0; topic < numTopics; topic++) {
for (int position = 0; position < topicTopWords[topic].length; position++) {
if (topicTopWords[topic][position] == null) { break; }
int length = topicTopWords[topic][position].length();
scores.addToTopicScore(topic, (length - meanLength) / lengthSD);
scores.setTopicWordScore(topic, position, (length - meanLength) / lengthSD);
}
}
return scores;
}
public TopicScores getCoherence() {
TopicScores scores = new TopicScores("coherence", numTopics, numTopWords);
scores.wordScoresDefined = true;
for (int topic = 0; topic < numTopics; topic++) {
int[][] matrix = topicCodocumentMatrices[topic];
double topicScore = 0.0;
for (int row = 0; row < numTopWords; row++) {
double rowScore = 0.0;
double minScore = 0.0;
for (int col = 0; col < row; col++) {
double score = Math.log( (matrix[row][col] + model.beta) / (matrix[col][col] + model.beta) );
rowScore += score;
if (score < minScore) { minScore = score; }
}
topicScore += rowScore;
scores.setTopicWordScore(topic, row, minScore);
}
scores.setTopicScore(topic, topicScore);
}
return scores;
}
public TopicScores getRank1Percent() {
TopicScores scores = new TopicScores("rank_1_docs", numTopics, numTopWords);
for (int topic = 0; topic < numTopics; topic++) {
scores.setTopicScore(topic, (double) numRank1Documents[topic] / numNonZeroDocuments[topic]);
}
return scores;
}
public TopicScores getDocumentPercentRatio(int numeratorIndex, int denominatorIndex) {
TopicScores scores = new TopicScores("allocation_ratio", numTopics, numTopWords);
if (numeratorIndex > numDocumentsAtProportions[0].length || denominatorIndex > numDocumentsAtProportions[0].length) {
System.err.println("Invalid proportion indices (max " + (numDocumentsAtProportions[0].length - 1) + ") : " +
numeratorIndex + ", " + denominatorIndex);
return scores;
}
for (int topic = 0; topic < numTopics; topic++) {
scores.setTopicScore(topic, (double) numDocumentsAtProportions[topic][numeratorIndex] /
numDocumentsAtProportions[topic][denominatorIndex]);
}
return scores;
}
public TopicScores getDocumentPercent(int i) {
TopicScores scores = new TopicScores("allocation_count", numTopics, numTopWords);
if (i > numDocumentsAtProportions[0].length) {
System.err.println("Invalid proportion indices (max " + (numDocumentsAtProportions[0].length - 1) + ") : " + i);
return scores;
}
for (int topic = 0; topic < numTopics; topic++) {
scores.setTopicScore(topic, (double) numDocumentsAtProportions[topic][i] / numNonZeroDocuments[topic]);
}
return scores;
}
/** Low-quality topics may have words that are also prominent in other topics. */
public TopicScores getExclusivity() {
int[] tokensPerTopic = model.tokensPerTopic;
TopicScores scores = new TopicScores("exclusivity", numTopics, numTopWords);
scores.wordScoresDefined = true;
double sumDefaultProbs = 0.0;
for (int topic = 0; topic < numTopics; topic++) {
sumDefaultProbs += model.beta / (model.betaSum + tokensPerTopic[topic]);
}
for (int topic = 0; topic < numTopics; topic++) {
double topicScore = 0.0;
int position = 0;
TreeSet sortedWords = topicSortedWords.get(topic);
for (IDSorter info: sortedWords) {
int type = info.getID();
double count = info.getWeight();
double sumTypeProbs = sumDefaultProbs;
int[] topicCounts = model.typeTopicCounts[type];
int index = 0;
while (index < topicCounts.length &&
topicCounts[index] > 0) {
int otherTopic = topicCounts[index] & model.topicMask;
int otherCount = topicCounts[index] >> model.topicBits;
// We've already accounted for the smoothing parameter,
// now we need to add the actual count for the non-zero
// topics.
sumTypeProbs += ((double) otherCount) / (model.betaSum + tokensPerTopic[otherTopic]);
index++;
}
double score = ((model.beta + count) / (model.betaSum + tokensPerTopic[topic])) / sumTypeProbs;
scores.setTopicWordScore(topic, position, score);
topicScore += score;
position++;
if (position == numTopWords) {
break;
}
}
scores.setTopicScore(topic, topicScore / numTopWords);
}
return scores;
}
public String toString() {
StringBuilder out = new StringBuilder();
Formatter formatter = new Formatter(out, Locale.US);
for (int topic = 0; topic < numTopics; topic++) {
formatter.format("Topic %d", topic);
for (TopicScores scores: diagnostics) {
formatter.format("\t%s=%.4f", scores.name, scores.scores[topic]);
}
formatter.format("\n");
for (int position = 0; position < topicTopWords[topic].length; position++) {
if (topicTopWords[topic][position] == null) { break; }
formatter.format(" %s", topicTopWords[topic][position]);
for(TopicScores scores: diagnostics) {
if (scores.wordScoresDefined) {
formatter.format("\t%s=%.4f", scores.name, scores.topicWordScores[topic][position]);
}
}
out.append("\n");
}
}
return out.toString();
}
public String toXML() {
int[] tokensPerTopic = model.tokensPerTopic;
StringBuilder out = new StringBuilder();
Formatter formatter = new Formatter(out, Locale.US);
out.append("\n");
out.append("\n");
for (int topic = 0; topic < numTopics; topic++) {
int[][] matrix = topicCodocumentMatrices[topic];
formatter.format("\n");
TreeSet sortedWords = topicSortedWords.get(topic);
// How many words should we report? Some topics may have fewer than
// the default number of words with non-zero weight.
int limit = numTopWords;
if (sortedWords.size() < numTopWords) { limit = sortedWords.size(); }
double cumulativeProbability = 0.0;
Iterator iterator = sortedWords.iterator();
for (int position=0; position < limit; position++) {
IDSorter info = iterator.next();
double probability = info.getWeight() / tokensPerTopic[topic];
cumulativeProbability += probability;
formatter.format("%s \n", topicTopWords[topic][position].replaceAll("&", "&").replaceAll("<", ">"));
}
out.append(" \n");
}
out.append(" \n");
return out.toString();
}
public class TopicScores {
public String name;
public double[] scores;
public double[][] topicWordScores;
/** Some diagnostics have meaningful values for each word, others do not */
public boolean wordScoresDefined = false;
public TopicScores (String name, int numTopics, int numWords) {
this.name = name;
scores = new double[numTopics];
topicWordScores = new double[numTopics][numWords];
}
public void setTopicScore(int topic, double score) {
scores[topic] = score;
}
public void addToTopicScore(int topic, double score) {
scores[topic] += score;
}
public void setTopicWordScore(int topic, int wordPosition, double score) {
topicWordScores[topic][wordPosition] = score;
wordScoresDefined = true;
}
}
public static void main (String[] args) throws Exception {
InstanceList instances = InstanceList.load(new File(args[0]));
int numTopics = Integer.parseInt(args[1]);
ParallelTopicModel model = new ParallelTopicModel(numTopics, 5.0, 0.01);
model.addInstances(instances);
model.setNumIterations(1000);
model.estimate();
TopicModelDiagnostics diagnostics = new TopicModelDiagnostics(model, 20);
if (args.length == 3) {
PrintWriter out = new PrintWriter(args[2]);
out.println(diagnostics.toXML());
out.close();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy