All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.topics.tui.HierarchicalLDATUI Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.topics.tui;

import cc.mallet.util.CommandOption;
import cc.mallet.util.Randoms;
import cc.mallet.types.InstanceList;
import cc.mallet.topics.HierarchicalLDA;

import java.io.*;

public class HierarchicalLDATUI {
	
	static CommandOption.String inputFile = new CommandOption.String
		(HierarchicalLDATUI.class, "input", "FILENAME", true, null,
		 "The filename from which to read the list of training instances.  Use - for stdin.  " +
		 "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);
	
	static CommandOption.String testingFile = new CommandOption.String
		(HierarchicalLDATUI.class, "testing", "FILENAME", true, null,
		 "The filename from which to read the list of instances for held-out likelihood calculation.  Use - for stdin.  " +
		 "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);
	
	static CommandOption.String stateFile = new CommandOption.String
		(HierarchicalLDATUI.class, "output-state", "FILENAME", true, null,
		 "The filename in which to write the Gibbs sampling state after at the end of the iterations.  " +
		 "By default this is null, indicating that no file will be written.", null);
	
	static CommandOption.Integer randomSeed = new CommandOption.Integer
		(HierarchicalLDATUI.class, "random-seed", "INTEGER", true, 0,
		 "The random seed for the Gibbs sampler.  Default is 0, which will use the clock.", null);
	
	static CommandOption.Integer numIterations = new CommandOption.Integer
	  	(Vectors2Topics.class, "num-iterations", "INTEGER", true, 1000,
		 "The number of iterations of Gibbs sampling.", null);

	static CommandOption.Boolean showProgress = new CommandOption.Boolean
		(HierarchicalLDATUI.class, "show-progress", "BOOLEAN", false, true,
		 "If true, print a character to standard output after every sampling iteration.", null);

	static CommandOption.Integer showTopicsInterval = new CommandOption.Integer
		(HierarchicalLDATUI.class, "show-topics-interval", "INTEGER", true, 50,
		 "The number of iterations between printing a brief summary of the topics so far.", null);

	static CommandOption.Integer topWords = new CommandOption.Integer
		(HierarchicalLDATUI.class, "num-top-words", "INTEGER", true, 20,
		 "The number of most probable words to print for each topic after model estimation.", null);

	static CommandOption.Integer numLevels = new CommandOption.Integer
		(HierarchicalLDATUI.class, "num-levels", "INTEGER", true, 3,
		 "The number of levels in the tree.", null);

	static CommandOption.Double alpha = new CommandOption.Double
		(HierarchicalLDATUI.class, "alpha", "DECIMAL", true, 10.0,
		 "Alpha parameter: smoothing over level distributions.", null);

	static CommandOption.Double gamma = new CommandOption.Double
		(HierarchicalLDATUI.class, "gamma", "DECIMAL", true, 1.0,
		 "Gamma parameter: CRP smoothing parameter; number of imaginary customers at next, as yet unused table", null);

	static CommandOption.Double eta = new CommandOption.Double
		(HierarchicalLDATUI.class, "eta", "DECIMAL", true, 0.1,
		 "Eta parameter: smoothing over topic-word distributions", null);
	
	public static void main (String[] args) throws java.io.IOException {

		// Process the command-line options
		CommandOption.setSummary (HierarchicalLDATUI.class,
								  "Hierarchical LDA with a fixed tree depth.");
		CommandOption.process (HierarchicalLDATUI.class, args);
		
		// Load instance lists

		if (inputFile.value() == null) {
			System.err.println("Input instance list is required, use --input option");
			System.exit(1);
		}

		InstanceList instances = InstanceList.load(new File(inputFile.value()));
		InstanceList testing = null;
		if (testingFile.value() != null) {
			testing = InstanceList.load(new File(testingFile.value()));
		}
	
		HierarchicalLDA hlda = new HierarchicalLDA();
		
		// Set hyperparameters

		hlda.setAlpha(alpha.value());
		hlda.setGamma(gamma.value());
		hlda.setEta(eta.value());
		
		// Display preferences

		hlda.setTopicDisplay(showTopicsInterval.value(), topWords.value());
		hlda.setProgressDisplay(showProgress.value());

		// Initialize random number generator

		Randoms random = null;
		if (randomSeed.value() == 0) {
			random = new Randoms();
		}
		else {
			random = new Randoms(randomSeed.value());
		}

		// Initialize and start the sampler

		hlda.initialize(instances, testing, numLevels.value(), random);
		hlda.estimate(numIterations.value());
		
		// Output results

		if (stateFile.value() != null) {
			hlda.printState(new PrintWriter(stateFile.value()));
		}

		if (testing != null) {
			double empiricalLikelihood = hlda.empiricalLikelihood(1000, testing);
			System.out.println("Empirical likelihood: " + empiricalLikelihood);
		}
		
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy