All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.CRFOptimizableByGradientValues Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

There is a newer version: 2.0.12
Show newest version
package cc.mallet.fst;

import java.util.logging.Logger;

import cc.mallet.optimize.Optimizable;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;

/**
 * A CRF objective function that is the sum of multiple
 * objective functions that implement Optimizable.ByGradientValue.
 *
 * @author Gregory Druck
 * @author Gaurav Chandalia
 */

public class CRFOptimizableByGradientValues implements Optimizable.ByGradientValue {
  private static Logger logger = MalletLogger.getLogger(CRFOptimizableByGradientValues.class.getName());

  private int cachedValueWeightsStamp;
  private int cachedGradientWeightsStamp;
  private double cachedValue = Double.NEGATIVE_INFINITY;
  private double[] cachedGradient;
  private Optimizable.ByGradientValue[] optimizables;
	private CRF crf;
	
	/**
	 * @param crf CRF whose parameters we wish to estimate.
	 * @param opts Optimizable.ByGradientValue objective functions. 
	 * 
	 * Parameters are estimated by maximizing the sum of the individual
	 * objective functions.
	 */
	public CRFOptimizableByGradientValues (CRF crf, Optimizable.ByGradientValue[] opts) {
		this.crf = crf;
		this.optimizables = opts;
		this.cachedGradient = new double[crf.parameters.getNumFactors()];
		this.cachedValueWeightsStamp = -1;
		this.cachedGradientWeightsStamp = -1;
	}

	public int getNumParameters () {
		return crf.parameters.getNumFactors();
	}

	public void getParameters (double[] buffer) {
		crf.parameters.getParameters(buffer);
	}

	public double getParameter (int index) {
		return crf.parameters.getParameter(index);
	}

	public void setParameters (double [] buff) {
		crf.parameters.setParameters(buff);
		crf.weightsValueChanged();
	}

	public void setParameter (int index, double value) {
		crf.parameters.setParameter(index, value);
		crf.weightsValueChanged();
	}

	/** Returns the log probability of the training sequence labels and the prior over parameters. */
	public double getValue () {
		if (crf.weightsValueChangeStamp != cachedValueWeightsStamp) {
			// The cached value is not up to date; it was calculated for a different set of CRF weights.

			cachedValue = 0;
			for (int i = 0; i < optimizables.length; i++)
				cachedValue += optimizables[i].getValue();
			
			cachedValueWeightsStamp = crf.weightsValueChangeStamp;  // cachedValue is now no longer stale
			logger.info ("getValue() = "+cachedValue);
		}
		return cachedValue;
	}

	public void getValueGradient (double [] buffer) {
		if (cachedGradientWeightsStamp != crf.weightsValueChangeStamp) {
			getValue ();
			MatrixOps.setAll(cachedGradient, 0);
			double[] b2 = new double[buffer.length];
			for (int i = 0; i < optimizables.length; i++) {
				MatrixOps.setAll(b2, 0);
				optimizables[i].getValueGradient(b2);
				MatrixOps.plusEquals(cachedGradient, b2);
			}
			cachedGradientWeightsStamp = crf.weightsValueChangeStamp;
		}
		System.arraycopy(cachedGradient, 0, buffer, 0, cachedGradient.length);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy