cc.mallet.grmm.learning.PseudolikelihoodACRFTrainer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning;
import cc.mallet.grmm.types.*;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.SparseVector;
import cc.mallet.util.MalletLogger;
import cc.mallet.grmm.util.CachingOptimizable;
import gnu.trove.THashMap;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
/**
* Created: Mar 15, 2005
*
* @author count
*/
SparseVector constraints[][];
/* Vectors that contain the expected value over the
* labels of all the features, have seen the training data
* (but not the training labels).
*/
SparseVector expectations[][];
SparseVector defaultConstraints[];
SparseVector defaultExpectations[];
private void initWeights (InstanceList training)
{
// ugh!! There must be a way to abstract this back into ACRF, but I don't know the best way....
// problem is that this maxable doesn't extend the ACRF Maxiximable, so I can't just call its initWeights() method
for (int tidx = 0; tidx < templates.length; tidx++) {
numParameters += templates[tidx].initWeights (training);
}
}
/* Initialize constraints[][] and expectations[][]
* to have the same dimensions as weights, but to
* be all zero.
*/
private void initConstraintsExpectations ()
{
// Do the defaults first
defaultConstraints = new SparseVector [templates.length];
defaultExpectations = new SparseVector [templates.length];
for (int tidx = 0; tidx < templates.length; tidx++) {
SparseVector defaults = templates[tidx].getDefaultWeights();
defaultConstraints[tidx] = (SparseVector) defaults.cloneMatrixZeroed ();
defaultExpectations[tidx] = (SparseVector) defaults.cloneMatrixZeroed ();
}
// And now the others
constraints = new SparseVector [templates.length][];
expectations = new SparseVector [templates.length][];
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector[] weights = tmpl.getWeights();
constraints [tidx] = new SparseVector [weights.length];
expectations [tidx] = new SparseVector [weights.length];
for (int i = 0; i < weights.length; i++) {
constraints[tidx][i] = (SparseVector) weights[i].cloneMatrixZeroed ();
expectations[tidx][i] = (SparseVector) weights[i].cloneMatrixZeroed ();
}
}
}
/**
* Set all expectations to 0 after they've been
* initialized.
*/
void resetExpectations ()
{
for (int tidx = 0; tidx < expectations.length; tidx++) {
defaultExpectations [tidx].setAll (0.0);
for (int i = 0; i < expectations[tidx].length; i++) {
expectations[tidx][i].setAll (0.0);
}
}
}
protected Maxable (ACRF acrf, InstanceList ilist)
{
logger.finest ("Initializing OptimizableACRF.");
this.acrf = acrf;
templates = acrf.getTemplates ();
fixedTmpls = acrf.getFixedTemplates ();
/* allocate for weights, constraints and expectations */
this.trainData = ilist;
initWeights(trainData);
initConstraintsExpectations();
int numInstances = trainData.size();
cachedValueStale = cachedGradientStale = true;
/*
if (cacheUnrolledGraphs) {
unrolledGraphs = new UnrolledGraph [numInstances];
}
*/
logger.info("Number of training instances = " + numInstances );
logger.info("Number of parameters = " + numParameters );
describePrior();
logger.fine("Computing constraints");
collectConstraints (trainData);
}
private void describePrior ()
{
logger.info ("Using gaussian prior with variance "+gaussianPriorVariance);
}
public int getNumParameters () { return numParameters; }
/* Negate initialValue and finalValue because the parameters are in
* terms of "weights", not "values".
*/
public void getParameters (double[] buf) {
if ( buf.length != numParameters )
throw new IllegalArgumentException("Argument is not of the " +
" correct dimensions");
int idx = 0;
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector defaults = tmpl.getDefaultWeights ();
double[] values = defaults.getValues();
System.arraycopy (values, 0, buf, idx, values.length);
idx += values.length;
}
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector[] weights = tmpl.getWeights();
for (int assn = 0; assn < weights.length; assn++) {
double[] values = weights [assn].getValues ();
System.arraycopy (values, 0, buf, idx, values.length);
idx += values.length;
}
}
}
protected void setParametersInternal (double[] params)
{
cachedValueStale = cachedGradientStale = true;
int idx = 0;
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector defaults = tmpl.getDefaultWeights();
double[] values = defaults.getValues ();
System.arraycopy (params, idx, values, 0, values.length);
idx += values.length;
}
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector[] weights = tmpl.getWeights();
for (int assn = 0; assn < weights.length; assn++) {
double[] values = weights [assn].getValues ();
System.arraycopy (params, idx, values, 0, values.length);
idx += values.length;
}
}
}
// Functions for unit tests to get constraints and expectations
// I'm too lazy to make a deep copy. Callers should not
// modify these.
public SparseVector[] getExpectations (int cnum) { return expectations [cnum]; }
public SparseVector[] getConstraints (int cnum) { return constraints [cnum]; }
/** print weights */
public void printParameters()
{
double[] buf = new double[numParameters];
getParameters(buf);
int len = buf.length;
for (int w = 0; w < len; w++)
System.out.print(buf[w] + "\t");
System.out.println();
}
protected double computeValue () {
double retval = 0.0;
int numInstances = trainData.size();
long start = System.currentTimeMillis();
long unrollTime = 0;
/* Instance values must either always or never be included in
* the total values; we can't just sometimes skip a value
* because it is infinite, that throws off the total values.
* We only allow an instance to have infinite value if it happens
* from the start (we don't compute the value for the instance
* after the first round. If any other instance has infinite
* value after that it is an error. */
boolean initializingInfiniteValues = false;
if (infiniteValues == null) {
/* We could initialize bitset with one slot for every
* instance, but it is *probably* cheaper not to, taking the
* time hit to allocate the space if a bit becomes
* necessary. */
infiniteValues = new BitSet ();
initializingInfiniteValues = true;
}
/* Clear the sufficient statistics that we are about to fill */
resetExpectations();
/* Fill in expectations for each instance */
for (int i = 0; i < numInstances; i++)
{
Instance instance = trainData.get(i);
/* Compute marginals for each clique */
long unrollStart = System.currentTimeMillis ();
ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph (instance, templates, fixedTmpls);
long unrollEnd = System.currentTimeMillis ();
unrollTime += (unrollEnd - unrollStart);
if (unrolled.numVariables () == 0) continue; // Happens if all nodes are pruned.
/* Save the expected value of each feature for when we
compute the gradient. */
Assignment observations = unrolled.getAssignment ();
double value = collectExpectationsAndValue (unrolled, observations);
if (Double.isInfinite(value))
{
if (initializingInfiniteValues) {
logger.warning ("Instance " + instance.getName() +
" has infinite value; skipping.");
infiniteValues.set (i);
continue;
} else if (!infiniteValues.get(i)) {
logger.warning ("Infinite value on instance "+instance.getName()+
"returning -infinity");
return Double.NEGATIVE_INFINITY;
/*
printDebugInfo (unrolled);
throw new IllegalStateException
("Instance " + instance.getName()+ " used to have non-infinite"
+ " value, but now it has infinite value.");
*/
}
} else if (Double.isNaN (value)) {
System.out.println("NaN on instance "+i+" : "+instance.getName ());
printDebugInfo (unrolled);
/* throw new IllegalStateException
("Value is NaN in ACRF.getValue() Instance "+i);
*/
logger.warning ("Value is NaN in ACRF.getValue() Instance "+i+" : "+
"returning -infinity... ");
return Double.NEGATIVE_INFINITY;
} else {
retval += value;
}
}
/* Incorporate Gaussian prior on parameters. This means
that for each weight, we will add w^2 / (2 * variance) to the
log probability. */
double priorDenom = 2 * gaussianPriorVariance;
for (int tidx = 0; tidx < templates.length; tidx++) {
SparseVector[] weights = templates [tidx].getWeights ();
for (int j = 0; j < weights.length; j++) {
for (int fnum = 0; fnum < weights[j].numLocations(); fnum++) {
double w = weights [j].valueAtLocation (fnum);
if (weightValid (w, tidx, j)) {
retval += -w*w/priorDenom;
}
}
}
}
long end = System.currentTimeMillis ();
logger.info ("ACRF Inference time (ms) = "+(end-start));
logger.info ("ACRF unroll time (ms) = "+unrollTime);
logger.info ("getValue (loglikelihood) = "+retval);
return retval;
}
/**
* Computes the gradient of the penalized log likelihood of the
* ACRF, and places it in cachedGradient[].
*
* Gradient is
* constraint - expectation - parameters/gaussianPriorVariance
*/
protected void computeValueGradient (double[] grad)
{
/* Index into current element of cachedGradient[] array. */
int gidx = 0;
// First do gradient wrt defaultWeights
for (int tidx = 0; tidx < templates.length; tidx++) {
SparseVector theseWeights = templates[tidx].getDefaultWeights ();
SparseVector theseConstraints = defaultConstraints [tidx];
SparseVector theseExpectations = defaultExpectations [tidx];
for (int j = 0; j < theseWeights.numLocations(); j++) {
double weight = theseWeights.valueAtLocation (j);
double constraint = theseConstraints.valueAtLocation (j);
double expectation = theseExpectations.valueAtLocation (j);
if (printGradient) {
System.out.println(" gradient ["+gidx+"] = "+constraint+" (ctr) - "+expectation+" (exp) - "+
(weight / gaussianPriorVariance)+" (reg) [feature=DEFAULT]");
}
grad [gidx++] = constraint - expectation - (weight / gaussianPriorVariance);
}
}
// Now do other weights
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector[] weights = tmpl.getWeights ();
for (int i = 0; i < weights.length; i++) {
SparseVector thisWeightVec = weights [i];
SparseVector thisConstraintVec = constraints [tidx][i];
SparseVector thisExpectationVec = expectations [tidx][i];
for (int j = 0; j < thisWeightVec.numLocations(); j++) {
double w = thisWeightVec.valueAtLocation (j);
double gradient; // Computed below
double constraint = thisConstraintVec.valueAtLocation(j);
double expectation = thisExpectationVec.valueAtLocation(j);
/* A parameter may be set to -infinity by an external user.
* We set gradient to 0 because the parameter's value can
* never change anyway and it will mess up future calculations
* on the matrix. */
if (Double.isInfinite(w)) {
logger.warning("Infinite weight for node index " +i+
" feature " +
acrf.getInputAlphabet().lookupObject(j) );
gradient = 0.0;
} else {
gradient = constraint
- (w/gaussianPriorVariance)
- expectation;
}
if (printGradient) {
int idx = thisWeightVec.indexAtLocation (j);
Object fname = acrf.getInputAlphabet ().lookupObject (idx);
System.out.println(" gradient ["+gidx+"] = "+constraint+" (ctr) - "+expectation+" (exp) - "+
(w / gaussianPriorVariance)+" (reg) [feature="+fname+"]");
}
grad [gidx++] = gradient;
}
}
}
}
/**
* For every feature f_k, computes the expected value of f_k
* aver all possible label sequences given the list of instances
* we have.
*
* These values are stored in collector, that is,
* collector[i][j][k] gets the expected value for the
* feature for clique i, label assignment j, and input features k.
*/
private double collectExpectationsAndValue (ACRF.UnrolledGraph unrolled, Assignment observations)
{
double value = 0.0;
for (CliquesIterator it = makeCliquesIterator (unrolled, observations); it.hasNext();) {
it.advance ();
TableFactor ptl = (TableFactor) it.localConditional ();
double logZ = ptl.logsum ();
ACRF.UnrolledVarSet[] cliques = it.cliques ();
Assignment assn = (Assignment) observations.duplicate ();
// for each assigment to the clique
// xxx SLOW this will need to be sparsified
AssignmentIterator assnIt = ptl.assignmentIterator ();
while (assnIt.hasNext ()) {
double marginal = Math.exp (ptl.logValue (assnIt) - logZ);
// This is ugly need to map from assignments to the single twiddled variable to clique assignments
Assignment currentAssn = assnIt.assignment ();
for (int vi = 0; vi < currentAssn.numVariables (); vi++) {
Variable var = currentAssn.getVariable (vi);
assn.setValue (0, var, currentAssn.get (var));
}
for (int cidx = 0; cidx < cliques.length; cidx++) {
ACRF.UnrolledVarSet clique = cliques[cidx];
int tidx = clique.getTemplate().index;
if (tidx == -1) continue;
int assnIdx = clique.lookupNumberOfAssignment (assn);
expectations [tidx][assnIdx].plusEqualsSparse (clique.getFv (), marginal);
if (defaultExpectations[tidx].location (assnIdx) != -1)
defaultExpectations [tidx].incrementValue (assnIdx, marginal);
}
assnIt.advance ();
}
value += (ptl.logValue (observations) - logZ);
}
return value;
}
private void collectConstraintsForGraph (ACRF.UnrolledGraph unrolled, Assignment observations)
{
for (CliquesIterator it = makeCliquesIterator (unrolled, observations); it.hasNext();) {
it.advance ();
ACRF.UnrolledVarSet[] cliques = it.cliques ();
for (int cidx = 0; cidx < cliques.length; cidx++) {
ACRF.UnrolledVarSet clique = cliques[cidx];
int tidx = clique.getTemplate().index;
if (tidx < 0) continue;
int assnIdx = clique.lookupNumberOfAssignment (observations);
constraints [tidx][assnIdx].plusEqualsSparse (clique.getFv (), 1.0);
if (defaultConstraints[tidx].location (assnIdx) != -1)
defaultConstraints [tidx].incrementValue (assnIdx, 1.0);
}
}
}
public void collectConstraints (InstanceList ilist)
{
for (int inum = 0; inum < ilist.size(); inum++) {
logger.finest ("*** Collecting constraints for instance "+inum);
Instance inst = ilist.get (inum);
ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph (inst, templates, null, true);
Assignment assn = unrolled.getAssignment ();
collectConstraintsForGraph (unrolled, assn);
}
}
void dumpGradientToFile (String fileName)
{
try {
double[] grad = new double [getNumParameters ()];
getValueGradient (grad);
PrintStream w = new PrintStream (new FileOutputStream (fileName));
for (int i = 0; i < numParameters; i++) {
w.println (grad[i]);
}
w.close ();
} catch (IOException e) {
System.err.println("Could not open output file.");
e.printStackTrace ();
}
}
void dumpDefaults ()
{
System.out.println("Default constraints");
for (int i = 0; i < defaultConstraints.length; i++) {
System.out.println("Template "+i);
defaultConstraints[i].print ();
}
System.out.println("Default expectations");
for (int i = 0; i < defaultExpectations.length; i++) {
System.out.println("Template "+i);
defaultExpectations[i].print ();
}
}
void printDebugInfo (ACRF.UnrolledGraph unrolled)
{
acrf.print (System.err);
Assignment assn = unrolled.getAssignment ();
for (Iterator it = unrolled.varSetIterator (); it.hasNext();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next();
System.out.println("Clique "+clique);
dumpAssnForClique (assn, clique);
Factor ptl = unrolled.factorOf (clique);
System.out.println("Value = "+ptl.value (assn));
System.out.println(ptl);
}
}
void dumpAssnForClique (Assignment assn, ACRF.UnrolledVarSet clique)
{
for (Iterator it = clique.iterator(); it.hasNext();) {
Variable var = (Variable) it.next();
System.out.println(var+" ==> "+assn.getObject (var)
+" ("+assn.get (var)+")");
}
}
private boolean weightValid (double w, int cnum, int j)
{
if (Double.isInfinite (w)) {
logger.warning ("Weight is infinite for clique "+cnum+"assignment "+j);
return false;
} else if (Double.isNaN (w)) {
logger.warning ("Weight is Nan for clique "+cnum+"assignment "+j);
return false;
} else {
return true;
}
}
} // OptimizableACRF
}