
cc.mallet.grmm.learning.PwplACRFTrainer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jcore-mallet-2.0.9 Show documentation
Show all versions of jcore-mallet-2.0.9 Show documentation
MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.
The newest version!
/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://mallet.cs.umass.edu/
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning;
import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.types.*;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.*;
import cc.mallet.util.FileUtils;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import cc.mallet.util.Timing;
import cc.mallet.grmm.util.CachingOptimizable;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
/**
* Implementation of piecewise PL (Sutton and McCallum, 2007)
*
* NB The wrong-wrong options are for an extension that we tried that never quite worked
*
* Created: Mar 15, 2005
*
* @author
* Gradient is
* constraint - expectation - parameters/gaussianPriorVariance
*/
protected void computeValueGradient (double[] grad)
{
/* Index into current element of cachedGradient[] array. */
int gidx = 0;
// First do gradient wrt defaultWeights
for (int tidx = 0; tidx < templates.length; tidx++) {
SparseVector theseWeights = templates[tidx].getDefaultWeights ();
SparseVector theseConstraints = defaultConstraints[tidx];
SparseVector theseExpectations = defaultExpectations[tidx];
for (int j = 0; j < theseWeights.numLocations (); j++) {
double weight = theseWeights.valueAtLocation (j);
double constraint = theseConstraints.valueAtLocation (j);
double expectation = theseExpectations.valueAtLocation (j);
if (PwplACRFTrainer.printGradient) {
System.out.println (" gradient [" + gidx + "] = " + constraint + " (ctr) - " + expectation + " (exp) - " + (weight gaussianPriorVariance) + " (reg) [feature=DEFAULT]"); } grad[gidx++] = constraint - expectation - (weight gaussianPriorVariance); } } Now do other weights for (int tidx = 0; tidx < templates.length; tidx++) { ACRF.Template tmpl = templates[tidx]; SparseVector[] weights = tmpl.getWeights (); for (int i = 0; i < weights.length; i++) { SparseVector thisWeightVec = weights[i]; SparseVector thisConstraintVec = constraints[tidx][i]; SparseVector thisExpectationVec = expectations[tidx][i]; for (int j = 0; j < thisWeightVec.numLocations (); j++) { double w = thisWeightVec.valueAtLocation (j); double gradient; Computed below double constraint = thisConstraintVec.valueAtLocation (j); double expectation = thisExpectationVec.valueAtLocation (j); * A parameter may be set to -infinity by an external user. * We set gradient to 0 because the parameter's value can * never change anyway and it will mess up future calculations * on the matrix. * if (Double.isInfinite (w)) { PwplACRFTrainer.logger.warning ("Infinite weight for node index " + i + " feature " + acrf.getInputAlphabet ().lookupObject (j)); gradient = 0.0; } else { gradient = constraint - (w gaussianPriorVariance) - expectation; } if (PwplACRFTrainer.printGradient) { int idx = thisWeightVec.indexAtLocation (j); Object fname = acrf.getInputAlphabet ().lookupObject (idx); System.out.println (" gradient [" + gidx + "] = " + constraint + " (ctr) - " + expectation + " (exp) - " + (w gaussianPriorVariance) + " (reg) [feature=" + fname + " ]"); } grad[gidx++] = gradient; } } } } ** * For every feature f_k, computes the expected value of f_k * aver all possible label sequences given the list of instances * we have. *
* These values are stored in collector, that is,
* collector[i][j][k] gets the expected value for the
* feature for clique i, label assignment j, and input features k.
*/
private double collectExpectationsAndValue (ACRF.UnrolledGraph unrolled, Assignment observations, int inum)
{
double value = 0.0;
for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
ACRF.Template tmpl = clique.getTemplate ();
int tidx = tmpl.index;
if (tidx == -1) continue;
for (int vi = 0; vi < clique.size (); vi++) {
Variable target = clique.get (vi);
value += computeValueGradientForAssn (observations, clique, target);
}
}
switch (wrongWrongType) {
case NO_WRONG_WRONG:
break;
case CONDITION_WW:
value += addConditionalWW (unrolled, inum);
break;
default:
throw new IllegalStateException ();
}
return value;
}
private double addConditionalWW (ACRF.UnrolledGraph unrolled, int inum)
{
double value = 0;
if (allWrongWrongs != null) {
List wrongs = allWrongWrongs[inum];
for (Iterator it = wrongs.iterator (); it.hasNext ();) {
WrongWrong ww = (WrongWrong) it.next ();
Variable target = ww.findVariable (unrolled);
ACRF.UnrolledVarSet clique = ww.findVarSet (unrolled);
Assignment wrong = Assignment.makeFromSingleIndex (clique, ww.assnIdx);
// System.out.println ("Computing for WW: "+clique+" idx "+ww.assnIdx+" target "+target);
value += computeValueGradientForAssn (wrong, clique, target);
}
}
return value;
}
private double computeValueGradientForAssn (Assignment observations, ACRF.UnrolledVarSet clique, Variable target)
{
numCvgaCalls++;
Timing timing = new Timing ();
ACRF.Template tmpl = clique.getTemplate ();
int tidx = tmpl.index;
Assignment cliqueAssn = Assignment.restriction (observations, clique);
int M = target.getNumOutcomes ();
double[] vals = new double [M];
int[] singles = new int [M];
for (int assnIdx = 0; assnIdx < M; assnIdx++) {
cliqueAssn.setValue (target, assnIdx);
vals[assnIdx] = computeLogFactorValue (cliqueAssn, tmpl, clique.getFv ());
singles[assnIdx] = cliqueAssn.singleIndex ();
}
double logZ = Maths.sumLogProb (vals);
for (int assnIdx = 0; assnIdx < M; assnIdx++) {
double marginal = Math.exp (vals[assnIdx] - logZ);
int expIdx = singles[assnIdx];
expectations[tidx][expIdx].plusEqualsSparse (clique.getFv (), marginal);
if (defaultExpectations[tidx].location (expIdx) != -1) {
defaultExpectations[tidx].incrementValue (expIdx, marginal);
}
}
int observedVal = observations.get (target);
timePerCvgaCall += timing.elapsedTime ();
return vals[observedVal] - logZ;
}
private double computeLogFactorValue (Assignment cliqueAssn, ACRF.Template tmpl, FeatureVector fv)
{
SparseVector[] weights = tmpl.getWeights ();
int idx = cliqueAssn.singleIndex ();
SparseVector w = weights[idx];
double dp = w.dotProduct (fv);
dp += tmpl.getDefaultWeight (idx);
return dp;
}
public void collectConstraints (InstanceList ilist)
{
for (int inum = 0; inum < ilist.size (); inum++) {
PwplACRFTrainer.logger.finest ("*** Collecting constraints for instance " + inum);
Instance inst = ilist.get (inum);
ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph (inst, templates, null, false);
for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
int tidx = clique.getTemplate ().index;
if (tidx == -1) continue;
int assn = clique.lookupAssignmentNumber ();
constraints[tidx][assn].plusEqualsSparse (clique.getFv (), clique.size ());
if (defaultConstraints[tidx].location (assn) != -1) {
defaultConstraints[tidx].incrementValue (assn, clique.size ());
}
}
// constraints for wrong-wrongs for instance
if (allWrongWrongs != null) {
List wrongs = allWrongWrongs[inum];
for (Iterator wwIt = wrongs.iterator (); wwIt.hasNext ();) {
WrongWrong ww = (WrongWrong) wwIt.next ();
ACRF.UnrolledVarSet clique = ww.findVarSet (unrolled);
int tidx = clique.getTemplate ().index;
int wrong2rightId = ww.assnIdx;
constraints[tidx][wrong2rightId].plusEqualsSparse (clique.getFv (), 1.0);
if (defaultConstraints[tidx].location (wrong2rightId) != -1) {
defaultConstraints[tidx].incrementValue (wrong2rightId, 1.0);
}
}
}
}
}
void dumpGradientToFile (String fileName)
{
try {
double[] grad = new double [getNumParameters ()];
getValueGradient (grad);
PrintStream w = new PrintStream (new FileOutputStream (fileName));
for (int i = 0; i < numParameters; i++) {
w.println (grad[i]);
}
w.close ();
} catch (IOException e) {
System.err.println ("Could not open output file.");
e.printStackTrace ();
}
}
void dumpDefaults ()
{
System.out.println ("Default constraints");
for (int i = 0; i < defaultConstraints.length; i++) {
System.out.println ("Template " + i);
defaultConstraints[i].print ();
}
System.out.println ("Default expectations");
for (int i = 0; i < defaultExpectations.length; i++) {
System.out.println ("Template " + i);
defaultExpectations[i].print ();
}
}
void printDebugInfo (ACRF.UnrolledGraph unrolled)
{
acrf.print (System.err);
Assignment assn = unrolled.getAssignment ();
for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
System.out.println ("Clique " + clique);
dumpAssnForClique (assn, clique);
Factor ptl = unrolled.factorOf (clique);
System.out.println ("Value = " + ptl.value (assn));
System.out.println (ptl);
}
}
void dumpAssnForClique (Assignment assn, ACRF.UnrolledVarSet clique)
{
for (Iterator it = clique.iterator (); it.hasNext ();) {
Variable var = (Variable) it.next ();
System.out.println (var + " ==> " + assn.getObject (var)
+ " (" + assn.get (var) + ")");
}
}
private boolean weightValid (double w, int cnum, int j)
{
if (Double.isInfinite (w)) {
PwplACRFTrainer.logger.warning ("Weight is infinite for clique " + cnum + "assignment " + j);
return false;
} else if (Double.isNaN (w)) {
PwplACRFTrainer.logger.warning ("Weight is Nan for clique " + cnum + "assignment " + j);
return false;
} else {
return true;
}
}
// WRONG WRONG HANDLING
private class WrongWrong {
int varIdx;
int vsIdx;
int assnIdx;
public WrongWrong (ACRF.UnrolledGraph graph, VarSet vs, Variable var, int assnIdx)
{
varIdx = graph.getIndex (var);
vsIdx = graph.getIndex (vs);
this.assnIdx = assnIdx;
}
public ACRF.UnrolledVarSet findVarSet (ACRF.UnrolledGraph unrolled)
{
return unrolled.getUnrolledVarSet (vsIdx);
}
public Variable findVariable (ACRF.UnrolledGraph unrolled)
{
return unrolled.get (varIdx);
}
}
private List allWrongWrongs[];
private void addWrongWrong (InstanceList training)
{
allWrongWrongs = new List [training.size ()];
int totalAdded = 0;
// if (!acrf.isCacheUnrolledGraphs ()) {
// throw new IllegalStateException ("Wrong-wrong won't work without caching unrolled graphs.");
// }
for (int i = 0; i < training.size (); i++) {
allWrongWrongs[i] = new ArrayList ();
int numAdded = 0;
Instance instance = training.get (i);
ACRF.UnrolledGraph unrolled = acrf.unroll (instance);
if (unrolled.factors ().size () == 0) {
System.err.println ("WARNING: FactorGraph for instance " + instance.getName () + " : no factors.");
continue;
}
Inferencer inf = acrf.getInferencer ();
inf.computeMarginals (unrolled);
Assignment target = unrolled.getAssignment ();
for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
ACRF.UnrolledVarSet vs = (ACRF.UnrolledVarSet) it.next ();
Factor marg = inf.lookupMarginal (vs);
for (AssignmentIterator assnIt = vs.assignmentIterator (); assnIt.hasNext (); assnIt.advance ()) {
if (marg.value (assnIt) > wrongWrongThreshold) {
Assignment assn = assnIt.assignment ();
for (int vi = 0; vi < vs.size (); vi++) {
Variable var = vs.get (vi);
if (isWrong2RightAssn (target, assn, var)) {
int assnIdx = assn.singleIndex ();
// System.out.println ("Computing for WW: "+vs+" idx "+assnIdx+" target "+var);
allWrongWrongs[i].add (new WrongWrong (unrolled, vs, var, assnIdx));
numAdded++;
}
}
}
}
}
logger.info ("WrongWrongs: Instance " + i + " : " + instance.getName () + " Num added = " + numAdded);
totalAdded += numAdded;
}
resetConstraints ();
collectConstraints (training);
forceStale ();
logger.info ("Total timesteps = " + totalTimesteps (training));
logger.info ("Total WrongWrongs = " + totalAdded);
}
private int totalTimesteps (InstanceList ilist)
{
int total = 0;
for (int i = 0; i < ilist.size (); i++) {
Instance inst = ilist.get (i);
Sequence seq = (Sequence) inst.getData ();
total += seq.size ();
}
return total;
}
private boolean isWrong2RightAssn (Assignment target, Assignment assn, Variable toExclude)
{
Variable[] vars = assn.getVars ();
for (int i = 0; i < vars.length; i++) {
Variable variable = vars[i];
if ((variable != toExclude) && (assn.get (variable) != target.get (variable))) {
// return true;
return assn.get (toExclude) == target.get (toExclude);
}
}
return false;
}
} // OptimizableACRF
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy