
cc.mallet.classify.constraints.ge.MaxEntKLFLGEConstraints Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify.constraints.ge;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.Maths;
/**
* Expectation constraint for use with GE.
* Penalizes KL divergence from target distribution.
*
* Multiple constraints are grouped together here
* to make things more efficient.
*
* @author Gregory Druck
*/
public class MaxEntKLFLGEConstraints extends MaxEntFLGEConstraints {
public MaxEntKLFLGEConstraints(int numFeatures, int numLabels, boolean useValues) {
super(numFeatures, numLabels, useValues);
}
public double getValue() {
double value = 0.0;
for (int fi : constraints.keys()) {
MaxEntFLGEConstraint constraint = constraints.get(fi);
if (constraint.count > 0.0) {
double constraintValue = 0.0;
for (int labelIndex = 0; labelIndex < numLabels; ++labelIndex) {
if (constraint.target[labelIndex] > 0.0) {
// if target is non-zero and expectation is 0, infinite penalty
if (constraint.expectation[labelIndex] == 0.0) {
return Double.NEGATIVE_INFINITY;
}
else {
// p*log(q) - p*log(p)
// negative KL
constraintValue += constraint.target[labelIndex] *
(Math.log(constraint.expectation[labelIndex]/constraint.count) -
Math.log(constraint.target[labelIndex]));
}
}
}
assert(!Double.isNaN(constraintValue) &&
!Double.isInfinite(constraintValue));
value += constraintValue * constraint.weight;
}
}
return value;
}
@Override
public void addConstraint(int fi, double[] ex, double weight) {
assert(Maths.almostEquals(MatrixOps.sum(ex),1));
constraints.put(fi,new MaxEntKLFLGEConstraint(ex,weight));
}
protected class MaxEntKLFLGEConstraint extends MaxEntFLGEConstraint {
public MaxEntKLFLGEConstraint(double[] target, double weight) {
super(target, weight);
}
@Override
public double getValue(int li) {
assert(this.count != 0);
if (this.target[li] == 0 && this.expectation[li] == 0) {
return 0;
}
return this.weight * (this.target[li] / this.expectation[li]);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy