All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.MEMMTrainer Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.fst;

import java.util.BitSet;
import java.util.logging.Logger;

import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;

import cc.mallet.fst.MEMM.State;
import cc.mallet.fst.MEMM.TransitionIterator;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;

import cc.mallet.util.MalletLogger;

/**
 * Trains and evaluates a {@link MEMM}.
 */
public class MEMMTrainer extends TransducerTrainer 
{
	private static Logger logger = MalletLogger.getLogger(MEMMTrainer.class.getName());

	MEMM memm;
	private boolean gatheringTrainingData = false;
	// After training sets have been gathered in the states, record which
	//   InstanceList we've gathers, so we don't double-count instances.
	private InstanceList trainingGatheredFor;
	// gsc: user is supposed to set the weights manually, so this flag is not needed
//	boolean useSparseWeights = true;
	MEMMOptimizableByLabelLikelihood omemm;
	
	public MEMMTrainer (MEMM memm) {
		this.memm = memm;
	}

	public MEMMOptimizableByLabelLikelihood getOptimizableMEMM (InstanceList trainingSet) {
		return new MEMMOptimizableByLabelLikelihood (memm, trainingSet);
	}

//	public MEMMTrainer setUseSparseWeights (boolean f) { useSparseWeights = f;  return this; }

	/**
	 * Trains a MEMM until convergence.
	 */
	public boolean train (InstanceList training) {
		return train (training, Integer.MAX_VALUE);
	}

	/**
	 * Trains a MEMM for specified number of iterations or until convergence whichever
	 * occurs first; returns true if training converged within specified iterations.
	 */
	public boolean train (InstanceList training, int numIterations)
	{
		if (numIterations <= 0)
			return false;
		assert (training.size() > 0);

		// Allocate space for the parameters, and place transition FeatureVectors in
		// per-source-state InstanceLists.
		// Here, gatheringTrainingSets will be true, and these methods will result
		// in new InstanceList's being created in each source state, and the FeatureVectors
		// of their outgoing transitions to be added to them as the data field in the Instances.
		if (trainingGatheredFor != training) {
			gatherTrainingSets (training);
		}
		// gsc: the user has to set the weights manually
//		if (useSparseWeights) {
//			memm.setWeightsDimensionAsIn (training, false);
//		} else {
//			memm.setWeightsDimensionDensely ();
//		}


		/*
		if (false) {
			// Expectation-based placement of training data would go here.
			for (int i = 0; i < training.size(); i++) {
				Instance instance = training.get(i);
				FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
				FeatureSequence output = (FeatureSequence) instance.getTarget();
				// Do it for the paths consistent with the labels...
				gatheringConstraints = true;
				new SumLatticeDefault (this, input, output, true);
				// ...and also do it for the paths selected by the current model (so we will get some negative weights)
				gatheringConstraints = false;
				if (this.someTrainingDone)
					// (do this once some training is done)
					new SumLatticeDefault (this, input, null, true);
			}
			gatheringWeightsPresent = false;
			SparseVector[] newWeights = new SparseVector[weights.length];
			for (int i = 0; i < weights.length; i++) {
				int numLocations = weightsPresent[i].cardinality ();
				logger.info ("CRF weights["+weightAlphabet.lookupObject(i)+"] num features = "+numLocations);
				int[] indices = new int[numLocations];
				for (int j = 0; j < numLocations; j++) {
					indices[j] = weightsPresent[i].nextSetBit (j == 0 ? 0 : indices[j-1]+1);
					//System.out.println ("CRF4 has index "+indices[j]);
				}
				newWeights[i] = new IndexedSparseVector (indices, new double[numLocations],
						numLocations, numLocations, false, false, false);
				newWeights[i].plusEqualsSparse (weights[i]);
			}
			weights = newWeights;
		}
		*/

		omemm = new MEMMOptimizableByLabelLikelihood (memm, training);
		// Gather the constraints
		omemm.gatherExpectationsOrConstraints (true);
		Optimizer maximizer = new LimitedMemoryBFGS(omemm);
		int i;
//		boolean continueTraining = true;
		boolean converged = false;
		logger.info ("CRF about to train with "+numIterations+" iterations");
		for (i = 0; i < numIterations; i++) {
			try {
				converged = maximizer.optimize (1);
				logger.info ("CRF finished one iteration of maximizer, i="+i);
				runEvaluators();
			} catch (IllegalArgumentException e) {
				e.printStackTrace();
				logger.info ("Catching exception; saying converged.");
				converged = true;
			}
			if (converged) {
				logger.info ("CRF training has converged, i="+i);
				break;
			}
		}
		logger.info ("About to setTrainable(false)");
		return converged;
	}


	void gatherTrainingSets (InstanceList training)
	{
		if (trainingGatheredFor != null) {
			// It would be easy enough to support this, just go through all the states and set trainingSet to null.
			throw new UnsupportedOperationException ("Training with multiple sets not supported.");
		}

		trainingGatheredFor = training;
		for (int i = 0; i < training.size(); i++) {
			Instance instance = training.get(i);
			FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
			FeatureSequence output = (FeatureSequence) instance.getTarget();
			// Do it for the paths consistent with the labels...
			new SumLatticeDefault (memm, input, output, new Transducer.Incrementor() {
				public void incrementFinalState(Transducer.State s, double count) { }
				public void incrementInitialState(Transducer.State s, double count) { }
				public void incrementTransition(Transducer.TransitionIterator ti, double count) {
					MEMM.State source = (MEMM.State) ti.getSourceState();
					if (count != 0) {
						// Create the source state's trainingSet if it doesn't exist yet.
						if (source.trainingSet == null)
							// New InstanceList with a null pipe, because it doesn't do any processing of input.
							source.trainingSet = new InstanceList (null);
						// TODO We should make sure we don't add duplicates (through a second call to setWeightsDimenstion..!
						// TODO Note that when the training data still allows ambiguous outgoing transitions
						// this will add the same FV more than once to the source state's trainingSet, each
						// with >1.0 weight.  Not incorrect, but inefficient.
//						System.out.println ("From: "+source.getName()+" ---> "+getOutput()+" : "+getInput());
						source.trainingSet.add (new Instance(ti.getInput (), ti.getOutput (), null, null), count);
					}
				}
			});
		}
	}

	/**
	 * Not implemented yet.
	 * 
	 * @throws UnsupportedOperationException
	 */
	public boolean train (InstanceList training, InstanceList validation, InstanceList testing,
			TransducerEvaluator eval, int numIterations,
			int numIterationsPerProportion,
			double[] trainingProportions)
	{
		throw new UnsupportedOperationException();
	}

  /**
   * Not implemented yet.
   * 
   * @throws UnsupportedOperationException
   */
	public boolean trainWithFeatureInduction (InstanceList trainingData,
			InstanceList validationData, InstanceList testingData,
			TransducerEvaluator eval, int numIterations,
			int numIterationsBetweenFeatureInductions,
			int numFeatureInductions,
			int numFeaturesPerFeatureInduction,
			double trueLabelProbThreshold,
			boolean clusteredFeatureInduction,
			double[] trainingProportions,
			String gainName)
	{
		throw new UnsupportedOperationException();
	}


	public void printInstanceLists ()
	{
		for (int i = 0; i < memm.numStates(); i++) {
			State state = (State) memm.getState (i);
			InstanceList training = state.trainingSet;
			System.out.println ("State "+i+" : "+state.getName());
			if (training == null) {
				System.out.println ("No data");
				continue;
			}
			for (int j = 0; j < training.size(); j++) {
				Instance inst = training.get (j);
				System.out.println ("From : "+state.getName()+" To : "+inst.getTarget());
				System.out.println ("Instance "+j);
				System.out.println (inst.getTarget());
				System.out.println (inst.getData());
			}
		}
	}

	/**
	 * Represents the terms in the objective function.
	 * 

* The weights are trained by matching the expectations of the model to the observations gathered from the data. */ @SuppressWarnings("serial") public class MEMMOptimizableByLabelLikelihood extends CRFOptimizableByLabelLikelihood implements Optimizable.ByGradientValue { BitSet infiniteValues = null; protected MEMMOptimizableByLabelLikelihood (MEMM memm, InstanceList trainingData) { super (memm, trainingData); expectations = new CRF.Factors (memm); constraints = new CRF.Factors (memm); } // if constraints=false, return log probability of the training labels protected double gatherExpectationsOrConstraints (boolean gatherConstraints) { // Instance values must either always or never be included in // the total values; we can't just sometimes skip a value // because it is infinite, this throws off the total values. boolean initializingInfiniteValues = false; CRF.Factors factors = gatherConstraints ? constraints : expectations; CRF.Factors.Incrementor factorIncrementor = factors.new Incrementor (); if (infiniteValues == null) { infiniteValues = new BitSet (); initializingInfiniteValues = true; } double labelLogProb = 0; for (int i = 0; i < memm.numStates(); i++) { MEMM.State s = (State) memm.getState (i); if (s.trainingSet == null) { System.out.println ("Empty training set for state "+s.name); continue; } for (int j = 0; j < s.trainingSet.size(); j++) { Instance instance = s.trainingSet.get (j); double instWeight = s.trainingSet.getInstanceWeight (j); FeatureVector fv = (FeatureVector) instance.getData (); String labelString = (String) instance.getTarget (); TransitionIterator iter = new TransitionIterator (s, fv, gatherConstraints?labelString:null, memm); while (iter.hasNext ()) { // gsc iter.nextState(); // advance the iterator // State destination = (MEMM.State) iter.nextState(); // Just to advance the iterator double weight = iter.getWeight(); factorIncrementor.incrementTransition(iter, Math.exp(weight) * instWeight); //iter.incrementCount (Math.exp(weight) * instWeight); if (!gatherConstraints && iter.getOutput() == labelString) { if (!Double.isInfinite (weight)) labelLogProb += instWeight * weight; // xxx ????? else { logger.warning ("State "+i+" transition "+j+" has infinite cost; skipping."); if (initializingInfiniteValues) throw new IllegalStateException ("Infinite-cost transitions not yet supported"); //infiniteValues.set (j); else if (!infiniteValues.get(j)) throw new IllegalStateException ("Instance i used to have non-infinite value, " +"but now it has infinite value."); } } } } } // Force initial & final weight parameters to 0 by making sure that // whether factor refers to expectation or constraint, they have the same value. for (int i = 0; i < memm.numStates(); i++) { factors.initialWeights[i] = 0.0; factors.finalWeights[i] = 0.0; } return labelLogProb; } // log probability of the training sequence labels, and fill in expectations[] protected double getExpectationValue () { return gatherExpectationsOrConstraints (false); } } @Override public int getIteration() { // TODO Auto-generated method stub return 0; } @Override public Transducer getTransducer() { return memm; } @Override public boolean isFinishedTraining() { // TODO Auto-generated method stub return false; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy