
cc.mallet.topics.TopicInferencer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jcore-mallet-2.0.9 Show documentation
Show all versions of jcore-mallet-2.0.9 Show documentation
MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.
The newest version!
package cc.mallet.topics;
import cc.mallet.types.*;
import cc.mallet.util.*;
import java.util.Arrays;
import java.io.*;
public class TopicInferencer implements Serializable {
protected int numTopics;
// These values are used to encode type/topic counts as
// count/topic pairs in a single int.
protected int topicMask;
protected int topicBits;
protected int numTypes;
protected double[] alpha;
protected double beta;
protected double betaSum;
protected int[][] typeTopicCounts;
protected int[] tokensPerTopic;
Alphabet alphabet;
protected Randoms random = null;
double smoothingOnlyMass = 0.0;
double[] cachedCoefficients;
public TopicInferencer (int[][] typeTopicCounts, int[] tokensPerTopic, Alphabet alphabet,
double[] alpha, double beta, double betaSum) {
this.tokensPerTopic = tokensPerTopic;
this.typeTopicCounts = typeTopicCounts;
this.alphabet = alphabet;
numTopics = tokensPerTopic.length;
numTypes = typeTopicCounts.length;
if (Integer.bitCount(numTopics) == 1) {
// exact power of 2
topicMask = numTopics - 1;
topicBits = Integer.bitCount(topicMask);
}
else {
// otherwise add an extra bit
topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
topicBits = Integer.bitCount(topicMask);
}
this.alpha = alpha;
this.beta = beta;
this.betaSum = betaSum;
cachedCoefficients = new double[numTopics];
for (int topic=0; topic < numTopics; topic++) {
smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum);
}
random = new Randoms();
}
public void setRandomSeed(int seed) {
random = new Randoms(seed);
}
/**
* Use Gibbs sampling to infer a topic distribution.
* Topics are initialized to the (or a) most probable topic
* for each token. Using zero iterations returns exactly this
* initial topic distribution.
* This code does not adjust type-topic counts: P(w|t) is clamped.
*/
public double[] getSampledDistribution(Instance instance, int numIterations,
int thinning, int burnIn) {
FeatureSequence tokens = (FeatureSequence) instance.getData();
int docLength = tokens.size();
int[] topics = new int[docLength];
int[] localTopicCounts = new int[numTopics];
int[] localTopicIndex = new int[numTopics];
int type;
int[] currentTypeTopicCounts;
// Initialize all positions to the most common topic
// for that type.
for (int position = 0; position < docLength; position++) {
type = tokens.getIndexAtPosition(position);
// Ignore out of vocabulary terms
if (type < numTypes && typeTopicCounts[type].length != 0) {
currentTypeTopicCounts = typeTopicCounts[type];
// This value should be a topic such that
// no other topic has more tokens of this type
// assigned to it. If for some reason there were
// no tokens of this type in the training data, it
// will default to topic 0, which is no worse than
// random initialization.
topics[position] =
currentTypeTopicCounts[0] & topicMask;
localTopicCounts[topics[position]]++;
}
}
// Build an array that densely lists the topics that
// have non-zero counts.
int denseIndex = 0;
for (int topic = 0; topic < numTopics; topic++) {
if (localTopicCounts[topic] != 0) {
localTopicIndex[denseIndex] = topic;
denseIndex++;
}
}
// Record the total number of non-zero topics
int nonZeroTopics = denseIndex;
// Initialize the topic count/beta sampling bucket
double topicBetaMass = 0.0;
// Initialize cached coefficients and the topic/beta
// normalizing constant.
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
int n = localTopicCounts[topic];
// initialize the normalization constant for the (B * n_{t|d}) term
topicBetaMass += beta * n / (tokensPerTopic[topic] + betaSum);
// update the coefficients for the non-zero topics
cachedCoefficients[topic] = (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);
}
double topicTermMass = 0.0;
double[] topicTermScores = new double[numTopics];
int[] topicTermIndices;
int[] topicTermValues;
int i;
double score;
int oldTopic, newTopic;
double[] result = new double[numTopics];
double sum = 0.0;
for (int iteration = 1; iteration <= numIterations; iteration++) {
// Iterate over the positions (words) in the document
for (int position = 0; position < docLength; position++) {
type = tokens.getIndexAtPosition(position);
// ignore out-of-vocabulary terms
if (type >= numTypes || typeTopicCounts[type].length == 0) { continue; }
oldTopic = topics[position];
currentTypeTopicCounts = typeTopicCounts[type];
// Prepare to sample by adjusting existing counts.
// Note that we do not need to change the smoothing-only
// mass since the denominator is clamped.
topicBetaMass -= beta * localTopicCounts[oldTopic] /
(tokensPerTopic[oldTopic] + betaSum);
// Decrement the local doc/topic counts
localTopicCounts[oldTopic]--;
//assert(localTopicCounts[oldTopic] >= 0);
// Maintain the dense index, if we are deleting
// the old topic
if (localTopicCounts[oldTopic] == 0) {
// First get to the dense location associated with
// the old topic.
denseIndex = 0;
// We know it's in there somewhere, so we don't
// need bounds checking.
while (localTopicIndex[denseIndex] != oldTopic) {
denseIndex++;
}
// shift all remaining dense indices to the left.
while (denseIndex < nonZeroTopics) {
if (denseIndex < localTopicIndex.length - 1) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex + 1];
}
denseIndex++;
}
nonZeroTopics --;
} // finished maintaining local topic index
topicBetaMass += beta * localTopicCounts[oldTopic] /
(tokensPerTopic[oldTopic] + betaSum);
// Reset the cached coefficient for this topic
cachedCoefficients[oldTopic] =
(alpha[oldTopic] + localTopicCounts[oldTopic]) /
(tokensPerTopic[oldTopic] + betaSum);
if (cachedCoefficients[oldTopic] <= 0) {
System.out.println("zero or less coefficient: " + oldTopic + " = (" + alpha[oldTopic] + " + " + localTopicCounts[oldTopic] + ") / ( " + tokensPerTopic[oldTopic] + " + " + betaSum + " );");
}
int index = 0;
int currentTopic, currentValue;
boolean alreadyDecremented = false;
topicTermMass = 0.0;
while (index < currentTypeTopicCounts.length &&
currentTypeTopicCounts[index] > 0) {
currentTopic = currentTypeTopicCounts[index] & topicMask;
currentValue = currentTypeTopicCounts[index] >> topicBits;
score =
cachedCoefficients[currentTopic] * currentValue;
topicTermMass += score;
topicTermScores[index] = score;
index++;
}
double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
double origSample = sample;
// Make sure it actually gets set
newTopic = -1;
if (sample < topicTermMass) {
//topicTermCount++;
i = -1;
while (sample > 0) {
i++;
sample -= topicTermScores[i];
}
newTopic = currentTypeTopicCounts[i] & topicMask;
}
else {
sample -= topicTermMass;
if (sample < topicBetaMass) {
//betaTopicCount++;
sample /= beta;
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
sample -= localTopicCounts[topic] /
(tokensPerTopic[topic] + betaSum);
if (sample <= 0.0) {
newTopic = topic;
break;
}
}
}
else {
sample -= topicBetaMass;
sample /= beta;
newTopic = 0;
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
while (sample > 0.0) {
newTopic++;
if (newTopic >= numTopics) {
index = 0;
while (index < currentTypeTopicCounts.length &&
currentTypeTopicCounts[index] > 0) {
currentTopic = currentTypeTopicCounts[index] & topicMask;
currentValue = currentTypeTopicCounts[index] >> topicBits;
System.out.println(currentTopic + "\t" + currentValue + "\t" + topicTermScores[index] +
"\t" + cachedCoefficients[currentTopic]);
index++;
}
}
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
}
}
}
topics[position] = newTopic;
topicBetaMass -= beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
localTopicCounts[newTopic]++;
// If this is a new topic for this document,
// add the topic to the dense index.
if (localTopicCounts[newTopic] == 1) {
// First find the point where we
// should insert the new topic by going to
// the end (which is the only reason we're keeping
// track of the number of non-zero
// topics) and working backwards
denseIndex = nonZeroTopics;
while (denseIndex > 0 &&
localTopicIndex[denseIndex - 1] > newTopic) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex - 1];
denseIndex--;
}
localTopicIndex[denseIndex] = newTopic;
nonZeroTopics++;
}
// update the coefficients for the non-zero topics
cachedCoefficients[newTopic] =
(alpha[newTopic] + localTopicCounts[newTopic]) /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass += beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
}
if (iteration > burnIn &&
(iteration - burnIn) % thinning == 0) {
// Save a sample
for (int topic=0; topic < numTopics; topic++) {
result[topic] += alpha[topic] + localTopicCounts[topic];
sum += alpha[topic] + localTopicCounts[topic];
}
}
}
// Clean up our mess: reset the coefficients to values with only
// smoothing. The next doc will update its own non-zero topics...
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
cachedCoefficients[topic] =
alpha[topic] / (tokensPerTopic[topic] + betaSum);
}
if (sum == 0.0) {
// Save at least one sample
for (int topic=0; topic < numTopics; topic++) {
result[topic] = alpha[topic] + localTopicCounts[topic];
sum += result[topic];
}
}
// Normalize
for (int topic=0; topic < numTopics; topic++) {
result[topic] /= sum;
}
return result;
}
/**
* Infer topics for the provided instances and
* write distributions to the provided file.
*
* @param instances
* @param distributionsFile
* @param numIterations The total number of iterations of sampling per document
* @param thinning The number of iterations between saved samples
* @param burnIn The number of iterations before the first saved sample
* @param threshold The minimum proportion of a given topic that will be written
* @param max The total number of topics to report per document]
*/
public void writeInferredDistributions(InstanceList instances,
File distributionsFile,
int numIterations, int thinning, int burnIn,
double threshold, int max) throws IOException {
PrintWriter out = new PrintWriter(distributionsFile);
out.print ("#doc name topic proportion ...\n");
IDSorter[] sortedTopics = new IDSorter[ numTopics ];
for (int topic = 0; topic < numTopics; topic++) {
// Initialize the sorters with dummy values
sortedTopics[topic] = new IDSorter(topic, topic);
}
if (max < 0 || max > numTopics) {
max = numTopics;
}
int doc = 0;
for (Instance instance: instances) {
StringBuilder builder = new StringBuilder();
double[] topicDistribution =
getSampledDistribution(instance, numIterations,
thinning, burnIn);
builder.append(doc);
builder.append("\t");
if (instance.getName() != null) {
builder.append(instance.getName());
}
else {
builder.append("no-name");
}
if (threshold > 0.0) {
for (int topic = 0; topic < numTopics; topic++) {
sortedTopics[topic].set(topic, topicDistribution[topic]);
}
Arrays.sort(sortedTopics);
for (int i = 0; i < max; i++) {
if (sortedTopics[i].getWeight() < threshold) { break; }
builder.append("\t" + sortedTopics[i].getID() +
"\t" + sortedTopics[i].getWeight());
}
}
else {
for (int topic = 0; topic < numTopics; topic++) {
builder.append("\t" + topicDistribution[topic]);
}
}
out.println(builder);
doc++;
}
out.close();
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private static final int NULL_INTEGER = -1;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject(alphabet);
out.writeInt(numTopics);
out.writeInt(topicMask);
out.writeInt(topicBits);
out.writeInt(numTypes);
out.writeObject(alpha);
out.writeDouble(beta);
out.writeDouble(betaSum);
out.writeObject(typeTopicCounts);
out.writeObject(tokensPerTopic);
out.writeObject(random);
out.writeDouble(smoothingOnlyMass);
out.writeObject(cachedCoefficients);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt ();
alphabet = (Alphabet) in.readObject();
numTopics = in.readInt();
topicMask = in.readInt();
topicBits = in.readInt();
numTypes = in.readInt();
alpha = (double[]) in.readObject();
beta = in.readDouble();
betaSum = in.readDouble();
typeTopicCounts = (int[][]) in.readObject();
tokensPerTopic = (int[]) in.readObject();
random = (Randoms) in.readObject();
smoothingOnlyMass = in.readDouble();
cachedCoefficients = (double[]) in.readObject();
}
public static TopicInferencer read (File f) throws Exception {
TopicInferencer inferencer = null;
ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f));
inferencer = (TopicInferencer) ois.readObject();
ois.close();
return inferencer;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy