
cc.mallet.fst.semi_supervised.GELattice Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
/* Copyright (C) 2010 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.fst.semi_supervised;
import java.util.ArrayList;
import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.LogNumber;
/**
* Runs the dynamic programming algorithm of [Mann and McCallum 08] for
* computing the gradient of a Generalized Expectation constraint that
* considers a single label of a linear chain CRF.
*
* See:
* "Generalized Expectation Criteria for Semi-Supervised Learning of Conditional Random Fields"
* Gideon Mann and Andrew McCallum
* ACL 2008
*
* gdruck NOTE: This new version of GE Lattice that computes the gradient
* for all constraints simultaneously!
*
* @author Gregory Druck
* @author Gaurav Chandalia
* @author Gideon Mann
*/
public class GELattice {
// input length + 1
protected int latticeLength;
// the model
protected Transducer transducer;
// number of states in the FST
protected int numStates;
// dynamic programming lattice
protected LatticeNode[][] lattice;
// cache of dot produce between violation and
// constraint features
protected LogNumber[][][] dotCache;
/**
* @param fvs Input FeatureVectorSequence
* @param gammas Marginals over single states
* @param xis Marginals over pairs of states
* @param transducer Transducer
* @param reverseTrans Source state indices for each destination state
* @param reverseTransIndices Transition indices for each destination state
* @param gradient Gradient to increment
* @param constraints List of constraints
* @param check Whether to run the debugging test to verify correctness (will be much slower if true)
*/
public GELattice(
FeatureVectorSequence fvs, double[][] gammas, double[][][] xis,
Transducer transducer, int[][] reverseTrans, int[][] reverseTransIndices, CRF.Factors gradient,
ArrayList constraints, boolean check) {
assert(gradient != null);
latticeLength = fvs.size() + 1;
this.transducer = transducer;
numStates = transducer.numStates();
// lattice
lattice = new LatticeNode[latticeLength][numStates];
for (int ip = 0; ip < latticeLength; ++ip) {
for (int a = 0; a < numStates; ++a) {
lattice[ip][a] = new LatticeNode();
}
}
dotCache = new LogNumber[latticeLength][numStates][numStates];
// TODO maybe this should be cached?
// Separate lists for constraints that look at one vs two states.
ArrayList constraints1 = new ArrayList();
ArrayList constraints2 = new ArrayList();
for (GEConstraint constraint : constraints) {
if (constraint.isOneStateConstraint()) {
constraints1.add(constraint);
}
else {
constraints2.add(constraint);
}
}
CRF crf = (CRF)transducer;
double dotEx = this.runForward(crf, constraints1, constraints2, gammas, xis, reverseTrans, fvs);
this.runBackward(crf, gammas, xis, reverseTrans, reverseTransIndices, fvs, dotEx, gradient);
//check(constraints,gammas,xis,fvs);
}
/**
* Run forward pass of dynamic programming algorithm
*
* @param crf CRF
* @param constraints1 Constraints that consider one state.
* @param constraints2 Constraints that consider two states.
* @param gammas Marginals over single states
* @param xis Marginals over pairs of states
* @param reverseTrans Source state indices for each destination state
* @param fvs Input FeatureVectorSequence
* @return
*/
private double runForward(CRF crf, ArrayList constraints1, ArrayList constraints2, double[][] gammas,
double[][][] xis, int[][] reverseTrans, FeatureVectorSequence fvs) {
double dotEx = 0;
LogNumber[] oneStateValueCache = new LogNumber[numStates];
LogNumber nuAlpha = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
LogNumber temp = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
for (int ip = 0; ip < latticeLength-1; ++ip) {
FeatureVector fv = fvs.get(ip);
// speed things up by giving the constraints an
// opportunity to cache, for example, which
// constrained input features appear in this
// FeatureVector
for (GEConstraint constraint : constraints1) {
constraint.preProcess(fv);
}
for (GEConstraint constraint : constraints2) {
constraint.preProcess(fv);
}
boolean[] oneStateValComputed = new boolean[numStates];
for (int prev = 0; prev < numStates; prev++) {
nuAlpha.set(Transducer.IMPOSSIBLE_WEIGHT,true);
if (ip != 0) {
int[] prevPrevs = reverseTrans[prev];
// calculate only once: \sum_y_{i-1} w_a(y_{i-1},y_i)
for (int ppi = 0; ppi < prevPrevs.length; ppi++) {
nuAlpha.plusEquals(lattice[ip-1][prevPrevs[ppi]].alpha[prev]);
}
}
assert (!Double.isNaN(nuAlpha.logVal));
CRF.State prevState = (CRF.State)crf.getState(prev);
LatticeNode node = lattice[ip][prev];
double[] xi = xis[ip][prev];
double gamma = gammas[ip][prev];
for (int ci = 0; ci < prevState.numDestinations(); ci++) {
int curr = prevState.getDestinationState(ci).getIndex();
double dot = 0;
for (GEConstraint constraint : constraints2) {
dot += constraint.getCompositeConstraintFeatureValue(fv, ip, prev, curr);
}
// avoid recomputing one-state constraint features #labels times
if (!oneStateValComputed[curr]) {
double osVal = 0;
for (GEConstraint constraint : constraints1) {
osVal += constraint.getCompositeConstraintFeatureValue(fv, ip, prev, curr);
}
if (osVal < 0) {
dotEx += Math.exp(gammas[ip+1][curr]) * osVal;
oneStateValueCache[curr] = new LogNumber(Math.log(-osVal),false);
}
else if (osVal > 0) {
dotEx += Math.exp(gammas[ip+1][curr]) * osVal;
oneStateValueCache[curr] = new LogNumber(Math.log(osVal),true);
}
else {
oneStateValueCache[curr] = null;
}
oneStateValComputed[curr] = true;
}
// combine the one and two state constraint feature values
if (dot == 0 && oneStateValueCache[curr] == null) {
dotCache[ip][prev][curr] = null;
}
else if (dot == 0 && oneStateValueCache[curr] != null) {
dotCache[ip][prev][curr] = oneStateValueCache[curr];
}
else {
dotEx += Math.exp(xi[curr]) * dot;
if (dot < 0) {
dotCache[ip][prev][curr] = new LogNumber(Math.log(-dot),false);
}
else {
dotCache[ip][prev][curr] = new LogNumber(Math.log(dot),true);
}
if (oneStateValueCache[curr] != null) {
dotCache[ip][prev][curr].plusEquals(oneStateValueCache[curr]);
}
}
// update the dynamic programming table
if (dotCache[ip][prev][curr] != null) {
temp.set(xi[curr],true);
temp.timesEquals(dotCache[ip][prev][curr]);
node.alpha[curr].plusEquals(temp);
}
if (gamma == Transducer.IMPOSSIBLE_WEIGHT) {
node.alpha[curr] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
} else {
temp.set(xi[curr] - gamma,true);
temp.timesEquals(nuAlpha);
node.alpha[curr].plusEquals(temp);
}
assert (!Double.isNaN(node.alpha[curr].logVal)) : "xi: " + xi[curr] + ", gamma: "
+ gamma + ", constraint feature: " + dotCache[ip][prev][curr]
+ ", nuApha: " + nuAlpha + " dot: " + dot;
}
}
}
return dotEx;
}
/**
* Run backward pass of dynamic programming algorithm
*
* @param crf CRF
* @param gammas Marginals over single states
* @param xis Marginals over pairs of states
* @param reverseTrans Source state indices for each destination state
* @param reverseTransIndices Transition indices for each destination state
* @param fvs Input FeatureVectorSequence
* @param dotEx Expectation of constraint features dot violation terms
* @param gradient Gradient to increment
* @return
*/
private void runBackward(CRF crf, double[][] gammas, double[][][] xis, int[][] reverseTrans, int[][] reverseTransIndices,
FeatureVectorSequence fvs, double dotEx, CRF.Factors gradient) {
LogNumber nuBeta = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
LogNumber dot = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
LogNumber temp = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
LogNumber temp2 = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
LogNumber nextDot;
for (int ip = latticeLength-2; ip >= 0; --ip) {
for (int curr = 0; curr < numStates; ++curr) {
nuBeta.set(Transducer.IMPOSSIBLE_WEIGHT,true);
dot.set(Transducer.IMPOSSIBLE_WEIGHT,true);
// calculate only once: \sum_y_{i+1} w_b(y_i,y+i)
CRF.State currState = (CRF.State)crf.getState(curr);
for (int ni = 0; ni < currState.numDestinations(); ni++){
int next= currState.getDestinationState(ni).getIndex();
nuBeta.plusEquals(lattice[ip+1][curr].beta[next]);
assert(!Double.isNaN(nuBeta.logVal));
nextDot = dotCache[ip+1][curr][next];
if (nextDot != null) {
double xi = xis[ip+1][curr][next];
temp.set(xi,true);
temp.timesEquals(nextDot);
dot.plusEquals(temp);
}
}
double gamma = gammas[ip+1][curr];
int[] prevStates = reverseTrans[curr];
for (int pi = 0; pi < prevStates.length; pi++) {
int prev = prevStates[pi];
CRF.State crfState = (CRF.State)crf.getState(prev);
LatticeNode node = lattice[ip][prev];
double xi = xis[ip][prev][curr];
if (gamma == Transducer.IMPOSSIBLE_WEIGHT) {
node.beta[curr] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
} else {
// constraint feature values cached in Forward pass
temp.set(dot.logVal,dot.sign);
temp.plusEquals(nuBeta);
temp2.set(xi-gamma,true);
temp.timesEquals(temp2);
node.beta[curr].plusEquals(temp);
}
assert(!Double.isNaN(node.beta[curr].logVal))
: "xi: " + xi + ", gamma: " + gamma + ", xi: " + xi +
", log(indicatorFeat): " + dotCache[ip][curr];
// compute and update gradient!
double transProb = Math.exp(xi);
double covFirstTerm = node.alpha[curr].exp() + node.beta[curr].exp();
double contribution = (covFirstTerm - (transProb * dotEx));
int nwi = crfState.getWeightNames(reverseTransIndices[curr][pi]).length;
int weightsIndex;
for (int wi = 0; wi < nwi; wi++) {
weightsIndex = ((CRF)transducer).getWeightsIndex(crfState.getWeightNames(reverseTransIndices[curr][pi])[wi]);
gradient.weights[weightsIndex].plusEqualsSparse (fvs.get(ip), contribution);
gradient.defaultWeights[weightsIndex] += contribution;
}
}
}
}
}
/**
* Verifies the correctness of the lattice computations.
*/
public void check(ArrayList constraints, double[][] gammas, double[][][] xis, FeatureVectorSequence fvs) {
// sum of marginal probabilities
double ex1 = 0.0;
for (int ip = 0; ip < latticeLength-1; ++ip) {
for (int si1 = 0; si1 < numStates; si1++) {
for (int si2 = 0; si2 < numStates; si2++) {
double dot = 0;
for (GEConstraint constraint : constraints) {
dot += constraint.getCompositeConstraintFeatureValue(fvs.get(ip), ip, si1, si2);
}
double prob = Math.exp(xis[ip][si1][si2]);
ex1 += prob * dot;
}
}
}
double ex2 = 0.0;
for (int ip = 0; ip < latticeLength-1; ++ip) {
double ex3 = 0.0;
for (int s1 = 0; s1 < numStates; ++s1) {
LatticeNode node = lattice[ip][s1];
for (int s2 = 0; s2 < numStates; ++s2) {
ex3 += node.alpha[s2].exp() + node.beta[s2].exp();
}
}
// should be equal to marginal prob.
assert(ex1 - ex3 < 1e-6) :ex1 + " " + ex3;
ex2 += ex3;
}
ex2 = ex2 / (latticeLength - 1);
// should be equal to marginal prob.
assert(ex1 - ex2 < 1e-6) : ex1 + " " + ex2;
}
public LogNumber getAlpha(int ip, int s1, int s2) {
return lattice[ip][s1].alpha[s2];
}
public LogNumber getBeta(int ip, int s1, int s2) {
return lattice[ip][s1].beta[s2];
}
/**
* Contains forward-backward vectors correspoding to an input position and a
* state index.
*/
protected class LatticeNode {
// ip -> input position, a vector of doubles since for each node we need to
// keep track of the alpha, beta values of state@(ip+1)
protected LogNumber[] alpha;
protected LogNumber[] beta;
public LatticeNode() {
alpha = new LogNumber[numStates];
beta = new LogNumber[numStates];
for (int si = 0; si < numStates; ++si) {
alpha[si] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
beta[si] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy