All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.topics.HierarchicalLDA Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.topics;

import java.util.ArrayList;
import java.util.Arrays;
import java.io.*;

import cc.mallet.types.*;
import cc.mallet.util.Randoms;

import gnu.trove.*;

public class HierarchicalLDA {

    InstanceList instances;
    InstanceList testing;

    NCRPNode rootNode, node;

    int numLevels;
    int numDocuments;
    int numTypes;

    double alpha; // smoothing on topic distributions
    double gamma; // "imaginary" customers at the next, as yet unused table
    double eta;   // smoothing on word distributions
    double etaSum;

    int[][] levels; // indexed < doc, token >
    NCRPNode[] documentLeaves; // currently selected path (ie leaf node) through the NCRP tree

	int totalNodes = 0;

	String stateFile = "hlda.state";

    Randoms random;

	boolean showProgress = true;
	
	int displayTopicsInterval = 50;
	int numWordsToDisplay = 10;

    public HierarchicalLDA () {
		alpha = 10.0;
		gamma = 1.0;
		eta = 0.1;
    }

	public void setAlpha(double alpha) {
		this.alpha = alpha;
	}

	public void setGamma(double gamma) {
		this.gamma = gamma;
	}

	public void setEta(double eta) {
		this.eta = eta;
	}

	public void setStateFile(String stateFile) {
		this.stateFile = stateFile;
	}

	public void setTopicDisplay(int interval, int words) {
		displayTopicsInterval = interval;
		numWordsToDisplay = words;
	}

	/**  
	 *  This parameter determines whether the sampler outputs 
	 *   shows progress by outputting a character after every iteration.
	 */
	public void setProgressDisplay(boolean showProgress) {
		this.showProgress = showProgress;
	}

    public void initialize(InstanceList instances, InstanceList testing,
						   int numLevels, Randoms random) {
		this.instances = instances;
		this.testing = testing;
		this.numLevels = numLevels;
		this.random = random;

		if (! (instances.get(0).getData() instanceof FeatureSequence)) {
			throw new IllegalArgumentException("Input must be a FeatureSequence, using the --feature-sequence option when impoting data, for example");
		}

		numDocuments = instances.size();
		numTypes = instances.getDataAlphabet().size();
	
		etaSum = eta * numTypes;

		// Initialize a single path

		NCRPNode[] path = new NCRPNode[numLevels];

		rootNode = new NCRPNode(numTypes);

		levels = new int[numDocuments][];
		documentLeaves = new NCRPNode[numDocuments];

		// Initialize and fill the topic pointer arrays for 
		//  every document. Set everything to the single path that 
		//  we added earlier.
		for (int doc=0; doc < numDocuments; doc++) {
            FeatureSequence fs = (FeatureSequence) instances.get(doc).getData();
            int seqLen = fs.getLength();

			path[0] = rootNode;
			rootNode.customers++;
			for (int level = 1; level < numLevels; level++) {
				path[level] = path[level-1].select();
				path[level].customers++;
			}
			node = path[numLevels - 1];
	    
			levels[doc] = new int[seqLen];
			documentLeaves[doc] = node;

			for (int token=0; token < seqLen; token++) {
				int type = fs.getIndexAtPosition(token);
				levels[doc][token] = random.nextInt(numLevels);
				node = path[ levels[doc][token] ];
				node.totalTokens++;
				node.typeCounts[type]++;
			}
		}
	}

	public void estimate(int numIterations) {
		for (int iteration = 1; iteration <= numIterations; iteration++) {
			for (int doc=0; doc < numDocuments; doc++) {
				samplePath(doc, iteration);
			}
			for (int doc=0; doc < numDocuments; doc++) {
				sampleTopics(doc);
			}
			
			if (showProgress) {
				System.out.print(".");
				if (iteration % 50 == 0) {
					System.out.println(" " + iteration);
				}
			}

			if (iteration % displayTopicsInterval == 0) {
				printNodes();
			}
		}
    }

    public void samplePath(int doc, int iteration) {
		NCRPNode[] path = new NCRPNode[numLevels];
		NCRPNode node;
		int level, token, type, topicCount;
		double weight;

		node = documentLeaves[doc];
		for (level = numLevels - 1; level >= 0; level--) {
			path[level] = node;
			node = node.parent;
		}

		documentLeaves[doc].dropPath();

		TObjectDoubleHashMap nodeWeights = 
			new TObjectDoubleHashMap();
	
		// Calculate p(c_m | c_{-m})
		calculateNCRP(nodeWeights, rootNode, 0.0);

		// Add weights for p(w_m | c, w_{-m}, z)
	
		// The path may have no further customers and therefore
		//  be unavailable, but it should still exist since we haven't
		//  reset documentLeaves[doc] yet...
	
		TIntIntHashMap[] typeCounts = new TIntIntHashMap[numLevels];

		int[] docLevels;

		for (level = 0; level < numLevels; level++) {
			typeCounts[level] = new TIntIntHashMap();
		}

		docLevels = levels[doc];
		FeatureSequence fs = (FeatureSequence) instances.get(doc).getData();
	    
		// Save the counts of every word at each level, and remove
		//  counts from the current path

		for (token = 0; token < docLevels.length; token++) {
			level = docLevels[token];
			type = fs.getIndexAtPosition(token);
	    
			if (! typeCounts[level].containsKey(type)) {
				typeCounts[level].put(type, 1);
			}
			else {
				typeCounts[level].increment(type);
			}

			path[level].typeCounts[type]--;
			assert(path[level].typeCounts[type] >= 0);
	    
			path[level].totalTokens--;	    
			assert(path[level].totalTokens >= 0);
		}

		// Calculate the weight for a new path at a given level.
		double[] newTopicWeights = new double[numLevels];
		for (level = 1; level < numLevels; level++) {  // Skip the root...
			int[] types = typeCounts[level].keys();
			int totalTokens = 0;

			for (int t: types) {
				for (int i=0; i 1) { System.out.println(newTopicWeights[level]); }
		}
	
		calculateWordLikelihood(nodeWeights, rootNode, 0.0, typeCounts, newTopicWeights, 0, iteration);

		NCRPNode[] nodes = nodeWeights.keys(new NCRPNode[] {});
		double[] weights = new double[nodes.length];
		double sum = 0.0;
		double max = Double.NEGATIVE_INFINITY;

		// To avoid underflow, we're using log weights and normalizing the node weights so that 
		//  the largest weight is always 1.
		for (int i=0; i max) {
				max = nodeWeights.get(nodes[i]);
			}
		}

		for (int i=0; i 1) {
			  if (nodes[i] == documentLeaves[doc]) {
			  System.out.print("* ");
			  }
			  System.out.println(((NCRPNode) nodes[i]).level + "\t" + weights[i] + 
			  "\t" + nodeWeights.get(nodes[i]));
			  }
			*/

			sum += weights[i];
		}

		//if (iteration > 1) {System.out.println();}

		node = nodes[ random.nextDiscrete(weights, sum) ];

		// If we have picked an internal node, we need to 
		//  add a new path.
		if (! node.isLeaf()) {
			node = node.getNewLeaf();
		}
	
		node.addPath();
		documentLeaves[doc] = node;

		for (level = numLevels - 1; level >= 0; level--) {
			int[] types = typeCounts[level].keys();

			for (int t: types) {
				node.typeCounts[t] += typeCounts[level].get(t);
				node.totalTokens += typeCounts[level].get(t);
			}

			node = node.parent;
		}
    }

    public void calculateNCRP(TObjectDoubleHashMap nodeWeights, 
							  NCRPNode node, double weight) {
		for (NCRPNode child: node.children) {
			calculateNCRP(nodeWeights, child,
						  weight + Math.log((double) child.customers / (node.customers + gamma)));
		}

		nodeWeights.put(node, weight + Math.log(gamma / (node.customers + gamma)));
    }

    public void calculateWordLikelihood(TObjectDoubleHashMap nodeWeights,
										NCRPNode node, double weight, 
										TIntIntHashMap[] typeCounts, double[] newTopicWeights,
										int level, int iteration) {
	
		// First calculate the likelihood of the words at this level, given
		//  this topic.
		double nodeWeight = 0.0;
		int[] types = typeCounts[level].keys();
		int totalTokens = 0;
	
		//if (iteration > 1) { System.out.println(level + " " + nodeWeight); }

		for (int type: types) {
			for (int i=0; i 1) {
				  System.out.println("(" +eta + " + " + node.typeCounts[type] + " + " + i + ") /" + 
				  "(" + etaSum + " + " + node.totalTokens + " + " + totalTokens + ")" + 
				  " : " + nodeWeight);
				  }
				*/

			}
		}

		//if (iteration > 1) { System.out.println(level + " " + nodeWeight); }

		// Propagate that weight to the child nodes

		for (NCRPNode child: node.children) {
            calculateWordLikelihood(nodeWeights, child, weight + nodeWeight,
									typeCounts, newTopicWeights, level + 1, iteration);
        }

		// Finally, if this is an internal node, add the weight of
		//  a new path

		level++;
		while (level < numLevels) {
			nodeWeight += newTopicWeights[level];
			level++;
		}

		nodeWeights.adjustValue(node, nodeWeight);

    }

    /** Propagate a topic weight to a node and all its children.
		weight is assumed to be a log.
	*/
    public void propagateTopicWeight(TObjectDoubleHashMap nodeWeights,
									 NCRPNode node, double weight) {
		if (! nodeWeights.containsKey(node)) {
			// calculating the NCRP prior proceeds from the
			//  root down (ie following child links),
			//  but adding the word-topic weights comes from
			//  the bottom up, following parent links and then 
			//  child links. It's possible that the leaf node may have
			//  been removed just prior to this round, so the current
			//  node may not have an NCRP weight. If so, it's not 
			//  going to be sampled anyway, so ditch it.
			return;
		}
	
		for (NCRPNode child: node.children) {
			propagateTopicWeight(nodeWeights, child, weight);
		}

		nodeWeights.adjustValue(node, weight);
    }

    public void sampleTopics(int doc) {
		FeatureSequence fs = (FeatureSequence) instances.get(doc).getData();
		int seqLen = fs.getLength();
		int[] docLevels = levels[doc];
		NCRPNode[] path = new NCRPNode[numLevels];
		NCRPNode node;
		int[] levelCounts = new int[numLevels];
		int type, token, level;
		double sum;

		// Get the leaf
		node = documentLeaves[doc];
		for (level = numLevels - 1; level >= 0; level--) {
			path[level] = node;
			node = node.parent;
		}

		double[] levelWeights = new double[numLevels];

		// Initialize level counts
		for (token = 0; token < seqLen; token++) {
			levelCounts[ docLevels[token] ]++;
		}

		for (token = 0; token < seqLen; token++) {
			type = fs.getIndexAtPosition(token);
	    
			levelCounts[ docLevels[token] ]--;
			node = path[ docLevels[token] ];
			node.typeCounts[type]--;
			node.totalTokens--;
	    

			sum = 0.0;
			for (level=0; level < numLevels; level++) {
				levelWeights[level] = 
					(alpha + levelCounts[level]) * 
					(eta + path[level].typeCounts[type]) /
					(etaSum + path[level].totalTokens);
				sum += levelWeights[level];
			}
			level = random.nextDiscrete(levelWeights, sum);

			docLevels[token] = level;
			levelCounts[ docLevels[token] ]++;
			node = path[ level ];
			node.typeCounts[type]++;
			node.totalTokens++;
		}
    }

	/**
	 *  Writes the current sampling state to the file specified in stateFile.
	 */
	public void printState() throws IOException, FileNotFoundException {
		printState(new PrintWriter(new BufferedWriter(new FileWriter(stateFile))));
	}

	/**
	 *  Write a text file describing the current sampling state. 
	 */
    public void printState(PrintWriter out) throws IOException {
		int doc = 0;

		Alphabet alphabet = instances.getDataAlphabet();

		for (Instance instance: instances) {
			FeatureSequence fs = (FeatureSequence) instance.getData();
			int seqLen = fs.getLength();
			int[] docLevels = levels[doc];
			NCRPNode node;
			int type, token, level;

			StringBuffer path = new StringBuffer();
			
			// Start with the leaf, and build a string describing the path for this doc
			node = documentLeaves[doc];
			for (level = numLevels - 1; level >= 0; level--) {
				path.append(node.nodeID + " ");
				node = node.parent;
			}

			for (token = 0; token < seqLen; token++) {
				type = fs.getIndexAtPosition(token);
				level = docLevels[token];
				
				// The "" just tells java we're not trying to add a string and an int
				out.println(path + "" + type + " " + alphabet.lookupObject(type) + " " + level + " ");
			}

			doc++;
		}
	}	    

    public void printNodes() {
		printNode(rootNode, 0, false);
    }
    
    public void printNodes(boolean withWeight) {
		printNode(rootNode, 0, withWeight);
    }

    public void printNode(NCRPNode node, int indent, boolean withWeight) {
		StringBuffer out = new StringBuffer();
		for (int i=0; i max) {
                    max = likelihoods[doc][sample];
                }
            }

            double sum = 0.0;
            for (sample = 0; sample < numSamples; sample++) {
                sum += Math.exp(likelihoods[doc][sample] - max);
            }

            averageLogLikelihood += Math.log(sum) + max - logNumSamples;
        }

		return averageLogLikelihood;
    }

	/** 
	 *  This method is primarily for testing purposes. The {@link cc.mallet.topics.tui.HierarchicalLDATUI}
	 *   class has a more flexible interface for command-line use.
	 */
    public static void main (String[] args) {
		try {
			InstanceList instances = InstanceList.load(new File(args[0]));
			InstanceList testing = InstanceList.load(new File(args[1]));

			HierarchicalLDA sampler = new HierarchicalLDA();
			sampler.initialize(instances, testing, 5, new Randoms());
			sampler.estimate(250);
		} catch (Exception e) {
			e.printStackTrace();
		}
    }

    class NCRPNode {
		int customers;
		ArrayList children;
		NCRPNode parent;
		int level;

		int totalTokens;
		int[] typeCounts;

		public int nodeID;

		public NCRPNode(NCRPNode parent, int dimensions, int level) {
			customers = 0;
			this.parent = parent;
			children = new ArrayList();
			this.level = level;

			//System.out.println("new node at level " + level);
	    
			totalTokens = 0;
			typeCounts = new int[dimensions];

			nodeID = totalNodes;
			totalNodes++;
		}

		public NCRPNode(int dimensions) {
			this(null, dimensions, 0);
		}

		public NCRPNode addChild() {
			NCRPNode node = new NCRPNode(this, typeCounts.length, level + 1);
			children.add(node);
			return node;
		}

		public boolean isLeaf() {
			return level == numLevels - 1;
		}

		public NCRPNode getNewLeaf() {
			NCRPNode node = this;
			for (int l=level; l




© 2015 - 2025 Weber Informatics LLC | Privacy Policy