All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.HMMTrainerByLikelihood Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.fst;

import java.util.logging.Logger;

import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;

public class HMMTrainerByLikelihood extends TransducerTrainer {
	private static Logger logger = MalletLogger
			.getLogger(HMMTrainerByLikelihood.class.getName());

	HMM hmm;
	InstanceList trainingSet, unlabeledSet;
	int iterationCount = 0;
	boolean converged = false;

	public HMMTrainerByLikelihood(HMM hmm) {
		this.hmm = hmm;
	}

	@Override
	public Transducer getTransducer() {
		return hmm;
	}

	@Override
	public int getIteration() {
		return iterationCount;
	}

	@Override
	public boolean isFinishedTraining() {
		return converged;
	}

	@Override
	public boolean train(InstanceList trainingSet, int numIterations) {
		return train(trainingSet, null, numIterations);
	}

	public boolean train(InstanceList trainingSet, InstanceList unlabeledSet,
			int numIterations) {
		if (hmm.emissionEstimator == null)
			hmm.reset();

		converged = false;
		double threshold = 0.001;
		double logLikelihood = Double.NEGATIVE_INFINITY, prevLogLikelihood;
		for (int iter = 0; iter < numIterations; iter++) {
			prevLogLikelihood = logLikelihood;
			logLikelihood = 0;
			for (Instance inst : trainingSet) {
				FeatureSequence input = (FeatureSequence) inst.getData();
				FeatureSequence output = (FeatureSequence) inst.getTarget();
				double obsLikelihood = new SumLatticeDefault(hmm, input,
						output, hmm.new Incrementor()).getTotalWeight();
				logLikelihood += obsLikelihood;
			}
			logger.info("getValue() (observed log-likelihood) = "
					+ logLikelihood);

			if (unlabeledSet != null) {
				int numEx = 0;
				for (Instance inst : unlabeledSet) {
					numEx++;
					if (numEx % 100 == 0) {
						System.err.print(numEx + ". ");
						System.err.flush();
					}
					FeatureSequence input = (FeatureSequence) inst.getData();
					double hiddenLikelihood = new SumLatticeDefault(hmm, input,
							null, hmm.new Incrementor()).getTotalWeight();
					logLikelihood += hiddenLikelihood;
				}
				System.err.println();
			}
			logger.info("getValue() (log-likelihood) = " + logLikelihood);

			hmm.estimate();
			iterationCount++;
			logger.info("HMM finished one iteration of maximizer, i=" + iter);

			runEvaluators();

			if (Math.abs(logLikelihood - prevLogLikelihood) < threshold) {
				converged = true;
				logger.info("HMM training has converged, i=" + iter);
				break;
			}
		}

		return converged;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy