All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.CRFTrainerByThreadedLabelLikelihood Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.fst;

import java.util.Random;
import java.util.logging.Logger;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;

/**
 * @author Gregory Druck [email protected]
 *
 * Multi-threaded version of CRF trainer.  Note that multi-threaded feature induction
 * and hyperbolic prior are not supported by this code.  
 */
public class CRFTrainerByThreadedLabelLikelihood extends TransducerTrainer implements TransducerTrainer.ByOptimization {
	private static Logger logger = MalletLogger.getLogger(CRFTrainerByThreadedLabelLikelihood.class.getName());

	static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;

	private boolean useSparseWeights;
	private boolean useNoWeights;
	private transient boolean useSomeUnsupportedTrick;
	private boolean converged;
	private int numThreads;
	private int iterationCount;
	private double gaussianPriorVariance;
	private CRF crf;
	private CRFOptimizableByBatchLabelLikelihood optimizable;
	private ThreadedOptimizable threadedOptimizable;
	private Optimizer optimizer;
	private int cachedWeightsStructureStamp; 

	public CRFTrainerByThreadedLabelLikelihood (CRF crf, int numThreads) {
		this.crf = crf;
		this.useSparseWeights = true;
		this.useNoWeights = false;
		this.useSomeUnsupportedTrick = true;
		this.converged = false;
		this.numThreads = numThreads;
		this.iterationCount = 0;
		this.gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
		this.cachedWeightsStructureStamp = -1;
	}
	
	public Transducer getTransducer() { return crf; }
	public CRF getCRF () { return crf; }
	public Optimizer getOptimizer() { return optimizer; }
	public boolean isConverged() { return converged; }
	public boolean isFinishedTraining() { return converged; }
	public int getIteration () { return iterationCount; }
	public void setGaussianPriorVariance (double p) { gaussianPriorVariance = p; }
	public double getGaussianPriorVariance () { return gaussianPriorVariance; }
	public void setUseSparseWeights (boolean b) { useSparseWeights = b; }
	public boolean getUseSparseWeights () { return useSparseWeights; }

	/** Sets whether to use the 'some unsupported trick.' This trick is, if training a CRF
	 * where some training has been done and sparse weights are used, to add a few weights
	 * for feaures that do not occur in the tainig data.
	 * 

* This generally leads to better accuracy at only a small memory cost. * * @param b Whether to use the trick */ public void setUseSomeUnsupportedTrick (boolean b) { useSomeUnsupportedTrick = b; } /** * Use this method to specify whether or not factors * are added to the CRF by this trainer. If you have * already setup the factors in your CRF, you may * not want the trainer to add additional factors. * * @param flag If true, this trainer adds no factors to the CRF. */ public void setAddNoFactors(boolean flag) { this.useNoWeights = flag; } public void shutdown() { threadedOptimizable.shutdown(); } public CRFOptimizableByBatchLabelLikelihood getOptimizableCRF (InstanceList trainingSet) { if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) { if (!useNoWeights) { if (useSparseWeights) { crf.setWeightsDimensionAsIn (trainingSet, useSomeUnsupportedTrick); } else { crf.setWeightsDimensionDensely (); } } optimizable = null; cachedWeightsStructureStamp = crf.weightsStructureChangeStamp; } if (optimizable == null || optimizable.trainingSet != trainingSet) { optimizable = new CRFOptimizableByBatchLabelLikelihood(crf, trainingSet, numThreads); optimizable.setGaussianPriorVariance(gaussianPriorVariance); threadedOptimizable = new ThreadedOptimizable(optimizable, trainingSet, crf.getParameters().getNumFactors(), new CRFCacheStaleIndicator(crf)); optimizer = null; } return optimizable; } public Optimizer getOptimizer (InstanceList trainingSet) { getOptimizableCRF(trainingSet); if (optimizer == null || optimizable != optimizer.getOptimizable()) { optimizer = new LimitedMemoryBFGS(threadedOptimizable); } return optimizer; } public boolean trainIncremental (InstanceList training) { return train (training, Integer.MAX_VALUE); } public boolean train (InstanceList trainingSet, int numIterations) { if (numIterations <= 0) { return false; } assert (trainingSet.size() > 0); getOptimizableCRF(trainingSet); // This will set this.mcrf if necessary getOptimizer(trainingSet); // This will set this.opt if necessary boolean converged = false; logger.info ("CRF about to train with "+numIterations+" iterations"); for (int i = 0; i < numIterations; i++) { try { converged = optimizer.optimize (1); iterationCount++; logger.info ("CRF finished one iteration of maximizer, i="+i); runEvaluators(); } catch (IllegalArgumentException e) { e.printStackTrace(); logger.info ("Catching exception; saying converged."); converged = true; } catch (Exception e) { e.printStackTrace(); logger.info("Catching exception; saying converged."); converged = true; } if (converged) { logger.info ("CRF training has converged, i="+i); break; } } return converged; } /** * Train a CRF on various-sized subsets of the data. This method is typically used to accelerate training by * quickly getting to reasonable parameters on only a subset of the parameters first, then on progressively more data. * @param training The training Instances. * @param numIterationsPerProportion Maximum number of Maximizer iterations per training proportion. * @param trainingProportions If non-null, train on increasingly * larger portions of the data, e.g. new double[] {0.2, 0.5, 1.0}. This can sometimes speedup convergence. * Be sure to end in 1.0 if you want to train on all the data in the end. * @return True if training has converged. */ public boolean train (InstanceList training, int numIterationsPerProportion, double[] trainingProportions) { int trainingIteration = 0; assert (trainingProportions.length > 0); boolean converged = false; for (int i = 0; i < trainingProportions.length; i++) { assert (trainingProportions[i] <= 1.0); logger.info ("Training on "+trainingProportions[i]+"% of the data this round."); if (trainingProportions[i] == 1.0) { converged = this.train (training, numIterationsPerProportion); } else { converged = this.train (training.split (new Random(1), new double[] {trainingProportions[i], 1-trainingProportions[i]})[0], numIterationsPerProportion); } trainingIteration += numIterationsPerProportion; } return converged; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy