
cc.mallet.classify.constraints.ge.MaxEntFLGEConstraints Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify.constraints.ge;
import gnu.trove.TDoubleArrayList;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntObjectHashMap;
import java.util.BitSet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
/**
* Abstract expectation constraint for use with Generalized Expectation (GE).
*
* @author Gregory Druck
*/
public abstract class MaxEntFLGEConstraints implements MaxEntGEConstraint {
protected boolean useValues;
protected int numLabels;
protected int numFeatures;
// maps between input feature indices and constraints
protected TIntObjectHashMap constraints;
// cache of set of constrained features that fire at last FeatureVector
// provided in preprocess call
protected TIntArrayList indexCache;
protected TDoubleArrayList valueCache;
public MaxEntFLGEConstraints(int numFeatures, int numLabels, boolean useValues) {
this.numFeatures = numFeatures;
this.numLabels = numLabels;
this.useValues = useValues;
this.constraints = new TIntObjectHashMap();
this.indexCache = new TIntArrayList();
this.valueCache = new TDoubleArrayList();
}
public abstract void addConstraint(int fi, double[] ex, double weight);
public double getCompositeConstraintFeatureValue(FeatureVector input, int label) {
double value = 0;
for (int i = 0; i < indexCache.size(); i++) {
if (useValues) {
value += constraints.get(indexCache.getQuick(i)).getValue(label) * valueCache.getQuick(i);
}
else {
value += constraints.get(indexCache.getQuick(i)).getValue(label);
}
}
return value;
}
public void computeExpectations(FeatureVector input, double[] dist, double weight) {
preProcess(input);
for (int li = 0; li < numLabels; li++) {
double p = weight * dist[li];
for (int i = 0; i < indexCache.size(); i++) {
if (useValues) {
constraints.get(indexCache.getQuick(i)).expectation[li] += p * valueCache.getQuick(i);
}
else {
constraints.get(indexCache.getQuick(i)).expectation[li] += p;
}
}
}
}
public void zeroExpectations() {
for (int fi : constraints.keys()) {
constraints.get(fi).expectation = new double[numLabels];
}
}
public BitSet preProcess(InstanceList data) {
// count
int ii = 0;
int fi;
FeatureVector fv;
BitSet bitSet = new BitSet(data.size());
for (Instance instance : data) {
double weight = data.getInstanceWeight(instance);
fv = (FeatureVector)instance.getData();
for (int loc = 0; loc < fv.numLocations(); loc++) {
fi = fv.indexAtLocation(loc);
if (constraints.containsKey(fi)) {
if (useValues) {
constraints.get(fi).count += weight * fv.valueAtLocation(loc);
}
else {
constraints.get(fi).count += weight;
}
bitSet.set(ii);
}
}
ii++;
// default feature, for label regularization
if (constraints.containsKey(numFeatures)) {
bitSet.set(ii);
constraints.get(numFeatures).count += weight;
}
}
return bitSet;
}
public void preProcess(FeatureVector input) {
indexCache.resetQuick();
if (useValues) valueCache.resetQuick();
int fi;
// cache constrained input features
for (int loc = 0; loc < input.numLocations(); loc++) {
fi = input.indexAtLocation(loc);
if (constraints.containsKey(fi)) {
indexCache.add(fi);
if (useValues) valueCache.add(input.valueAtLocation(loc));
}
}
// default feature, for label regularization
if (constraints.containsKey(numFeatures)) {
indexCache.add(numFeatures);
if (useValues) valueCache.add(1);
}
}
protected abstract class MaxEntFLGEConstraint {
protected double[] target;
protected double[] expectation;
protected double count;
protected double weight;
public MaxEntFLGEConstraint(double[] target, double weight) {
this.target = target;
this.weight = weight;
this.expectation = null;
this.count = 0;
}
public double[] getTarget() {
return target;
}
public double[] getExpectation() {
return expectation;
}
public double getCount() {
return count;
}
public double getWeight() {
return weight;
}
public abstract double getValue(int li);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy