cc.mallet.fst.semi_supervised.constraints.TwoLabelKLGEConstraints Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2010 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.fst.semi_supervised.constraints;
import java.util.ArrayList;
import gnu.trove.TIntIntHashMap;
import cc.mallet.fst.semi_supervised.StateLabelMap;
/**
* A set of constraints on distributions over consecutive
* labels conditioned an input features.
*
* This is to be used with GE, and penalizes the
* KL divergence between model and target distributions.
*
* Multiple constraints are grouped together here
* to make things more efficient.
*
* @author Gregory Druck
*/
public class TwoLabelKLGEConstraints extends TwoLabelGEConstraints {
public TwoLabelKLGEConstraints() {
super();
}
private TwoLabelKLGEConstraints(ArrayList constraintsList, TIntIntHashMap constraintsMap, StateLabelMap map) {
super(constraintsList,constraintsMap,map);
}
public GEConstraint copy() {
return new TwoLabelKLGEConstraints(this.constraintsList, this.constraintsMap, this.map);
}
@Override
public void addConstraint(int fi, double[][] target, double weight) {
constraintsList.add(new TwoLabelKLGEConstraint(target,weight));
constraintsMap.put(fi, constraintsList.size()-1);
}
@Override
public double getValue() {
double value = 0.0;
for (int fi : constraintsMap.keys()) {
TwoLabelGEConstraint constraint = constraintsList.get(constraintsMap.get(fi));
if (constraint.count > 0.0) {
double constraintValue = 0.0;
for (int prevLi = 0; prevLi < map.getNumLabels(); prevLi++) {
for (int currLi = 0; currLi < map.getNumLabels(); currLi++) {
if (constraint.target[prevLi][currLi] > 0.0) {
if (constraint.expectation[prevLi][currLi] == 0.0) {
return Double.NEGATIVE_INFINITY;
}
else {
// p*log(q) - p*log(p)
// negative KL
constraintValue += constraint.target[prevLi][currLi] * (
Math.log(constraint.expectation[prevLi][currLi]/constraint.count) -
Math.log(constraint.target[prevLi][currLi]));
}
}
}
}
assert(!Double.isNaN(constraintValue) &&
!Double.isInfinite(constraintValue));
value += constraintValue * constraint.weight;
}
}
return value;
}
protected class TwoLabelKLGEConstraint extends TwoLabelGEConstraint {
public TwoLabelKLGEConstraint(double[][] target, double weight) {
super(target,weight);
}
@Override
public double getValue(int liPrev, int liCurr) {
assert(this.count != 0);
if (this.target[liPrev][liCurr] == 0 && this.expectation[liPrev][liCurr] == 0) {
return 0;
}
return this.weight * (this.target[liPrev][liCurr] / ( this.expectation[liPrev][liCurr] ));
}
}
}