cc.mallet.fst.semi_supervised.constraints.OneLabelGEConstraints Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2010 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.fst.semi_supervised.constraints;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.BitSet;
import cc.mallet.fst.SumLattice;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
/**
* A set of constraints on distributions over single
* labels conditioned on the presence of input features.
*
* Subclasses are to be used with GE.
*
* Multiple constraints are grouped together here
* to make things more efficient.
*
* @author Gregory Druck
*/
public abstract class OneLabelGEConstraints implements GEConstraint {
// maps between input feature indices and constraints
protected TIntObjectHashMap constraints;
protected StateLabelMap map;
// cache of set of constrained features that fire at last FeatureVector
// provided in preprocess call
protected TIntArrayList cache;
public OneLabelGEConstraints() {
this.constraints = new TIntObjectHashMap();
this.cache = new TIntArrayList();
}
protected OneLabelGEConstraints(TIntObjectHashMap constraints, StateLabelMap map) {
this.constraints = constraints;
this.map = map;
this.cache = new TIntArrayList();
}
public abstract void addConstraint(int fi, double[] target, double weight);
public boolean isOneStateConstraint() {
return true;
}
public void setStateLabelMap(StateLabelMap map) {
this.map = map;
}
public void preProcess(FeatureVector fv) {
cache.resetQuick();
int fi;
// cache constrained input features
for (int loc = 0; loc < fv.numLocations(); loc++) {
fi = fv.indexAtLocation(loc);
if (constraints.containsKey(fi)) {
cache.add(fi);
}
}
if (constraints.containsKey(fv.getAlphabet().size())) {
cache.add(fv.getAlphabet().size());
}
}
// find examples that contain constrained input features
public BitSet preProcess(InstanceList data) {
// count
int ii = 0;
int fi;
FeatureVector fv;
BitSet bitSet = new BitSet(data.size());
for (Instance instance : data) {
FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
for (int ip = 0; ip < fvs.size(); ip++) {
fv = fvs.get(ip);
for (int loc = 0; loc < fv.numLocations(); loc++) {
fi = fv.indexAtLocation(loc);
if (constraints.containsKey(fi)) {
constraints.get(fi).count += 1;
bitSet.set(ii);
}
}
if (constraints.containsKey(fv.getAlphabet().size())) {
bitSet.set(ii);
constraints.get(fv.getAlphabet().size()).count += 1;
}
}
ii++;
}
return bitSet;
}
public double getCompositeConstraintFeatureValue(FeatureVector fv, int ip, int si1, int si2) {
double value = 0;
int li2 = map.getLabelIndex(si2);
for (int i = 0; i < cache.size(); i++) {
value += constraints.get(cache.getQuick(i)).getValue(li2);
}
return value;
}
public abstract double getValue();
public void zeroExpectations() {
for (int fi : constraints.keys()) {
constraints.get(fi).expectation = new double[map.getNumLabels()];
}
}
public void computeExpectations(ArrayList lattices) {
double[][] gammas;
TIntArrayList cache = new TIntArrayList();
for (int i = 0; i < lattices.size(); i++) {
if (lattices.get(i) == null) { continue; }
SumLattice lattice = lattices.get(i);
FeatureVectorSequence fvs = (FeatureVectorSequence)lattice.getInput();
gammas = lattice.getGammas();
for (int ip = 0; ip < fvs.size(); ++ip) {
cache.resetQuick();
FeatureVector fv = fvs.getFeatureVector(ip);
int fi;
for (int loc = 0; loc < fv.numLocations(); loc++) {
fi = fv.indexAtLocation(loc);
// binary constraint features
if (constraints.containsKey(fi)) {
cache.add(fi);
}
}
if (constraints.containsKey(fv.getAlphabet().size())) {
cache.add(fv.getAlphabet().size());
}
for (int s = 0; s < map.getNumStates(); ++s) {
int li = map.getLabelIndex(s);
if (li != StateLabelMap.START_LABEL) {
double gammaProb = Math.exp(gammas[ip+1][s]);
for (int j = 0; j < cache.size(); j++) {
constraints.get(cache.getQuick(j)).expectation[li] += gammaProb;
}
}
}
}
}
}
protected abstract class OneLabelGEConstraint {
protected double[] target;
protected double[] expectation;
protected double count;
protected double weight;
public OneLabelGEConstraint(double[] target, double weight) {
this.target = target;
this.weight = weight;
this.expectation = null;
this.count = 0;
}
public double getCount() {
return count;
}
public double[] getTarget() {
return target;
}
public double[] getExpectation() {
return expectation;
}
public double getWeight() {
return weight;
}
public abstract double getValue(int li);
}
}