
cc.mallet.fst.semi_supervised.CRFTrainerByLikelihoodAndGE Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
package cc.mallet.fst.semi_supervised;
import java.util.ArrayList;
import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFCacheStaleIndicator;
import cc.mallet.fst.CRFOptimizableByBatchLabelLikelihood;
import cc.mallet.fst.CRFOptimizableByGradientValues;
import cc.mallet.fst.CRFOptimizableByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByThreadedLabelLikelihood;
import cc.mallet.fst.ThreadedOptimizable;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.CRFOptimizableByGE;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.InstanceList;
public class CRFTrainerByLikelihoodAndGE extends TransducerTrainer {
private boolean initSupervised;
private boolean converged;
private double geWeight;
private double gpv;
private int supIterations;
private int numThreads;
private int iteration;
private CRF crf;
private ArrayList constraints;
private StateLabelMap map;
public CRFTrainerByLikelihoodAndGE(CRF crf, ArrayList constraints, StateLabelMap map) {
this.crf = crf;
this.constraints = constraints;
this.map = map;
this.iteration = 0;
this.converged = false;
this.geWeight = 1.0;
this.initSupervised = false;
this.gpv = 10.0;
this.numThreads = 1;
this.supIterations = Integer.MAX_VALUE;
}
public void setGEWeight(double weight) {
this.geWeight = weight;
}
public void setGaussianPriorVariance(double gpv) {
this.gpv = gpv;
}
public void setInitSupervised(boolean flag) {
this.initSupervised = flag;
}
public void setSupervisedIterations(int iterations) {
this.supIterations = iterations;
}
public void setNumThreads(int numThreads) {
this.numThreads = numThreads;
}
@Override
public Transducer getTransducer() {
return crf;
}
@Override
public int getIteration() {
return iteration;
}
@Override
public boolean isFinishedTraining() {
return converged;
}
public boolean train(InstanceList trainingSet, InstanceList unlabeledSet, int numIterations) {
System.err.println(trainingSet.size());
System.err.println(unlabeledSet.size());
if (initSupervised) {
// train supervised
if (numThreads == 1) {
CRFTrainerByLabelLikelihood trainer = new CRFTrainerByLabelLikelihood(crf);
trainer.setAddNoFactors(true);
trainer.setGaussianPriorVariance(gpv);
trainer.train(trainingSet,supIterations);
}
else {
CRFTrainerByThreadedLabelLikelihood trainer = new CRFTrainerByThreadedLabelLikelihood(crf,numThreads);
trainer.setAddNoFactors(true);
trainer.setGaussianPriorVariance(gpv);
trainer.train(trainingSet,supIterations);
trainer.shutdown();
}
runEvaluators();
}
// train semi-supervised
Optimizable.ByGradientValue optLikelihood;
if (numThreads == 1) {
optLikelihood = new CRFOptimizableByLabelLikelihood(crf,trainingSet);
((CRFOptimizableByLabelLikelihood)optLikelihood).setGaussianPriorVariance(gpv);
}
else {
CRFOptimizableByBatchLabelLikelihood likelihood = new CRFOptimizableByBatchLabelLikelihood(crf,trainingSet,numThreads);
optLikelihood = new ThreadedOptimizable(likelihood,trainingSet,crf.getParameters().getNumFactors(),
new CRFCacheStaleIndicator(crf));
likelihood.setGaussianPriorVariance(gpv);
}
CRFOptimizableByGE ge = new CRFOptimizableByGE(crf,constraints,unlabeledSet,map,numThreads,geWeight);
// turn off the prior... it already appears above!
ge.setGaussianPriorVariance(Double.POSITIVE_INFINITY);
CRFOptimizableByGradientValues opt =
new CRFOptimizableByGradientValues(crf,new Optimizable.ByGradientValue[] { optLikelihood, ge });
LimitedMemoryBFGS optimizer = new LimitedMemoryBFGS(opt);
try {
converged = optimizer.optimize(numIterations);
}
catch (Exception e) {
e.printStackTrace();
}
optimizer.reset();
try {
converged = optimizer.optimize(numIterations);
}
catch (Exception e) {
e.printStackTrace();
}
if (numThreads > 1) {
((ThreadedOptimizable)optLikelihood).shutdown();
ge.shutdown();
}
return converged;
}
@Override
public boolean train(InstanceList trainingSet, int numIterations) {
throw new RuntimeException("Must use train(InstanceList trainingSet, InstanceList unlabeledSet, int numIterations) instead");
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy