All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.CRFTrainerByStochasticGradient Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

There is a newer version: 2.0.12
Show newest version
package cc.mallet.fst;

import java.util.ArrayList;
import java.util.Collections;

import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;

import cc.mallet.fst.TransducerTrainer.ByInstanceIncrements;

/**
 * Trains CRF by stochastic gradient. Most effective on large training sets.
 * 
 * @author kedarb
 */
public class CRFTrainerByStochasticGradient extends ByInstanceIncrements {
	protected CRF crf;

	// t is the decaying factor. lambda is some regularization depending on the
	// training set size and the gaussian prior.
	protected double learningRate, t, lambda;

	protected int iterationCount = 0;

	protected boolean converged = false;

	protected CRF.Factors expectations, constraints;

	public CRFTrainerByStochasticGradient(CRF crf, InstanceList trainingSample) {
		this.crf = crf;
		this.expectations = new CRF.Factors(crf);
		this.constraints = new CRF.Factors(crf);
		this.setLearningRateByLikelihood(trainingSample);
	}

	public CRFTrainerByStochasticGradient(CRF crf, double learningRate) {
		this.crf = crf;
		this.learningRate = learningRate;
		this.expectations = new CRF.Factors(crf);
		this.constraints = new CRF.Factors(crf);
	}

	public int getIteration() {
		return iterationCount;
	}

	public Transducer getTransducer() {
		return crf;
	}

	public boolean isFinishedTraining() {
		return converged;
	}

	// Best way to choose learning rate is to run training on a sample and set
	// it to the rate that produces maximum increase in likelihood or accuracy.
	// Then, to be conservative just halve the learning rate.
	// In general, eta = 1/(lambda*t) where
	// lambda=priorVariance*numTrainingInstances
	// After an initial eta_0 is set, t_0 = 1/(lambda*eta_0)
	// After each training step eta = 1/(lambda*(t+t_0)), t=0,1,2,..,Infinity
	/** Automatically sets the learning rate to one that would be good */
	public void setLearningRateByLikelihood(InstanceList trainingSample) {
		int numIterations = 5; // was 10 -akm 1/25/08
		double bestLearningRate = Double.NEGATIVE_INFINITY;
		double bestLikelihoodChange = Double.NEGATIVE_INFINITY;

		double currLearningRate = 5e-11;
		while (currLearningRate < 1) {
			currLearningRate *= 2;
			crf.parameters.zero();
			double beforeLikelihood = computeLikelihood(trainingSample);
			double likelihoodChange = trainSample(trainingSample,
					numIterations, currLearningRate)
					- beforeLikelihood;
			System.out.println("likelihood change = " + likelihoodChange
					+ " for learningrate=" + currLearningRate);

			if (likelihoodChange > bestLikelihoodChange) {
				bestLikelihoodChange = likelihoodChange;
				bestLearningRate = currLearningRate;
			}
		}

		// reset the parameters
		crf.parameters.zero();
		// conservative estimate for learning rate
		bestLearningRate /= 2;
		System.out.println("Setting learning rate to " + bestLearningRate);
		setLearningRate(bestLearningRate);
	}

	private double trainSample(InstanceList trainingSample, int numIterations,
			double rate) {
		double lambda = trainingSample.size();
		double t = 1 / (lambda * rate);

		double loglik = Double.NEGATIVE_INFINITY;
		for (int i = 0; i < numIterations; i++) {
			loglik = 0.0;
			for (int j = 0; j < trainingSample.size(); j++) {
				rate = 1 / (lambda * t);
				loglik += trainIncrementalLikelihood(trainingSample.get(j),
						rate);
				t += 1.0;
			}
		}

		return loglik;
	}

	private double computeLikelihood(InstanceList trainingSample) {
		double loglik = 0.0;
		for (int i = 0; i < trainingSample.size(); i++) {
			Instance trainingInstance = trainingSample.get(i);
			FeatureVectorSequence fvs = (FeatureVectorSequence) trainingInstance
					.getData();
			Sequence labelSequence = (Sequence) trainingInstance.getTarget();
			loglik += new SumLatticeDefault(crf, fvs, labelSequence, null)
					.getTotalWeight();
			loglik -= new SumLatticeDefault(crf, fvs, null, null)
					.getTotalWeight();
		}
		constraints.zero();
		expectations.zero();
		return loglik;
	}

	public void setLearningRate(double r) {
		this.learningRate = r;
	}

	public double getLearningRate() {
		return this.learningRate;
	}

	public boolean train(InstanceList trainingSet, int numIterations) {
		return train(trainingSet, numIterations, 1);
	}

	public boolean train(InstanceList trainingSet, int numIterations,
			int numIterationsBetweenEvaluation) {
		assert (expectations.structureMatches(crf.parameters));
		assert (constraints.structureMatches(crf.parameters));
		lambda = 1.0 / trainingSet.size();
		t = 1.0 / (lambda * learningRate);
		converged = false;

		ArrayList trainingIndices = new ArrayList();
		for (int i = 0; i < trainingSet.size(); i++)
			trainingIndices.add(i);

		double oldLoglik = Double.NEGATIVE_INFINITY;
		while (numIterations-- > 0) {
			iterationCount++;

			// shuffle the indices
			Collections.shuffle(trainingIndices);

			double loglik = 0.0;
			for (int i = 0; i < trainingSet.size(); i++) {
				learningRate = 1.0 / (lambda * t);
				loglik += trainIncrementalLikelihood(trainingSet
						.get(trainingIndices.get(i)));
				t += 1.0;
			}

			System.out.println("loglikelihood[" + numIterations + "] = "
					+ loglik);

			if (Math.abs(loglik - oldLoglik) < 1e-3) {
				converged = true;
				break;
			}
			oldLoglik = loglik;

			Runtime.getRuntime().gc();

			if (iterationCount % numIterationsBetweenEvaluation == 0)
				runEvaluators();
		}

		return converged;
	}

	// TODO Add some way to train by batches of instances, where the batch
	// memberships are determined externally? Or provide some easy interface for
	// creating batches.
	public boolean trainIncremental(InstanceList trainingSet) {
		this.train(trainingSet, 1);
		return false;
	}

	public boolean trainIncremental(Instance trainingInstance) {
		assert (expectations.structureMatches(crf.parameters));
		trainIncrementalLikelihood(trainingInstance);
		return false;
	}

	/**
	 * Adjust the parameters by default learning rate according to the gradient
	 * of this single Instance, and return the true label sequence likelihood.
	 */
	public double trainIncrementalLikelihood(Instance trainingInstance) {
		return trainIncrementalLikelihood(trainingInstance, learningRate);
	}

	/**
	 * Adjust the parameters by learning rate according to the gradient of this
	 * single Instance, and return the true label sequence likelihood.
	 */
	public double trainIncrementalLikelihood(Instance trainingInstance,
			double rate) {
		double singleLoglik;
		constraints.zero();
		expectations.zero();
		FeatureVectorSequence fvs = (FeatureVectorSequence) trainingInstance
				.getData();
		Sequence labelSequence = (Sequence) trainingInstance.getTarget();
		singleLoglik = new SumLatticeDefault(crf, fvs, labelSequence,
				constraints.new Incrementor()).getTotalWeight();
		singleLoglik -= new SumLatticeDefault(crf, fvs, null,
				expectations.new Incrementor()).getTotalWeight();
		// Calculate parameter gradient given these instances: (constraints -
		// expectations)
		constraints.plusEquals(expectations, -1);
		// Change the parameters a little by this difference, obeying
		// weightsFrozen
		crf.parameters.plusEquals(constraints, rate, true);

		return singleLoglik;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy