All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.CRFTrainerByValueGradients Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.fst;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

import java.util.BitSet;
import java.util.Random;
import java.util.logging.Logger;

import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.optimize.Optimizer;

import cc.mallet.util.MalletLogger;


/**
 * A CRF trainer that can combine multiple objective functions, each represented
 * by a Optmizable.ByValueGradient.
 */
public class CRFTrainerByValueGradients extends TransducerTrainer implements TransducerTrainer.ByOptimization {

	private static Logger logger = MalletLogger.getLogger(CRFTrainerByLabelLikelihood.class.getName());

	CRF crf;
  // gsc: keep objects instead of classnames, this will give more flexibility to the 
  // user to setup new CRFOptimizable* objects and then pass them directly in the constructor,
  // so the CRFOptimizable inner class no longer creates CRFOptimizable* objects
	Optimizable.ByGradientValue[] optimizableByValueGradientObjects;
//	Class[] optimizableByValueGradientClasses;
	OptimizableCRF ocrf;
	Optimizer opt;
	int iterationCount = 0;
	boolean converged;
	// gsc: removing these options, the user ought to set the weights before 
	// creating the trainer object
//	boolean useSparseWeights = true;
//	// gsc
//	boolean useUnsupportedTrick = false;
	
	// Various values from CRF acting as indicators of when we need to ...
	private int cachedValueWeightsStamp = -1;  // ... re-calculate expectations and values to getValue() because weights' values changed
	private int cachedGradientWeightsStamp = -1; // ... re-calculate to getValueGradient() because weights' values changed
	
	// gsc: removing this because the user will call setWeightsDimensionsAsIn
//	private int cachedWeightsStructureStamp = -1; // ... re-allocate crf.weights, expectations & constraints because new states, transitions
	// Use mcrf.trainingSet to see when we need to re-allocate crf.weights, expectations & constraints because we are using a different TrainingList than last time

	// gsc: number of times to reset (the optimizer), and continue training when the "could not step in
	// current direction" exception occurs
	public static final int DEFAULT_MAX_RESETS = 3;
	int maxResets = DEFAULT_MAX_RESETS;
	
	public CRFTrainerByValueGradients (CRF crf, Optimizable.ByGradientValue[] optimizableByValueGradientObjects) {
		this.crf = crf;
		this.optimizableByValueGradientObjects = optimizableByValueGradientObjects;
	}
	
	public Transducer getTransducer() { return crf; }
	public CRF getCRF () { return crf; }
	public Optimizer getOptimizer() { return opt; }
	/** Returns true if training converged, false otherwise. */
	public boolean isConverged() { return converged; }
  /** Returns true if training converged, false otherwise. */
	public boolean isFinishedTraining() { return converged; }
	public int getIteration () { return iterationCount; }
	
	// gsc
	public Optimizable.ByGradientValue[] getOptimizableByGradientValueObjects() {
		return optimizableByValueGradientObjects;
	}

	/**
	 * Returns an optimizable CRF that contains a collection of objective functions.
	 * 

* If one doesn't exist then creates one and sets the optimizer to null. */ public OptimizableCRF getOptimizableCRF (InstanceList trainingSet) { // gsc: user should call setWeightsDimensionsAsIn before the optimizable and // trainer objects are created // if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) { // if (useSparseWeights) // crf.setWeightsDimensionAsIn (trainingSet, useUnsupportedTrick); // else // crf.setWeightsDimensionDensely (); // ocrf = null; // cachedWeightsStructureStamp = crf.weightsStructureChangeStamp; // } if (ocrf == null || ocrf.trainingSet != trainingSet) { ocrf = new OptimizableCRF (crf, trainingSet); opt = null; } return ocrf; } /** * Returns a L-BFGS optimizer, creating if one doesn't exist. *

* Also creates an optimizable CRF if required. */ public Optimizer getOptimizer (InstanceList trainingSet) { getOptimizableCRF(trainingSet); // this will set this.mcrf if necessary if (opt == null || ocrf != opt.getOptimizable()) opt = new LimitedMemoryBFGS(ocrf); // Alternative: opt = new ConjugateGradient (0.001); return opt; } /** Trains a CRF until convergence. */ public boolean trainIncremental (InstanceList training) { return train (training, Integer.MAX_VALUE); } /** * Trains a CRF until convergence or specified number of iterations, whichever is earlier. *

* Also creates an optimizable CRF and an optmizer if required. */ public boolean train (InstanceList trainingSet, int numIterations) { if (numIterations <= 0) return false; assert (trainingSet.size() > 0); getOptimizableCRF(trainingSet); // This will set this.mcrf if necessary getOptimizer(trainingSet); // This will set this.opt if necessary int numResets = 0; boolean converged = false; logger.info ("CRF about to train with "+numIterations+" iterations"); for (int i = 0; i < numIterations; i++) { try { // gsc: timing each iteration long startTime = System.currentTimeMillis(); converged = opt.optimize (1); logger.info ("CRF finished one iteration of maximizer, i="+i+", "+ +(System.currentTimeMillis()-startTime)/1000 + " secs."); iterationCount++; runEvaluators(); } catch (OptimizationException e) { // gsc: resetting the optimizer for specified number of times e.printStackTrace(); logger.info ("Catching exception."); if (numResets < maxResets) { // reset the optimizer and get a new one logger.info("Resetting optimizer."); ++numResets; opt = null; getOptimizer(trainingSet); // logger.info ("Catching exception; saying converged."); // converged = true; } else { logger.info("Saying converged."); converged = true; } } if (converged) { logger.info ("CRF training has converged, i="+i); break; } } return converged; } /** * Train a CRF on various-sized subsets of the data. This method is typically used to accelerate training by * quickly getting to reasonable parameters on only a subset of the parameters first, then on progressively more data. * @param training The training Instances. * @param numIterationsPerProportion Maximum number of Maximizer iterations per training proportion. * @param trainingProportions If non-null, train on increasingly * larger portions of the data, e.g. new double[] {0.2, 0.5, 1.0}. This can sometimes speedup convergence. * Be sure to end in 1.0 if you want to train on all the data in the end. * @return True if training has converged. */ public boolean train (InstanceList training, int numIterationsPerProportion, double[] trainingProportions) { int trainingIteration = 0; assert (trainingProportions.length > 0); boolean converged = false; for (int i = 0; i < trainingProportions.length; i++) { assert (trainingProportions[i] <= 1.0); logger.info ("Training on "+trainingProportions[i]+"% of the data this round."); if (trainingProportions[i] == 1.0) converged = this.train (training, numIterationsPerProportion); else converged = this.train (training.split (new Random(1), new double[] {trainingProportions[i], 1-trainingProportions[i]})[0], numIterationsPerProportion); trainingIteration += numIterationsPerProportion; } return converged; } // gsc: see comment in getOptimizableCRF // public void setUseSparseWeights (boolean b) { useSparseWeights = b; } // public boolean getUseSparseWeights () { return useSparseWeights; } // // // gsc // public void setUseUnsupportedTrick (boolean b) { useUnsupportedTrick = b; } // public boolean getUseUnsupportedTrick () { return useUnsupportedTrick; } // gsc: change max. number of times the optimizer can be reset before // throwing the "could not step in current direction" exception /** * Sets the max. number of times the optimizer can be reset before throwing * an exception. *

* Default value: DEFAULT_MAX_RESETS. */ public void setMaxResets(int maxResets) { this.maxResets = maxResets; } /** An optimizable CRF that contains a collection of objective functions. */ public class OptimizableCRF implements Optimizable.ByGradientValue, Serializable { InstanceList trainingSet; double cachedValue = -123456789; double[] cachedGradie; BitSet infiniteValues = null; CRF crf; Optimizable.ByGradientValue[] opts; protected OptimizableCRF (CRF crf, InstanceList ilist) { // Set up this.crf = crf; this.trainingSet = ilist; this.opts = optimizableByValueGradientObjects; cachedGradie = new double[crf.parameters.getNumFactors()]; cachedValueWeightsStamp = -1; cachedGradientWeightsStamp = -1; } // protected OptimizableCRF (CRF crf, InstanceList ilist) // { // // Set up // this.crf = crf; // this.trainingSet = ilist; // cachedGradie = new double[crf.parameters.getNumFactors()]; // Class[] parameterTypes = new Class[] {CRF.class, InstanceList.class}; // for (int i = 0; i < optimizableByValueGradientClasses.length; i++) { // try { // Constructor c = optimizableByValueGradientClasses[i].getConstructor(parameterTypes); // opts[i] = (Optimizable.ByGradientValue) c.newInstance(crf, ilist); // } catch (Exception e) { throw new IllegalStateException ("Couldn't contruct Optimizable.ByGradientValue"); } // } // cachedValueWeightsStamp = -1; // cachedGradientWeightsStamp = -1; // } // TODO Move these implementations into CRF.java, and put here stubs that call them! public int getNumParameters () { return crf.parameters.getNumFactors(); } public void getParameters (double[] buffer) { crf.parameters.getParameters(buffer); } public double getParameter (int index) { return crf.parameters.getParameter(index); } public void setParameters (double [] buff) { crf.parameters.setParameters(buff); crf.weightsValueChanged(); } public void setParameter (int index, double value) { crf.parameters.setParameter(index, value); crf.weightsValueChanged(); } /** Returns the log probability of the training sequence labels and the prior over parameters. */ public double getValue () { if (crf.weightsValueChangeStamp != cachedValueWeightsStamp) { // The cached value is not up to date; it was calculated for a different set of CRF weights. long startingTime = System.currentTimeMillis(); cachedValue = 0; for (int i = 0; i < opts.length; i++) cachedValue += opts[i].getValue(); cachedValueWeightsStamp = crf.weightsValueChangeStamp; // cachedValue is now no longer stale logger.info ("getValue() (loglikelihood) = "+cachedValue); logger.fine ("Inference milliseconds = "+(System.currentTimeMillis() - startingTime)); } return cachedValue; } public void getValueGradient (double [] buffer) { // PriorGradient is -parameter/gaussianPriorVariance // Gradient is (constraint - expectation + PriorGradient) // == -(expectation - constraint - PriorGradient). // Gradient points "up-hill", i.e. in the direction of higher value if (cachedGradientWeightsStamp != crf.weightsValueChangeStamp) { getValue (); // This will fill in the this.expectation, updating it if necessary MatrixOps.setAll(cachedGradie, 0); double[] b2 = new double[buffer.length]; for (int i = 0; i < opts.length; i++) { MatrixOps.setAll(b2, 0); opts[i].getValueGradient(b2); MatrixOps.plusEquals(cachedGradie, b2); } cachedGradientWeightsStamp = crf.weightsValueChangeStamp; } System.arraycopy(cachedGradie, 0, buffer, 0, cachedGradie.length); } //Serialization of MaximizableCRF private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject(trainingSet); out.writeDouble(cachedValue); out.writeObject(cachedGradie); out.writeObject(infiniteValues); out.writeObject(crf); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { in.readInt (); trainingSet = (InstanceList) in.readObject(); cachedValue = in.readDouble(); cachedGradie = (double[]) in.readObject(); infiniteValues = (BitSet) in.readObject(); crf = (CRF)in.readObject(); } } // Serialization for CRFTrainerByValueGradient private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; static final int NULL_INTEGER = -1; /* Need to check for null pointers. */ private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); //out.writeInt(defaultFeatureIndex); out.writeInt(cachedGradientWeightsStamp); out.writeInt(cachedValueWeightsStamp); // out.writeInt(cachedWeightsStructureStamp); // out.writeBoolean (useSparseWeights); throw new IllegalStateException("Implementation not yet complete."); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { in.readInt (); //defaultFeatureIndex = in.readInt(); // useSparseWeights = in.readBoolean(); throw new IllegalStateException("Implementation not yet complete."); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy