All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.constraints.ge.MaxEntKLFLGEConstraints Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

There is a newer version: 2.0.12
Show newest version
/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.classify.constraints.ge;

import cc.mallet.types.MatrixOps;
import cc.mallet.util.Maths;

/**
 * Expectation constraint for use with GE.
 * Penalizes KL divergence from target distribution. 
 * 
 * Multiple constraints are grouped together here
 * to make things more efficient.
 * 
 * @author Gregory Druck
 */

public class MaxEntKLFLGEConstraints extends MaxEntFLGEConstraints {
  
  public MaxEntKLFLGEConstraints(int numFeatures, int numLabels, boolean useValues) {
    super(numFeatures, numLabels, useValues);
  }

  public double getValue() {
    double value = 0.0;
    for (int fi : constraints.keys()) {
      MaxEntFLGEConstraint constraint = constraints.get(fi);
      if (constraint.count > 0.0) {
        double constraintValue = 0.0;
        for (int labelIndex = 0; labelIndex < numLabels; ++labelIndex) {
          if (constraint.target[labelIndex] > 0.0) {
            // if target is non-zero and expectation is 0, infinite penalty
            if (constraint.expectation[labelIndex] == 0.0) {
              return Double.NEGATIVE_INFINITY;
            }
            else {
              // p*log(q) - p*log(p)
              // negative KL
              constraintValue += constraint.target[labelIndex] * 
                  (Math.log(constraint.expectation[labelIndex]/constraint.count) - 
                  Math.log(constraint.target[labelIndex]));
            }
          }
        }
        assert(!Double.isNaN(constraintValue) &&
               !Double.isInfinite(constraintValue));

        value += constraintValue * constraint.weight;
      }
    }
    return value;
  }

  @Override
  public void addConstraint(int fi, double[] ex, double weight) {
    assert(Maths.almostEquals(MatrixOps.sum(ex),1));
    constraints.put(fi,new MaxEntKLFLGEConstraint(ex,weight));
  }
  
  protected class MaxEntKLFLGEConstraint extends MaxEntFLGEConstraint {
    public MaxEntKLFLGEConstraint(double[] target, double weight) {
      super(target, weight);
    }

    @Override
    public double getValue(int li) {
      assert(this.count != 0);
      if (this.target[li] == 0 && this.expectation[li] == 0) {
        return 0;
      }
      return this.weight * (this.target[li] / this.expectation[li]);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy