
cc.mallet.fst.CRFOptimizableByGradientValues Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
package cc.mallet.fst;
import java.util.logging.Logger;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
/**
* A CRF objective function that is the sum of multiple
* objective functions that implement Optimizable.ByGradientValue.
*
* @author Gregory Druck
* @author Gaurav Chandalia
*/
public class CRFOptimizableByGradientValues implements Optimizable.ByGradientValue {
private static Logger logger = MalletLogger.getLogger(CRFOptimizableByGradientValues.class.getName());
private int cachedValueWeightsStamp;
private int cachedGradientWeightsStamp;
private double cachedValue = Double.NEGATIVE_INFINITY;
private double[] cachedGradient;
private Optimizable.ByGradientValue[] optimizables;
private CRF crf;
/**
* @param crf CRF whose parameters we wish to estimate.
* @param opts Optimizable.ByGradientValue objective functions.
*
* Parameters are estimated by maximizing the sum of the individual
* objective functions.
*/
public CRFOptimizableByGradientValues (CRF crf, Optimizable.ByGradientValue[] opts) {
this.crf = crf;
this.optimizables = opts;
this.cachedGradient = new double[crf.parameters.getNumFactors()];
this.cachedValueWeightsStamp = -1;
this.cachedGradientWeightsStamp = -1;
}
public int getNumParameters () {
return crf.parameters.getNumFactors();
}
public void getParameters (double[] buffer) {
crf.parameters.getParameters(buffer);
}
public double getParameter (int index) {
return crf.parameters.getParameter(index);
}
public void setParameters (double [] buff) {
crf.parameters.setParameters(buff);
crf.weightsValueChanged();
}
public void setParameter (int index, double value) {
crf.parameters.setParameter(index, value);
crf.weightsValueChanged();
}
/** Returns the log probability of the training sequence labels and the prior over parameters. */
public double getValue () {
if (crf.weightsValueChangeStamp != cachedValueWeightsStamp) {
// The cached value is not up to date; it was calculated for a different set of CRF weights.
cachedValue = 0;
for (int i = 0; i < optimizables.length; i++)
cachedValue += optimizables[i].getValue();
cachedValueWeightsStamp = crf.weightsValueChangeStamp; // cachedValue is now no longer stale
logger.info ("getValue() = "+cachedValue);
}
return cachedValue;
}
public void getValueGradient (double [] buffer) {
if (cachedGradientWeightsStamp != crf.weightsValueChangeStamp) {
getValue ();
MatrixOps.setAll(cachedGradient, 0);
double[] b2 = new double[buffer.length];
for (int i = 0; i < optimizables.length; i++) {
MatrixOps.setAll(b2, 0);
optimizables[i].getValueGradient(b2);
MatrixOps.plusEquals(cachedGradient, b2);
}
cachedGradientWeightsStamp = crf.weightsValueChangeStamp;
}
System.arraycopy(cachedGradient, 0, buffer, 0, cachedGradient.length);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy