All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.semi_supervised.CRFOptimizableByEntropyRegularization Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2009 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.semi_supervised;

import java.io.Serializable;
import java.util.logging.Logger;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.InstanceList;
import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLattice;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.util.MalletLogger;

/**
 * A CRF objective function that is the entropy of the CRF's
 * predictions on unlabeled data.
 * 
 * References:
 * Feng Jiao, Shaojun Wang, Chi-Hoon Lee, Russell Greiner, Dale Schuurmans
 * "Semi-supervised conditional random fields for improved sequence segmentation and labeling"
 * ACL 2006
 *
 * Gideon Mann, Andrew McCallum
 * "Efficient Computation of Entropy Gradient for Semi-Supervised Conditional Random Fields"
 * HLT/NAACL 2007
 *
 * @author Gaurav Chandalia
 * @author Gregory Druck
 */
public class CRFOptimizableByEntropyRegularization implements Optimizable.ByGradientValue,
                                                   Serializable {
  private static Logger logger = MalletLogger.getLogger(CRFOptimizableByEntropyRegularization.class.getName());

	private int cachedValueWeightsStamp = -1;
	private int cachedGradientWeightsStamp = -1;
  
  // model's expectations according to entropy reg.
  protected CRF.Factors expectations;
  // used to update gradient
  protected Transducer.Incrementor incrementor;

  // contains labeled and unlabeled data
  protected InstanceList data;
  // the model
  protected CRF crf;

  // scale entropy values,
  // typically, (entropyRegGamma * numLabeled / numUnlabeled)
  protected double scalingFactor;

  // log probability of all the sequences, this is also the entropy due to all
  // the instances (updated in computeExpectations())
  protected double cachedValue;
  // gradient due to this optimizable (updated in getValueGradient())
  protected double[] cachedGradient;

  /**
   * Initializes the structures.
   */
  public CRFOptimizableByEntropyRegularization(CRF crf, InstanceList ilist,
                                    double scalingFactor) {
    data = ilist;
    this.crf = crf;
    this.scalingFactor = scalingFactor;

    // initialize the expectations using the model
    expectations = new CRF.Factors(crf);
    incrementor = expectations.new Incrementor();

    cachedValue = 0.0;
    cachedGradient = new double[crf.getParameters().getNumFactors()];
  }

  /**
   * Initializes the structures (sets the scaling factor to 1.0).
   */
  public CRFOptimizableByEntropyRegularization(CRF crf, InstanceList ilist) {
    this(crf, ilist, 1.0);
  }

  public double getScalingFactor() {
    return scalingFactor;
  }

  public void setScalingFactor(double scalingFactor) {
    this.scalingFactor = scalingFactor;
  }

  /**
   * Resets, computes and fills expectations from all instances, also updating
   * the entropy value. 

* * Analogous to CRFOptimizableByLabelLikelihood.getExpectationValue. */ public void computeExpectations() { expectations.zero(); // now, update the expectations due to each instance for entropy reg. for (int ii = 0; ii < data.size(); ii++) { FeatureVectorSequence input = (FeatureVectorSequence) data.get(ii).getData(); SumLattice lattice = new SumLatticeDefault(crf,input, true); // udpate the expectations EntropyLattice entropyLattice = new EntropyLattice( input, lattice.getGammas(), lattice.getXis(), crf, incrementor, scalingFactor); cachedValue += entropyLattice.getEntropy(); } } public double getValue() { if (crf.getWeightsValueChangeStamp() != cachedValueWeightsStamp) { // The cached value is not up to date; it was calculated for a different set of CRF weights. cachedValueWeightsStamp = crf.getWeightsValueChangeStamp(); cachedValue = 0; computeExpectations(); cachedValue = scalingFactor * cachedValue; assert(!Double.isNaN(cachedValue) && !Double.isInfinite(cachedValue)) : "Likelihood due to Entropy Regularization is NaN/Infinite"; logger.info("getValue() (entropy regularization) = " + cachedValue); } return cachedValue; } public void getValueGradient(double[] buffer) { if (cachedGradientWeightsStamp != crf.getWeightsValueChangeStamp()) { cachedGradientWeightsStamp = crf.getWeightsValueChangeStamp(); // cachedGradient will soon no longer be stale getValue(); // if this fails then look in computeExpectations expectations.assertNotNaNOrInfinite(); // fill up gradient expectations.getParameters(cachedGradient); } System.arraycopy(cachedGradient, 0, buffer, 0, cachedGradient.length); } // some get/set methods that have to be implemented public int getNumParameters() { return crf.getParameters().getNumFactors(); } public void getParameters(double[] buffer) { crf.getParameters().getParameters(buffer); } public void setParameters(double[] buffer) { crf.getParameters().setParameters(buffer); crf.weightsValueChanged(); } public double getParameter(int index) { return crf.getParameters().getParameter(index); } public void setParameter(int index, double value) { crf.getParameters().setParameter(index, value); crf.weightsValueChanged(); } // serialization stuff private static final long serialVersionUID = 1; }