cc.mallet.fst.CRF Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Andrew McCallum [email protected]
*/
package cc.mallet.fst;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import java.text.DecimalFormat;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureInducer;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.IndexedSparseVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.types.Sequence;
import cc.mallet.types.SparseVector;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.util.ArrayUtils;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
/* There are several different kinds of numeric values:
"weights" range from -Inf to Inf. High weights make a path more
likely. These don't appear directly in Transducer.java, but appear
as parameters to many subclasses, such as CRFs. Weights are also
often summed, or combined in a dot product with feature vectors.
"unnormalized costs" range from -Inf to Inf. High costs make a
path less likely. Unnormalized costs can be obtained from negated
weights or negated sums of weights. These are often returned by a
TransitionIterator's getValue() method. The LatticeNode.alpha
values are unnormalized costs.
"normalized costs" range from 0 to Inf. High costs make a path
less likely. Normalized costs can safely be considered as the
-log(probability) of some event. They can be obtained by
subtracting a (negative) normalizer from unnormalized costs, for
example, subtracting the total cost of a lattice. Typically
initialCosts and finalCosts are examples of normalized costs, but
they are also allowed to be unnormalized costs. The gammas[][],
stateGammas[], and transitionXis[][] are all normalized costs, as
well as the return value of Lattice.getValue().
"probabilities" range from 0 to 1. High probabilities make a path
more likely. They are obtained from normalized costs by taking the
log and negating.
"sums of probabilities" range from 0 to positive numbers. They are
the sum of several probabilities. These are passed to the
incrementCount() methods.
*/
/**
* Represents a CRF model.
*/
public class CRF extends Transducer implements Serializable
{
private static Logger logger = MalletLogger.getLogger(CRF.class.getName());
static final String LABEL_SEPARATOR = ",";
protected Alphabet inputAlphabet;
protected Alphabet outputAlphabet;
protected ArrayList states = new ArrayList ();
protected ArrayList initialStates = new ArrayList ();
protected HashMap name2state = new HashMap ();
protected Factors parameters = new Factors ();
//SparseVector[] weights;
//double[] defaultWeights; // parameters for default feature
//Alphabet weightAlphabet = new Alphabet ();
//boolean[] weightsFrozen;
// FeatureInduction can fill this in
protected FeatureSelection globalFeatureSelection;
// "featureSelections" is on a per- weights[i] basis, and over-rides
// (permanently disabling) FeatureInducer's and
// setWeightsDimensionsAsIn() from using these features on these transitions
protected FeatureSelection[] featureSelections;
// Store here the induced feature conjunctions so that these conjunctions can be added to test instances before transduction
protected ArrayList featureInducers = new ArrayList();
// An integer index that gets incremented each time this CRFs parameters get changed
protected int weightsValueChangeStamp = 0;
// An integer index that gets incremented each time this CRFs parameters' structure get changed
protected int weightsStructureChangeStamp = 0;
protected int cachedNumParametersStamp = -1; // A copy of weightsStructureChangeStamp the last time numParameters was calculated
protected int numParameters;
/** A simple, transparent container to hold the parameters or sufficient statistics for the CRF. */
public static class Factors implements Serializable {
public Alphabet weightAlphabet;
public SparseVector[] weights; // parameters on transitions, indexed by "weight index"
public double[] defaultWeights;// parameters for default features, indexed by "weight index"
public boolean[] weightsFrozen; // flag, if true indicating that the weights of this "weight index" should not be changed by learning, indexed by "weight index"
public double [] initialWeights; // indexed by state index
public double [] finalWeights; // indexed by state index
/** Construct a new empty Factors with a new empty weightsAlphabet, 0-length initialWeights and finalWeights, and the other arrays null. */
public Factors () {
weightAlphabet = new Alphabet();
initialWeights = new double[0];
finalWeights = new double[0];
// Leave the rest as null. They will get set later by addState() and addWeight()
// Alternatively, we could create zero-length arrays
}
/** Construct new Factors by mimicking the structure of the other one, but with zero values.
* Always simply point to the other's Alphabet; do not clone it. */
public Factors (Factors other) {
weightAlphabet = other.weightAlphabet;
weights = new SparseVector[other.weights.length];
for (int i = 0; i < weights.length; i++)
weights[i] = (SparseVector) other.weights[i].cloneMatrixZeroed();
defaultWeights = new double[other.defaultWeights.length];
weightsFrozen = other.weightsFrozen; // We don't copy here because we want "expectation" and "constraint" factors to get changes to a CRF.parameters factor. Alternatively we declare freezing to be a change of structure, and force reallocation of "expectations", etc.
initialWeights = new double[other.initialWeights.length];
finalWeights = new double[other.finalWeights.length];
}
/** Construct new Factors by copying the other one. */
public Factors (Factors other, boolean cloneAlphabet) {
weightAlphabet = cloneAlphabet ? (Alphabet) other.weightAlphabet.clone() : other.weightAlphabet;
weights = new SparseVector[other.weights.length];
for (int i = 0; i < weights.length; i++)
weights[i] = (SparseVector) other.weights[i].cloneMatrix();
defaultWeights = other.defaultWeights.clone();
weightsFrozen = other.weightsFrozen;
initialWeights = other.initialWeights.clone();
finalWeights = other.finalWeights.clone();
}
/** Construct a new Factors with the same structure as the parameters of 'crf', but with values initialized to zero.
* This method is typically used to allocate storage for sufficient statistics, expectations, constraints, etc. */
public Factors (CRF crf) {
// TODO Change this implementation to this(crf.parameters)
weightAlphabet = crf.parameters.weightAlphabet; // TODO consider cloning this instead
weights = new SparseVector[crf.parameters.weights.length];
for (int i = 0; i < weights.length; i++)
weights[i] = (SparseVector) crf.parameters.weights[i].cloneMatrixZeroed();
defaultWeights = new double[crf.parameters.weights.length];
weightsFrozen = crf.parameters.weightsFrozen;
assert (crf.numStates() == crf.parameters.initialWeights.length);
assert (crf.parameters.initialWeights.length == crf.parameters.finalWeights.length);
initialWeights = new double[crf.parameters.initialWeights.length];
finalWeights = new double[crf.parameters.finalWeights.length];
}
public int getNumFactors () {
assert (initialWeights.length == finalWeights.length);
assert (defaultWeights.length == weights.length);
int ret = initialWeights.length + finalWeights.length + defaultWeights.length;
for (int i = 0; i < weights.length; i++)
ret += weights[i].numLocations();
return ret;
}
public void zero () {
for (int i = 0; i < weights.length; i++)
weights[i].setAll(0);
Arrays.fill(defaultWeights, 0);
Arrays.fill(initialWeights, 0);
Arrays.fill(finalWeights, 0);
}
public boolean structureMatches (Factors other) {
if (weightAlphabet.size() != other.weightAlphabet.size()) return false;
if (weights.length != other.weights.length) return false;
// gsc: checking each SparseVector's size within weights.
for (int i = 0; i < weights.length; i++)
if (weights[i].numLocations() != other.weights[i].numLocations()) return false;
// Note that we are not checking the indices of the SparseVectors in weights
if (defaultWeights.length != other.defaultWeights.length) return false;
assert (initialWeights.length == finalWeights.length);
if (initialWeights.length != other.initialWeights.length) return false;
return true;
}
public void assertNotNaN () {
for (int i = 0; i < weights.length; i++)
assert (!weights[i].isNaN());
assert (!MatrixOps.isNaN(defaultWeights));
assert (!MatrixOps.isNaN(initialWeights));
assert (!MatrixOps.isNaN(finalWeights));
}
// gsc: checks all weights to make sure there are no NaN or Infinite values,
// this method can be called for checking the weights of constraints and
// expectations but not for crf.parameters since it can have infinite
// weights associated with states that are not likely.
public void assertNotNaNOrInfinite () {
for (int i = 0; i < weights.length; i++)
assert (!weights[i].isNaNOrInfinite());
assert (!MatrixOps.isNaNOrInfinite(defaultWeights));
assert (!MatrixOps.isNaNOrInfinite(initialWeights));
assert (!MatrixOps.isNaNOrInfinite(finalWeights));
}
public void plusEquals (Factors other, double factor) {
plusEquals(other, factor, false);
}
public void plusEquals (Factors other, double factor, boolean obeyWeightsFrozen) {
for (int i = 0; i < weights.length; i++) {
if (obeyWeightsFrozen && weightsFrozen[i]) continue;
this.weights[i].plusEqualsSparse(other.weights[i], factor);
this.defaultWeights[i] += other.defaultWeights[i] * factor;
}
for (int i = 0; i < initialWeights.length; i++) {
this.initialWeights[i] += other.initialWeights[i] * factor;
this.finalWeights[i] += other.finalWeights[i] * factor;
}
}
/** Return the log(p(parameters)) according to a zero-mean Gaussian with given variance. */
public double gaussianPrior (double variance) {
double value = 0;
double priorDenom = 2 * variance;
assert (initialWeights.length == finalWeights.length);
for (int i = 0; i < initialWeights.length; i++) {
if (!Double.isInfinite(initialWeights[i])) value -= initialWeights[i] * initialWeights[i] / priorDenom;
if (!Double.isInfinite(finalWeights[i])) value -= finalWeights[i] * finalWeights[i] / priorDenom;
}
double w;
for (int i = 0; i < weights.length; i++) {
if (!Double.isInfinite(defaultWeights[i])) value -= defaultWeights[i] * defaultWeights[i] / priorDenom;
for (int j = 0; j < weights[i].numLocations(); j++) {
w = weights[i].valueAtLocation (j);
if (!Double.isInfinite(w)) value -= w * w / priorDenom;
}
}
return value;
}
public void plusEqualsGaussianPriorGradient (Factors other, double variance) {
assert (initialWeights.length == finalWeights.length);
for (int i = 0; i < initialWeights.length; i++) {
// gsc: checking initial/final weights of crf.parameters as well since we could
// have a state machine where some states have infinite initial and/or final weight
if (!Double.isInfinite(initialWeights[i]) && !Double.isInfinite(other.initialWeights[i]))
initialWeights[i] -= other.initialWeights[i] / variance;
if (!Double.isInfinite(finalWeights[i]) && !Double.isInfinite(other.finalWeights[i]))
finalWeights[i] -= other.finalWeights[i] / variance;
}
double w, ow;
for (int i = 0; i < weights.length; i++) {
if (weightsFrozen[i]) continue;
// TODO Note that there doesn't seem to be a way to freeze the initialWeights and finalWeights
// TODO Should we also obey FeatureSelection here? No need; it is enforced by the creation of the weights.
if (!Double.isInfinite(defaultWeights[i])) defaultWeights[i] -= other.defaultWeights[i] / variance;
for (int j = 0; j < weights[i].numLocations(); j++) {
w = weights[i].valueAtLocation (j);
ow = other.weights[i].valueAtLocation (j);
if (!Double.isInfinite(w)) weights[i].setValueAtLocation(j, w - (ow/variance));
}
}
}
/** Return the log(p(parameters)) according to a a hyperbolic curve that is a smooth approximation to an L1 prior. */
public double hyberbolicPrior (double slope, double sharpness) {
double value = 0;
assert (initialWeights.length == finalWeights.length);
for (int i = 0; i < initialWeights.length; i++) {
if (!Double.isInfinite(initialWeights[i]))
value -= (slope / sharpness * Math.log (Maths.cosh (sharpness * -initialWeights[i])));
if (!Double.isInfinite(finalWeights[i]))
value -= (slope / sharpness * Math.log (Maths.cosh (sharpness * -finalWeights[i])));
}
double w;
for (int i = 0; i < weights.length; i++) {
value -= (slope / sharpness * Math.log (Maths.cosh (sharpness * defaultWeights[i])));
for (int j = 0; j < weights[i].numLocations(); j++) {
w = weights[i].valueAtLocation(j);
if (!Double.isInfinite(w))
value -= (slope / sharpness * Math.log (Maths.cosh (sharpness * w)));
}
}
return value;
}
public void plusEqualsHyperbolicPriorGradient (Factors other, double slope, double sharpness) {
// TODO This method could use some careful checking over, especially for flipped negations
assert (initialWeights.length == finalWeights.length);
double ss = slope * sharpness;
for (int i = 0; i < initialWeights.length; i++) {
// gsc: checking initial/final weights of crf.parameters as well since we could
// have a state machine where some states have infinite initial and/or final weight
if (!Double.isInfinite(initialWeights[i]) && !Double.isInfinite(other.initialWeights[i]))
initialWeights[i] += ss * Maths.tanh (-other.initialWeights[i]);
if (!Double.isInfinite(finalWeights[i]) && !Double.isInfinite(other.finalWeights[i]))
finalWeights[i] += ss * Maths.tanh (-other.finalWeights[i]);
}
double w, ow;
for (int i = 0; i < weights.length; i++) {
if (weightsFrozen[i]) continue;
// TODO Note that there doesn't seem to be a way to freeze the initialWeights and finalWeights
// TODO Should we also obey FeatureSelection here? No need; it is enforced by the creation of the weights.
if (!Double.isInfinite(defaultWeights[i])) defaultWeights[i] += ss * Maths.tanh(-other.defaultWeights[i]);
for (int j = 0; j < weights[i].numLocations(); j++) {
w = weights[i].valueAtLocation (j);
ow = other.weights[i].valueAtLocation (j);
if (!Double.isInfinite(w)) weights[i].setValueAtLocation(j, w + (ss * Maths.tanh(-ow)));
}
}
}
/** Instances of this inner class can be passed to various inference methods, which can then
* gather/increment sufficient statistics counts into the containing Factor instance. */
public class Incrementor implements Transducer.Incrementor {
public void incrementFinalState(Transducer.State s, double count) {
finalWeights[s.getIndex()] += count;
}
public void incrementInitialState(Transducer.State s, double count) {
initialWeights[s.getIndex()] += count;
}
public void incrementTransition(Transducer.TransitionIterator ti, double count) {
int index = ti.getIndex();
CRF.State source = (CRF.State)ti.getSourceState();
int nwi = source.weightsIndices[index].length;
int weightsIndex;
for (int wi = 0; wi < nwi; wi++) {
weightsIndex = source.weightsIndices[index][wi];
// For frozen weights, don't even gather their sufficient statistics; this is how we ensure that the gradient for these will be zero
if (weightsFrozen[weightsIndex]) continue;
// TODO Should we also obey FeatureSelection here? No need; it is enforced by the creation of the weights.
weights[weightsIndex].plusEqualsSparse ((FeatureVector)ti.getInput(), count);
defaultWeights[weightsIndex] += count;
}
}
}
public double getParametersAbsNorm ()
{
double ret = 0;
for (int i = 0; i < initialWeights.length; i++) {
if (initialWeights[i] > Transducer.IMPOSSIBLE_WEIGHT)
ret += Math.abs(initialWeights[i]);
if (finalWeights[i] > Transducer.IMPOSSIBLE_WEIGHT)
ret += Math.abs(finalWeights[i]);
}
for (int i = 0; i < weights.length; i++) {
ret += Math.abs(defaultWeights[i]);
int nl = weights[i].numLocations();
for (int j = 0; j < nl; j++)
ret += Math.abs(weights[i].valueAtLocation(j));
}
return ret;
}
public class WeightedIncrementor implements Transducer.Incrementor {
double instanceWeight = 1.0;
public WeightedIncrementor (double instanceWeight) {
this.instanceWeight = instanceWeight;
}
public void incrementFinalState(Transducer.State s, double count) {
finalWeights[s.getIndex()] += count * instanceWeight;
}
public void incrementInitialState(Transducer.State s, double count) {
initialWeights[s.getIndex()] += count * instanceWeight;
}
public void incrementTransition(Transducer.TransitionIterator ti, double count) {
int index = ti.getIndex();
CRF.State source = (CRF.State)ti.getSourceState();
int nwi = source.weightsIndices[index].length;
int weightsIndex;
count *= instanceWeight;
for (int wi = 0; wi < nwi; wi++) {
weightsIndex = source.weightsIndices[index][wi];
// For frozen weights, don't even gather their sufficient statistics; this is how we ensure that the gradient for these will be zero
if (weightsFrozen[weightsIndex]) continue;
// TODO Should we also obey FeatureSelection here? No need; it is enforced by the creation of the weights.
weights[weightsIndex].plusEqualsSparse ((FeatureVector)ti.getInput(), count);
defaultWeights[weightsIndex] += count;
}
}
}
public void getParameters (double[] buffer)
{
if (buffer.length != getNumFactors ())
throw new IllegalArgumentException ("Expected size of buffer: " + getNumFactors() + ", actual size: " + buffer.length);
int pi = 0;
for (int i = 0; i < initialWeights.length; i++) {
buffer[pi++] = initialWeights[i];
buffer[pi++] = finalWeights[i];
}
for (int i = 0; i < weights.length; i++) {
buffer[pi++] = defaultWeights[i];
int nl = weights[i].numLocations();
for (int j = 0; j < nl; j++)
buffer[pi++] = weights[i].valueAtLocation(j);
}
}
public double getParameter (int index) {
int numStateParms = 2 * initialWeights.length;
if (index < numStateParms) {
if (index % 2 == 0)
return initialWeights[index/2];
return finalWeights[index/2];
}
index -= numStateParms;
for (int i = 0; i < weights.length; i++) {
if (index == 0)
return this.defaultWeights[i];
index--;
if (index < weights[i].numLocations())
return weights[i].valueAtLocation (index);
index -= weights[i].numLocations();
}
throw new IllegalArgumentException ("index too high = "+index);
}
public void setParameters (double [] buff) {
assert (buff.length == getNumFactors());
int pi = 0;
for (int i = 0; i < initialWeights.length; i++) {
initialWeights[i] = buff[pi++];
finalWeights[i] = buff[pi++];
}
for (int i = 0; i < weights.length; i++) {
this.defaultWeights[i] = buff[pi++];
int nl = weights[i].numLocations();
for (int j = 0; j < nl; j++)
weights[i].setValueAtLocation (j, buff[pi++]);
}
}
public void setParameter (int index, double value) {
int numStateParms = 2 * initialWeights.length;
if (index < numStateParms) {
if (index % 2 == 0)
initialWeights[index/2] = value;
else
finalWeights[index/2] = value;
} else {
index -= numStateParms;
for (int i = 0; i < weights.length; i++) {
if (index == 0) {
defaultWeights[i] = value;
return;
}
index--;
if (index < weights[i].numLocations()) {
weights[i].setValueAtLocation (index, value);
return;
} else {
index -= weights[i].numLocations();
}
}
throw new IllegalArgumentException ("index too high = "+index);
}
}
// gsc: Serialization for Factors
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject (weightAlphabet);
out.writeObject (weights);
out.writeObject (defaultWeights);
out.writeObject (weightsFrozen);
out.writeObject (initialWeights);
out.writeObject (finalWeights);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt ();
weightAlphabet = (Alphabet) in.readObject ();
weights = (SparseVector[]) in.readObject ();
defaultWeights = (double[]) in.readObject ();
weightsFrozen = (boolean[]) in.readObject ();
initialWeights = (double[]) in.readObject ();
finalWeights = (double[]) in.readObject ();
}
}
public CRF (Pipe inputPipe, Pipe outputPipe)
{
super (inputPipe, outputPipe);
this.inputAlphabet = inputPipe.getDataAlphabet();
this.outputAlphabet = inputPipe.getTargetAlphabet();
//inputAlphabet.stopGrowth();
}
public CRF (Alphabet inputAlphabet, Alphabet outputAlphabet)
{
super (new Noop(inputAlphabet, outputAlphabet), null);
inputAlphabet.stopGrowth();
logger.info ("CRF input dictionary size = "+inputAlphabet.size());
//xxx outputAlphabet.stopGrowth();
this.inputAlphabet = inputAlphabet;
this.outputAlphabet = outputAlphabet;
}
/** Create a CRF whose states and weights are a copy of those from another CRF. */
public CRF (CRF other)
{
// This assumes that "other" has non-null inputPipe and outputPipe. We'd need to add another constructor to handle this if not.
this (other.getInputPipe (), other.getOutputPipe ());
copyStatesAndWeightsFrom (other);
assertWeightsLength ();
}
private void copyStatesAndWeightsFrom (CRF initialCRF)
{
this.parameters = new Factors (initialCRF.parameters, true); // This will copy all the transition weights
this.parameters.weightAlphabet = (Alphabet) initialCRF.parameters.weightAlphabet.clone();
//weightAlphabet = (Alphabet) initialCRF.weightAlphabet.clone ();
//weights = new SparseVector [initialCRF.weights.length];
states.clear ();
// Clear these, because they will be filled by this.addState()
this.parameters.initialWeights = new double[0];
this.parameters.finalWeights = new double[0];
for (int i = 0; i < initialCRF.states.size(); i++) {
State s = (State) initialCRF.getState (i);
String[][] weightNames = new String[s.weightsIndices.length][];
for (int j = 0; j < weightNames.length; j++) {
int[] thisW = s.weightsIndices[j];
weightNames[j] = (String[]) initialCRF.parameters.weightAlphabet.lookupObjects(thisW, new String [s.weightsIndices[j].length]);
}
addState (s.name, initialCRF.parameters.initialWeights[i], initialCRF.parameters.finalWeights[i],
s.destinationNames, s.labels, weightNames);
}
featureSelections = initialCRF.featureSelections.clone ();
// yyy weightsFrozen = (boolean[]) initialCRF.weightsFrozen.clone();
}
public Alphabet getInputAlphabet () { return inputAlphabet; }
public Alphabet getOutputAlphabet () { return outputAlphabet; }
/** This method should be called whenever the CRFs weights (parameters) have their structure/arity/number changed. */
public void weightsStructureChanged () {
weightsStructureChangeStamp++;
weightsValueChangeStamp++;
}
/** This method should be called whenever the CRFs weights (parameters) are changed. */
public void weightsValueChanged () {
weightsValueChangeStamp++;
}
// This method can be over-ridden in subclasses of CRF to return subclasses of CRF.State
protected CRF.State newState (String name, int index,
double initialWeight, double finalWeight,
String[] destinationNames,
String[] labelNames,
String[][] weightNames,
CRF crf)
{
return new State (name, index, initialWeight, finalWeight,
destinationNames, labelNames, weightNames, crf);
}
public void addState (String name, double initialWeight, double finalWeight,
String[] destinationNames,
String[] labelNames,
String[][] weightNames)
{
assert (weightNames.length == destinationNames.length);
assert (labelNames.length == destinationNames.length);
weightsStructureChanged();
if (name2state.get(name) != null)
throw new IllegalArgumentException ("State with name `"+name+"' already exists.");
parameters.initialWeights = MatrixOps.append(parameters.initialWeights, initialWeight);
parameters.finalWeights = MatrixOps.append(parameters.finalWeights, finalWeight);
State s = newState (name, states.size(), initialWeight, finalWeight,
destinationNames, labelNames, weightNames, this);
//s.print ();
states.add (s);
if (initialWeight > IMPOSSIBLE_WEIGHT)
initialStates.add (s);
name2state.put (name, s);
}
public void addState (String name, double initialWeight, double finalWeight,
String[] destinationNames,
String[] labelNames,
String[] weightNames)
{
String[][] newWeightNames = new String[weightNames.length][1];
for (int i = 0; i < weightNames.length; i++)
newWeightNames[i][0] = weightNames[i];
this.addState (name, initialWeight, finalWeight, destinationNames, labelNames, newWeightNames);
}
/** Default gives separate parameters to each transition. */
public void addState (String name, double initialWeight, double finalWeight,
String[] destinationNames,
String[] labelNames)
{
assert (destinationNames.length == labelNames.length);
String[] weightNames = new String[labelNames.length];
for (int i = 0; i < labelNames.length; i++)
weightNames[i] = name + "->" + destinationNames[i] + ":" + labelNames[i];
this.addState (name, initialWeight, finalWeight, destinationNames, labelNames, weightNames);
}
/** Add a state with parameters equal zero, and labels on out-going arcs
the same name as their destination state names. */
public void addState (String name, String[] destinationNames)
{
this.addState (name, 0, 0, destinationNames, destinationNames);
}
/** Add a group of states that are fully connected with each other,
* with parameters equal zero, and labels on their out-going arcs
* the same name as their destination state names. */
public void addFullyConnectedStates (String[] stateNames)
{
for (int i = 0; i < stateNames.length; i++)
addState (stateNames[i], stateNames);
}
public void addFullyConnectedStatesForLabels ()
{
String[] labels = new String[outputAlphabet.size()];
// This is assuming the the entries in the outputAlphabet are Strings!
for (int i = 0; i < outputAlphabet.size(); i++) {
logger.fine ("CRF: outputAlphabet.lookup class = "+
outputAlphabet.lookupObject(i).getClass().getName());
labels[i] = (String) outputAlphabet.lookupObject(i);
}
addFullyConnectedStates (labels);
}
public void addStartState ()
{
addStartState ("");
}
public void addStartState (String name)
{
for (int i = 0; i < numStates (); i++)
parameters.initialWeights[i] = IMPOSSIBLE_WEIGHT;
String[] dests = new String [numStates()];
for (int i = 0; i < dests.length; i++)
dests[i] = getState(i).getName();
addState (name, 0, 0.0, dests, dests); // initialWeight of 0.0
}
public void setAsStartState (State state)
{
for (int i = 0; i < numStates(); i++) {
Transducer.State other = getState (i);
if (other == state) {
other.setInitialWeight (0);
} else {
other.setInitialWeight (IMPOSSIBLE_WEIGHT);
}
}
weightsValueChanged();
}
private boolean[][] labelConnectionsIn (InstanceList trainingSet)
{
return labelConnectionsIn (trainingSet, null);
}
private boolean[][] labelConnectionsIn (InstanceList trainingSet, String start)
{
int numLabels = outputAlphabet.size();
boolean[][] connections = new boolean[numLabels][numLabels];
for (int i = 0; i < trainingSet.size(); i++) {
Instance instance = trainingSet.get(i);
FeatureSequence output = (FeatureSequence) instance.getTarget();
for (int j = 1; j < output.size(); j++) {
int sourceIndex = outputAlphabet.lookupIndex (output.get(j-1));
int destIndex = outputAlphabet.lookupIndex (output.get(j));
assert (sourceIndex >= 0 && destIndex >= 0);
connections[sourceIndex][destIndex] = true;
}
}
// Handle start state
if (start != null) {
int startIndex = outputAlphabet.lookupIndex (start);
for (int j = 0; j < outputAlphabet.size(); j++) {
connections[startIndex][j] = true;
}
}
return connections;
}
/**
* Add states to create a first-order Markov model on labels, adding only
* those transitions the occur in the given trainingSet.
*/
public void addStatesForLabelsConnectedAsIn (InstanceList trainingSet)
{
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn (trainingSet);
for (int i = 0; i < numLabels; i++) {
int numDestinations = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j]) numDestinations++;
String[] destinationNames = new String[numDestinations];
int destinationIndex = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j])
destinationNames[destinationIndex++] = (String)outputAlphabet.lookupObject(j);
addState ((String)outputAlphabet.lookupObject(i), destinationNames);
}
}
/**
* Add as many states as there are labels, but don't create separate weights
* for each source-destination pair of states. Instead have all the incoming
* transitions to a state share the same weights.
*/
public void addStatesForHalfLabelsConnectedAsIn (InstanceList trainingSet)
{
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn (trainingSet);
for (int i = 0; i < numLabels; i++) {
int numDestinations = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j]) numDestinations++;
String[] destinationNames = new String[numDestinations];
int destinationIndex = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j])
destinationNames[destinationIndex++] = (String)outputAlphabet.lookupObject(j);
addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0,
destinationNames, destinationNames, destinationNames);
}
}
/**
* Add as many states as there are labels, but don't create separate
* observational-test-weights for each source-destination pair of
* states---instead have all the incoming transitions to a state share the
* same observational-feature-test weights. However, do create separate
* default feature for each transition, (which acts as an HMM-style transition
* probability).
*/
public void addStatesForThreeQuarterLabelsConnectedAsIn (InstanceList trainingSet)
{
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn (trainingSet);
for (int i = 0; i < numLabels; i++) {
int numDestinations = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j]) numDestinations++;
String[] destinationNames = new String[numDestinations];
String[][] weightNames = new String[numDestinations][];
int destinationIndex = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j]) {
String labelName = (String)outputAlphabet.lookupObject(j);
destinationNames[destinationIndex] = labelName;
weightNames[destinationIndex] = new String[2];
// The "half-labels" will include all observed tests
weightNames[destinationIndex][0] = labelName;
// The "transition" weights will include only the default feature
String wn = (String)outputAlphabet.lookupObject(i) + "->" + (String)outputAlphabet.lookupObject(j);
weightNames[destinationIndex][1] = wn;
int wi = getWeightsIndex (wn);
// A new empty FeatureSelection won't allow any features here, so we only
// get the default feature for transitions
featureSelections[wi] = new FeatureSelection(trainingSet.getDataAlphabet());
destinationIndex++;
}
addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0,
destinationNames, destinationNames, weightNames);
}
}
public void addFullyConnectedStatesForThreeQuarterLabels (InstanceList trainingSet)
{
int numLabels = outputAlphabet.size();
for (int i = 0; i < numLabels; i++) {
String[] destinationNames = new String[numLabels];
String[][] weightNames = new String[numLabels][];
for (int j = 0; j < numLabels; j++) {
String labelName = (String)outputAlphabet.lookupObject(j);
destinationNames[j] = labelName;
weightNames[j] = new String[2];
// The "half-labels" will include all observational tests
weightNames[j][0] = labelName;
// The "transition" weights will include only the default feature
String wn = (String)outputAlphabet.lookupObject(i) + "->" + (String)outputAlphabet.lookupObject(j);
weightNames[j][1] = wn;
int wi = getWeightsIndex (wn);
// A new empty FeatureSelection won't allow any features here, so we only
// get the default feature for transitions
featureSelections[wi] = new FeatureSelection(trainingSet.getDataAlphabet());
}
addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0,
destinationNames, destinationNames, weightNames);
}
}
public void addFullyConnectedStatesForBiLabels ()
{
String[] labels = new String[outputAlphabet.size()];
// This is assuming the the entries in the outputAlphabet are Strings!
for (int i = 0; i < outputAlphabet.size(); i++) {
logger.fine ("CRF: outputAlphabet.lookup class = "+
outputAlphabet.lookupObject(i).getClass().getName());
labels[i] = (String) outputAlphabet.lookupObject(i);
}
for (int i = 0; i < labels.length; i++) {
for (int j = 0; j < labels.length; j++) {
String[] destinationNames = new String[labels.length];
for (int k = 0; k < labels.length; k++)
destinationNames[k] = labels[j]+LABEL_SEPARATOR+labels[k];
addState (labels[i]+LABEL_SEPARATOR+labels[j], 0.0, 0.0,
destinationNames, labels);
}
}
}
/**
* Add states to create a second-order Markov model on labels, adding only
* those transitions the occur in the given trainingSet.
*/
public void addStatesForBiLabelsConnectedAsIn (InstanceList trainingSet)
{
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn (trainingSet);
for (int i = 0; i < numLabels; i++) {
for (int j = 0; j < numLabels; j++) {
if (!connections[i][j])
continue;
int numDestinations = 0;
for (int k = 0; k < numLabels; k++)
if (connections[j][k]) numDestinations++;
String[] destinationNames = new String[numDestinations];
String[] labels = new String[numDestinations];
int destinationIndex = 0;
for (int k = 0; k < numLabels; k++)
if (connections[j][k]) {
destinationNames[destinationIndex] =
(String)outputAlphabet.lookupObject(j)+LABEL_SEPARATOR+(String)outputAlphabet.lookupObject(k);
labels[destinationIndex] = (String)outputAlphabet.lookupObject(k);
destinationIndex++;
}
addState ((String)outputAlphabet.lookupObject(i)+LABEL_SEPARATOR+
(String)outputAlphabet.lookupObject(j), 0.0, 0.0,
destinationNames, labels);
}
}
}
public void addFullyConnectedStatesForTriLabels ()
{
String[] labels = new String[outputAlphabet.size()];
// This is assuming the the entries in the outputAlphabet are Strings!
for (int i = 0; i < outputAlphabet.size(); i++) {
logger.fine ("CRF: outputAlphabet.lookup class = "+
outputAlphabet.lookupObject(i).getClass().getName());
labels[i] = (String) outputAlphabet.lookupObject(i);
}
for (int i = 0; i < labels.length; i++) {
for (int j = 0; j < labels.length; j++) {
for (int k = 0; k < labels.length; k++) {
String[] destinationNames = new String[labels.length];
for (int l = 0; l < labels.length; l++)
destinationNames[l] = labels[j]+LABEL_SEPARATOR+labels[k]+LABEL_SEPARATOR+labels[l];
addState (labels[i]+LABEL_SEPARATOR+labels[j]+LABEL_SEPARATOR+labels[k], 0.0, 0.0,
destinationNames, labels);
}
}
}
}
public void addSelfTransitioningStateForAllLabels (String name)
{
String[] labels = new String[outputAlphabet.size()];
String[] destinationNames = new String[outputAlphabet.size()];
// This is assuming the the entries in the outputAlphabet are Strings!
for (int i = 0; i < outputAlphabet.size(); i++) {
logger.fine ("CRF: outputAlphabet.lookup class = "+
outputAlphabet.lookupObject(i).getClass().getName());
labels[i] = (String) outputAlphabet.lookupObject(i);
destinationNames[i] = name;
}
addState (name, 0.0, 0.0, destinationNames, labels);
}
private String concatLabels(String[] labels)
{
String sep = "";
StringBuffer buf = new StringBuffer();
for (int i = 0; i < labels.length; i++)
{
buf.append(sep).append(labels[i]);
sep = LABEL_SEPARATOR;
}
return buf.toString();
}
private String nextKGram(String[] history, int k, String next)
{
String sep = "";
StringBuffer buf = new StringBuffer();
int start = history.length + 1 - k;
for (int i = start; i < history.length; i++)
{
buf.append(sep).append(history[i]);
sep = LABEL_SEPARATOR;
}
buf.append(sep).append(next);
return buf.toString();
}
private boolean allowedTransition(String prev, String curr,
Pattern no, Pattern yes)
{
String pair = concatLabels(new String[]{prev, curr});
if (no != null && no.matcher(pair).matches())
return false;
if (yes != null && !yes.matcher(pair).matches())
return false;
return true;
}
private boolean allowedHistory(String[] history, Pattern no, Pattern yes) {
for (int i = 1; i < history.length; i++)
if (!allowedTransition(history[i-1], history[i], no, yes))
return false;
return true;
}
/**
* Assumes that the CRF's output alphabet contains
* String
s. Creates an order-n CRF with input
* predicates and output labels given by trainingSet
* and order, connectivity, and weights given by the remaining
* arguments.
*
* @param trainingSet the training instances
* @param orders an array of increasing non-negative numbers giving
* the orders of the features for this CRF. The largest number
* n is the Markov order of the CRF. States are
* n-tuples of output labels. Each of the other numbers
* k in orders
represents a weight set shared
* by all destination states whose last (most recent) k
* labels agree. If orders
is null
, an
* order-0 CRF is built.
* @param defaults If non-null, it must be the same length as
* orders
, with true
positions indicating
* that the weight set for the corresponding order contains only the
* weight for a default feature; otherwise, the weight set has
* weights for all features built from input predicates.
* @param start The label that represents the context of the start of
* a sequence. It may be also used for sequence labels. If no label of
* this name exists, one will be added. Connection wills be added between
* the start label and all other labels, even if fullyConnected is
* false. This argument may be null, in which case no special
* start state is added.
* @param forbidden If non-null, specifies what pairs of successive
* labels are not allowed, both for constructing norder
* states or for transitions. A label pair (u,v)
* is not allowed if u + "," + v matches
* forbidden
.
* @param allowed If non-null, specifies what pairs of successive
* labels are allowed, both for constructing norder
* states or for transitions. A label pair (u,v)
* is allowed only if u + "," + v matches
* allowed
.
* @param fullyConnected Whether to include all allowed transitions,
* even those not occurring in trainingSet
,
* @return The name of the start state.
*
*/
public String addOrderNStates(InstanceList trainingSet, int[] orders,
boolean[] defaults, String start,
Pattern forbidden, Pattern allowed,
boolean fullyConnected)
{
boolean[][] connections = null;
if (start != null)
outputAlphabet.lookupIndex (start);
if (!fullyConnected)
connections = labelConnectionsIn (trainingSet, start);
int order = -1;
if (defaults != null && defaults.length != orders.length)
throw new IllegalArgumentException("Defaults must be null or match orders");
if (orders == null)
order = 0;
else
{
for (int i = 0; i < orders.length; i++) {
if (orders[i] <= order)
throw new IllegalArgumentException("Orders must be non-negative and in ascending order");
order = orders[i];
}
if (order < 0) order = 0;
}
if (order > 0)
{
int[] historyIndexes = new int[order];
String[] history = new String[order];
String label0 = (String)outputAlphabet.lookupObject(0);
for (int i = 0; i < order; i++)
history[i] = label0;
int numLabels = outputAlphabet.size();
while (historyIndexes[0] < numLabels)
{
logger.fine("Preparing " + concatLabels(history));
if (allowedHistory(history, forbidden, allowed))
{
String stateName = concatLabels(history);
int nt = 0;
String[] destNames = new String[numLabels];
String[] labelNames = new String[numLabels];
String[][] weightNames = new String[numLabels][orders.length];
for (int nextIndex = 0; nextIndex < numLabels; nextIndex++)
{
String next = (String)outputAlphabet.lookupObject(nextIndex);
if (allowedTransition(history[order-1], next, forbidden, allowed)
&& (fullyConnected ||
connections[historyIndexes[order-1]][nextIndex]))
{
destNames[nt] = nextKGram(history, order, next);
labelNames[nt] = next;
for (int i = 0; i < orders.length; i++)
{
weightNames[nt][i] = nextKGram(history, orders[i]+1, next);
if (defaults != null && defaults[i]) {
int wi = getWeightsIndex (weightNames[nt][i]);
// Using empty feature selection gives us only the
// default features
featureSelections[wi] =
new FeatureSelection(trainingSet.getDataAlphabet());
}
}
nt++;
}
}
if (nt < numLabels)
{
String[] newDestNames = new String[nt];
String[] newLabelNames = new String[nt];
String[][] newWeightNames = new String[nt][];
for (int t = 0; t < nt; t++)
{
newDestNames[t] = destNames[t];
newLabelNames[t] = labelNames[t];
newWeightNames[t] = weightNames[t];
}
destNames = newDestNames;
labelNames = newLabelNames;
weightNames = newWeightNames;
}
for (int i = 0; i < destNames.length; i++)
{
StringBuffer b = new StringBuffer();
for (int j = 0; j < orders.length; j++)
b.append(" ").append(weightNames[i][j]);
logger.fine(stateName + "->" + destNames[i] +
"(" + labelNames[i] + ")" + b.toString());
}
addState (stateName, 0.0, 0.0, destNames, labelNames, weightNames);
}
for (int o = order-1; o >= 0; o--)
if (++historyIndexes[o] < numLabels)
{
history[o] = (String)outputAlphabet.lookupObject(historyIndexes[o]);
break;
} else if (o > 0)
{
historyIndexes[o] = 0;
history[o] = label0;
}
}
for (int i = 0; i < order; i++)
history[i] = start;
return concatLabels(history);
}
String[] stateNames = new String[outputAlphabet.size()];
for (int s = 0; s < outputAlphabet.size(); s++)
stateNames[s] = (String)outputAlphabet.lookupObject(s);
for (int s = 0; s < outputAlphabet.size(); s++)
addState(stateNames[s], 0.0, 0.0, stateNames, stateNames, stateNames);
return start;
}
public State getState (String name)
{
return name2state.get(name);
}
public void setWeights (int weightsIndex, SparseVector transitionWeights)
{
weightsStructureChanged();
if (weightsIndex >= parameters.weights.length || weightsIndex < 0)
throw new IllegalArgumentException ("weightsIndex "+weightsIndex+" is out of bounds");
parameters.weights[weightsIndex] = transitionWeights;
}
public void setWeights (String weightName, SparseVector transitionWeights)
{
setWeights (getWeightsIndex (weightName), transitionWeights);
}
public String getWeightsName (int weightIndex)
{
return (String) parameters.weightAlphabet.lookupObject (weightIndex);
}
public SparseVector getWeights (String weightName)
{
return parameters.weights[getWeightsIndex (weightName)];
}
public SparseVector getWeights (int weightIndex)
{
return parameters.weights[weightIndex];
}
public double[] getDefaultWeights () {
return parameters.defaultWeights;
}
public SparseVector[] getWeights () {
return parameters.weights;
}
public void setWeights (SparseVector[] m) {
weightsStructureChanged();
parameters.weights = m;
}
public void setDefaultWeights (double[] w) {
weightsStructureChanged();
parameters.defaultWeights = w;
}
public void setDefaultWeight (int widx, double val) {
weightsValueChanged();
parameters.defaultWeights[widx] = val;
}
// Support for making cc.mallet.optimize.Optimizable CRFs
public boolean isWeightsFrozen (int weightsIndex)
{
return parameters.weightsFrozen [weightsIndex];
}
/**
* Freezes a set of weights to their current values.
* Frozen weights are used for labeling sequences (as in transduce),
* but are not be modified by the train methods.
* @param weightsIndex Index of weight set to freeze.
*/
public void freezeWeights (int weightsIndex)
{
parameters.weightsFrozen [weightsIndex] = true;
}
/**
* Freezes a set of weights to their current values.
* Frozen weights are used for labeling sequences (as in transduce),
* but are not be modified by the train methods.
* @param weightsName Name of weight set to freeze.
*/
public void freezeWeights (String weightsName)
{
int widx = getWeightsIndex (weightsName);
freezeWeights (widx);
}
/**
* Unfreezes a set of weights.
* Frozen weights are used for labeling sequences (as in transduce),
* but are not be modified by the train methods.
* @param weightsName Name of weight set to unfreeze.
*/
public void unfreezeWeights (String weightsName)
{
int widx = getWeightsIndex (weightsName);
parameters.weightsFrozen[widx] = false;
}
public void setFeatureSelection (int weightIdx, FeatureSelection fs)
{
featureSelections [weightIdx] = fs;
weightsStructureChanged(); // Is this necessary? -akm 11/2007
}
public void setWeightsDimensionAsIn (InstanceList trainingData) {
setWeightsDimensionAsIn(trainingData, false);
}
// gsc: changing this to consider the case when trainingData is a mix of labeled and unlabeled data,
// and we want to use the unlabeled data as well to set some weights (while using the unsupported trick)
// *note*: 'target' sequence of an unlabeled instance is either null or is of size zero.
public void setWeightsDimensionAsIn (InstanceList trainingData, boolean useSomeUnsupportedTrick)
{
final BitSet[] weightsPresent;
int numWeights = 0;
// The value doesn't actually change, because the "new" parameters will have zero value
// but the gradient changes because the parameters now have different layout.
weightsStructureChanged();
weightsPresent = new BitSet[parameters.weights.length];
for (int i = 0; i < parameters.weights.length; i++)
weightsPresent[i] = new BitSet();
// Put in the weights that are already there
for (int i = 0; i < parameters.weights.length; i++)
for (int j = parameters.weights[i].numLocations()-1; j >= 0; j--)
weightsPresent[i].set (parameters.weights[i].indexAtLocation(j));
// Put in the weights in the training set
for (int i = 0; i < trainingData.size(); i++) {
Instance instance = trainingData.get(i);
FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
FeatureSequence output = (FeatureSequence) instance.getTarget();
// gsc: trainingData can have unlabeled instances as well
if (output != null && output.size() > 0) {
// Do it for the paths consistent with the labels...
sumLatticeFactory.newSumLattice (this, input, output, new Transducer.Incrementor() {
public void incrementTransition (Transducer.TransitionIterator ti, double count) {
State source = (CRF.State)ti.getSourceState();
FeatureVector input = (FeatureVector)ti.getInput();
int index = ti.getIndex();
int nwi = source.weightsIndices[index].length;
for (int wi = 0; wi < nwi; wi++) {
int weightsIndex = source.weightsIndices[index][wi];
for (int i = 0; i < input.numLocations(); i++) {
int featureIndex = input.indexAtLocation(i);
if ((globalFeatureSelection == null || globalFeatureSelection.contains(featureIndex))
&& (featureSelections == null
|| featureSelections[weightsIndex] == null
|| featureSelections[weightsIndex].contains(featureIndex)))
weightsPresent[weightsIndex].set (featureIndex);
}
}
}
public void incrementInitialState (Transducer.State s, double count) { }
public void incrementFinalState (Transducer.State s, double count) { }
});
}
// ...and also do it for the paths selected by the current model (so we will get some negative weights)
if (useSomeUnsupportedTrick && this.getParametersAbsNorm() > 0) {
if (i == 0)
logger.fine ("CRF: Incremental training detected. Adding weights for some unsupported features...");
// (do this once some training is done)
sumLatticeFactory.newSumLattice (this, input, null, new Transducer.Incrementor() {
public void incrementTransition (Transducer.TransitionIterator ti, double count) {
if (count < 0.2) // Only create features for transitions with probability above 0.2
return; // This 0.2 is somewhat arbitrary -akm
State source = (CRF.State)ti.getSourceState();
FeatureVector input = (FeatureVector)ti.getInput();
int index = ti.getIndex();
int nwi = source.weightsIndices[index].length;
for (int wi = 0; wi < nwi; wi++) {
int weightsIndex = source.weightsIndices[index][wi];
for (int i = 0; i < input.numLocations(); i++) {
int featureIndex = input.indexAtLocation(i);
if ((globalFeatureSelection == null || globalFeatureSelection.contains(featureIndex))
&& (featureSelections == null
|| featureSelections[weightsIndex] == null
|| featureSelections[weightsIndex].contains(featureIndex)))
weightsPresent[weightsIndex].set (featureIndex);
}
}
}
public void incrementInitialState (Transducer.State s, double count) { }
public void incrementFinalState (Transducer.State s, double count) { }
});
}
}
SparseVector[] newWeights = new SparseVector[parameters.weights.length];
for (int i = 0; i < parameters.weights.length; i++) {
int numLocations = weightsPresent[i].cardinality ();
logger.fine ("CRF weights["+parameters.weightAlphabet.lookupObject(i)+"] num features = "+numLocations);
int[] indices = new int[numLocations];
for (int j = 0; j < numLocations; j++) {
indices[j] = weightsPresent[i].nextSetBit (j == 0 ? 0 : indices[j-1]+1);
//System.out.println ("CRF4 has index "+indices[j]);
}
newWeights[i] = new IndexedSparseVector (indices, new double[numLocations],
numLocations, numLocations, false, false, false);
newWeights[i].plusEqualsSparse (parameters.weights[i]); // Put in the previous weights
numWeights += (numLocations + 1);
}
logger.info("Number of weights = "+numWeights);
parameters.weights = newWeights;
}
public void setWeightsDimensionDensely ()
{
weightsStructureChanged();
SparseVector[] newWeights = new SparseVector [parameters.weights.length];
int max = inputAlphabet.size();
int numWeights = 0;
logger.info ("CRF using dense weights, num input features = "+max);
for (int i = 0; i < parameters.weights.length; i++) {
int nfeatures;
if (featureSelections[i] == null) {
nfeatures = max;
newWeights [i] = new SparseVector (null, new double [max],
max, max, false, false, false);
} else {
// Respect the featureSelection
FeatureSelection fs = featureSelections[i];
nfeatures = fs.getBitSet ().cardinality ();
int[] idxs = new int [nfeatures];
int j = 0, thisIdx = -1;
while ((thisIdx = fs.nextSelectedIndex (thisIdx + 1)) >= 0) {
idxs[j++] = thisIdx;
}
newWeights[i] = new IndexedSparseVector (idxs, new double [nfeatures], nfeatures, nfeatures, false, false, false);
}
newWeights [i].plusEqualsSparse (parameters.weights [i]);
numWeights += (nfeatures + 1);
}
logger.info("Number of weights = "+numWeights);
parameters.weights = newWeights;
}
// Create a new weight Vector if weightName is new.
public int getWeightsIndex (String weightName)
{
int wi = parameters.weightAlphabet.lookupIndex (weightName);
if (wi == -1)
throw new IllegalArgumentException ("Alphabet frozen, and no weight with name "+ weightName);
if (parameters.weights == null) {
assert (wi == 0);
parameters.weights = new SparseVector[1];
parameters.defaultWeights = new double[1];
featureSelections = new FeatureSelection[1];
parameters.weightsFrozen = new boolean [1];
// Use initial capacity of 8
parameters.weights[0] = new IndexedSparseVector ();
parameters.defaultWeights[0] = 0;
featureSelections[0] = null;
weightsStructureChanged();
} else if (wi == parameters.weights.length) {
SparseVector[] newWeights = new SparseVector[parameters.weights.length+1];
double[] newDefaultWeights = new double[parameters.weights.length+1];
FeatureSelection[] newFeatureSelections = new FeatureSelection[parameters.weights.length+1];
for (int i = 0; i < parameters.weights.length; i++) {
newWeights[i] = parameters.weights[i];
newDefaultWeights[i] = parameters.defaultWeights[i];
newFeatureSelections[i] = featureSelections[i];
}
newWeights[wi] = new IndexedSparseVector ();
newDefaultWeights[wi] = 0;
newFeatureSelections[wi] = null;
parameters.weights = newWeights;
parameters.defaultWeights = newDefaultWeights;
featureSelections = newFeatureSelections;
parameters.weightsFrozen = ArrayUtils.append (parameters.weightsFrozen, false);
weightsStructureChanged();
}
//setTrainable (false);
return wi;
}
private void assertWeightsLength ()
{
if (parameters.weights != null) {
assert parameters.defaultWeights != null;
assert featureSelections != null;
assert parameters.weightsFrozen != null;
int n = parameters.weights.length;
assert parameters.defaultWeights.length == n;
assert featureSelections.length == n;
assert parameters.weightsFrozen.length == n;
}
}
public int numStates () { return states.size(); }
public Transducer.State getState (int index) {
return states.get(index); }
public Iterator initialStateIterator () {
return initialStates.iterator (); }
public boolean isTrainable () { return true; }
// gsc: accessor methods
public int getWeightsValueChangeStamp() {
return weightsValueChangeStamp;
}
// kedar: access structure stamp method
public int getWeightsStructureChangeStamp() {
return weightsStructureChangeStamp;
}
public Factors getParameters ()
{
return parameters;
}
// gsc
public double getParametersAbsNorm ()
{
double ret = 0;
for (int i = 0; i < numStates(); i++) {
ret += Math.abs (parameters.initialWeights[i]);
ret += Math.abs (parameters.finalWeights[i]);
}
for (int i = 0; i < parameters.weights.length; i++) {
ret += Math.abs (parameters.defaultWeights[i]);
ret += parameters.weights[i].absNorm();
}
return ret;
}
/** Only sets the parameter from the first group of parameters. */
public void setParameter (int sourceStateIndex, int destStateIndex, int featureIndex, double value)
{
setParameter(sourceStateIndex, destStateIndex, featureIndex, 0, value);
}
public void setParameter (int sourceStateIndex, int destStateIndex, int featureIndex, int weightIndex, double value)
{
weightsValueChanged();
State source = (State)getState(sourceStateIndex);
State dest = (State) getState(destStateIndex);
int rowIndex;
for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++)
if (source.destinationNames[rowIndex].equals (dest.name))
break;
if (rowIndex == source.destinationNames.length)
throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+".");
int weightsIndex = source.weightsIndices[rowIndex][weightIndex];
if (featureIndex < 0)
parameters.defaultWeights[weightsIndex] = value;
else {
parameters.weights[weightsIndex].setValue (featureIndex, value);
}
}
/** Only gets the parameter from the first group of parameters. */
public double getParameter (int sourceStateIndex, int destStateIndex, int featureIndex)
{
return getParameter(sourceStateIndex,destStateIndex,featureIndex,0);
}
public double getParameter (int sourceStateIndex, int destStateIndex, int featureIndex, int weightIndex)
{
State source = (State)getState(sourceStateIndex);
State dest = (State) getState(destStateIndex);
int rowIndex;
for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++)
if (source.destinationNames[rowIndex].equals (dest.name))
break;
if (rowIndex == source.destinationNames.length)
throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+".");
int weightsIndex = source.weightsIndices[rowIndex][weightIndex];
if (featureIndex < 0)
return parameters.defaultWeights[weightsIndex];
return parameters.weights[weightsIndex].value (featureIndex);
}
public int getNumParameters () {
if (cachedNumParametersStamp != weightsStructureChangeStamp) {
this.numParameters = 2 * this.numStates() + this.parameters.defaultWeights.length;
for (int i = 0; i < parameters.weights.length; i++)
numParameters += parameters.weights[i].numLocations();
}
return this.numParameters;
}
/** This method is deprecated. */
// But it is here as a reminder to do something about induceFeaturesFor(). */
@Deprecated
public Sequence[] predict (InstanceList testing) {
testing.setFeatureSelection(this.globalFeatureSelection);
for (int i = 0; i < featureInducers.size(); i++) {
FeatureInducer klfi = (FeatureInducer)featureInducers.get(i);
klfi.induceFeaturesFor (testing, false, false);
}
Sequence[] ret = new Sequence[testing.size()];
for (int i = 0; i < testing.size(); i++) {
Instance instance = testing.get(i);
Sequence input = (Sequence) instance.getData();
Sequence trueOutput = (Sequence) instance.getTarget();
assert (input.size() == trueOutput.size());
Sequence predOutput = new MaxLatticeDefault(this, input).bestOutputSequence();
assert (predOutput.size() == trueOutput.size());
ret[i] = predOutput;
}
return ret;
}
/** This method is deprecated. */
@Deprecated
public void evaluate (TransducerEvaluator eval, InstanceList testing) {
throw new IllegalStateException ("This method is no longer usable. Use CRF.induceFeaturesFor() instead.");
/*
testing.setFeatureSelection(this.globalFeatureSelection);
for (int i = 0; i < featureInducers.size(); i++) {
FeatureInducer klfi = (FeatureInducer)featureInducers.get(i);
klfi.induceFeaturesFor (testing, false, false);
}
eval.evaluate (this, true, 0, true, 0.0, null, null, testing);
*/
}
/** When the CRF has done feature induction, these new feature conjunctions must be
* created in the test or validation data in order for them to take effect. */
public void induceFeaturesFor (InstanceList instances) {
instances.setFeatureSelection(this.globalFeatureSelection);
for (int i = 0; i < featureInducers.size(); i++) {
FeatureInducer klfi = featureInducers.get(i);
klfi.induceFeaturesFor (instances, false, false);
}
}
// TODO Put support to Optimizable here, including getValue(InstanceList)??
public void print ()
{
print (new PrintWriter (new OutputStreamWriter (System.out), true));
}
public void print (PrintWriter out)
{
out.println ("*** CRF STATES ***");
for (int i = 0; i < numStates (); i++) {
State s = (State) getState (i);
out.print ("STATE NAME=\"");
out.print (s.name); out.print ("\" ("); out.print (s.destinations.length); out.print (" outgoing transitions)\n");
out.print (" "); out.print ("initialWeight = "); out.print (parameters.initialWeights[i]); out.print ('\n');
out.print (" "); out.print ("finalWeight = "); out.print (parameters.finalWeights[i]); out.print ('\n');
out.println (" transitions:");
for (int j = 0; j < s.destinations.length; j++) {
out.print (" "); out.print (s.name); out.print (" -> "); out.println (s.getDestinationState (j).getName ());
for (int k = 0; k < s.weightsIndices[j].length; k++) {
out.print (" WEIGHTS = \"");
int widx = s.weightsIndices[j][k];
out.print (parameters.weightAlphabet.lookupObject (widx).toString ());
out.print ("\"\n");
}
}
out.println ();
}
if (parameters.weights == null)
out.println ("\n\n*** NO WEIGHTS ***");
else {
out.println ("\n\n*** CRF WEIGHTS ***");
for (int widx = 0; widx < parameters.weights.length; widx++) {
out.println ("WEIGHTS NAME = " + parameters.weightAlphabet.lookupObject (widx));
out.print (": = "); out.print (parameters.defaultWeights[widx]); out.print ('\n');
SparseVector transitionWeights = parameters.weights[widx];
if (transitionWeights.numLocations () == 0)
continue;
RankedFeatureVector rfv = new RankedFeatureVector (inputAlphabet, transitionWeights);
for (int m = 0; m < rfv.numLocations (); m++) {
double v = rfv.getValueAtRank (m);
//int index = rfv.indexAtLocation (rfv.getIndexAtRank (m)); // This doesn't make any sense. How did this ever work? -akm 12/2007
int index = rfv.getIndexAtRank (m);
Object feature = inputAlphabet.lookupObject (index);
if (v != 0) {
out.print (": "); out.print (feature); out.print (" = "); out.println (v);
}
}
}
}
out.flush ();
}
public void write (File f) {
try {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));
oos.writeObject(this);
oos.close();
}
catch (IOException e) {
System.err.println("Exception writing file " + f + ": " + e);
}
}
// gsc: Serialization for CRF class
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject (inputAlphabet);
out.writeObject (outputAlphabet);
out.writeObject (states);
out.writeObject (initialStates);
out.writeObject (name2state);
out.writeObject (parameters);
out.writeObject (globalFeatureSelection);
out.writeObject (featureSelections);
out.writeObject (featureInducers);
out.writeInt (weightsValueChangeStamp);
out.writeInt (weightsStructureChangeStamp);
out.writeInt (cachedNumParametersStamp);
out.writeInt (numParameters);
}
@SuppressWarnings("unchecked")
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.readInt ();
inputAlphabet = (Alphabet) in.readObject ();
outputAlphabet = (Alphabet) in.readObject ();
states = (ArrayList) in.readObject ();
initialStates = (ArrayList) in.readObject ();
name2state = (HashMap) in.readObject ();
parameters = (Factors) in.readObject ();
globalFeatureSelection = (FeatureSelection) in.readObject ();
featureSelections = (FeatureSelection[]) in.readObject ();
featureInducers = (ArrayList) in.readObject ();
weightsValueChangeStamp = in.readInt ();
weightsStructureChangeStamp = in.readInt ();
cachedNumParametersStamp = in.readInt ();
numParameters = in.readInt ();
}
// Why is this "static"? Couldn't it be a non-static inner class? (In Transducer also) -akm 12/2007
public static class State extends Transducer.State implements Serializable
{
// Parameters indexed by destination state, feature index
String name;
int index;
String[] destinationNames;
State[] destinations; // N.B. elements are null until getDestinationState(int) is called
int[][] weightsIndices; // contains indices into CRF.weights[],
String[] labels;
CRF crf;
// No arg constructor so serialization works
protected State() {
super ();
}
protected State (String name, int index,
double initialWeight, double finalWeight,
String[] destinationNames,
String[] labelNames,
String[][] weightNames,
CRF crf)
{
super ();
assert (destinationNames.length == labelNames.length);
assert (destinationNames.length == weightNames.length);
this.name = name;
this.index = index;
// Note: setting these parameters here is actually redundant; they were set already in CRF.addState(...)
// I'm considering removing initialWeight and finalWeight as arguments to this constructor, but need to think more -akm 12/2007
// If CRF.State were non-static, then this constructor could add the state to the list of states, and put it in the name2state also.
crf.parameters.initialWeights[index] = initialWeight;
crf.parameters.finalWeights[index] = finalWeight;
this.destinationNames = new String[destinationNames.length];
this.destinations = new State[labelNames.length];
this.weightsIndices = new int[labelNames.length][];
this.labels = new String[labelNames.length];
this.crf = crf;
for (int i = 0; i < labelNames.length; i++) {
// Make sure this label appears in our output Alphabet
crf.outputAlphabet.lookupIndex (labelNames[i]);
this.destinationNames[i] = destinationNames[i];
this.labels[i] = labelNames[i];
this.weightsIndices[i] = new int[weightNames[i].length];
for (int j = 0; j < weightNames[i].length; j++)
this.weightsIndices[i][j] = crf.getWeightsIndex (weightNames[i][j]);
}
crf.weightsStructureChanged();
}
public Transducer getTransducer () { return crf; }
public double getInitialWeight () { return crf.parameters.initialWeights[index]; }
public void setInitialWeight (double c) { crf.parameters.initialWeights[index]= c; }
public double getFinalWeight () { return crf.parameters.finalWeights[index]; }
public void setFinalWeight (double c) { crf.parameters.finalWeights[index] = c; }
public void print () {
System.out.println ("State #" + index + " \"" + name + "\"");
System.out.println ("initialWeight=" + crf.parameters.initialWeights[index] + ", finalWeight=" + crf.parameters.finalWeights[index]);
System.out.println ("#destinations=" + destinations.length);
for (int i = 0; i < destinations.length; i++) {
System.out.println ("-> " + destinationNames[i]);
}
}
public int numDestinations () { return destinations.length;}
public String[] getWeightNames (int index) {
int[] indices = this.weightsIndices[index];
String[] ret = new String[indices.length];
for (int i=0; i < ret.length; i++)
ret[i] = crf.parameters.weightAlphabet.lookupObject(indices[i]).toString();
return ret;
}
public void addWeight (int didx, String weightName) {
int widx = crf.getWeightsIndex (weightName);
weightsIndices[didx] = ArrayUtils.append (weightsIndices[didx], widx);
}
public String getLabelName (int index) {
return labels [index];
}
public State getDestinationState (int index)
{
State ret;
if ((ret = destinations[index]) == null) {
ret = destinations[index] = crf.name2state.get (destinationNames[index]);
if (ret == null)
throw new IllegalArgumentException ("this.name="+this.name+" index="+index+" destinationNames[index]="+destinationNames[index]+" name2state.size()="+ crf.name2state.size());
}
return ret;
}
public Transducer.TransitionIterator transitionIterator (Sequence inputSequence, int inputPosition,
Sequence outputSequence, int outputPosition)
{
if (inputPosition < 0 || outputPosition < 0)
throw new UnsupportedOperationException ("Epsilon transitions not implemented.");
if (inputSequence == null)
throw new UnsupportedOperationException ("CRFs are not generative models; must have an input sequence.");
return new TransitionIterator (this, (FeatureVectorSequence)inputSequence, inputPosition,
(outputSequence == null ? null : (String)outputSequence.get(outputPosition)), crf);
}
public Transducer.TransitionIterator transitionIterator (FeatureVector fv, String output)
{
return new TransitionIterator (this, fv, output, crf);
}
public String getName () { return name; }
// "final" to make it efficient inside incrementTransition
public final int getIndex () { return index; }
// Serialization
// For class State
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject(name);
out.writeInt(index);
out.writeObject(destinationNames);
out.writeObject(destinations);
out.writeObject(weightsIndices);
out.writeObject(labels);
out.writeObject(crf);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.readInt ();
name = (String) in.readObject();
index = in.readInt();
destinationNames = (String[]) in.readObject();
destinations = (CRF.State[]) in.readObject();
weightsIndices = (int[][]) in.readObject();
labels = (String[]) in.readObject();
crf = (CRF) in.readObject();
}
}
protected static class TransitionIterator extends Transducer.TransitionIterator implements Serializable
{
State source;
int index, nextIndex;
protected double[] weights;
FeatureVector input;
CRF crf;
public TransitionIterator (State source,
FeatureVectorSequence inputSeq,
int inputPosition,
String output, CRF crf)
{
this (source, inputSeq.get(inputPosition), output, crf);
}
protected TransitionIterator (State source,
FeatureVector fv,
String output, CRF crf)
{
this.source = source;
this.crf = crf;
this.input = fv;
this.weights = new double[source.destinations.length];
int nwi, swi;
for (int transIndex = 0; transIndex < source.destinations.length; transIndex++) {
// xxx Or do we want output.equals(...) here?
if (output == null || output.equals(source.labels[transIndex])) {
// Here is the dot product of the feature weights with the lambda weights
// for one transition
weights[transIndex] = 0;
nwi = source.weightsIndices[transIndex].length;
for (int wi = 0; wi < nwi; wi++) {
swi = source.weightsIndices[transIndex][wi];
weights[transIndex] += (crf.parameters.weights[swi].dotProduct (fv)
// include with implicit weight 1.0 the default feature
+ crf.parameters.defaultWeights[swi]);
}
assert (!Double.isNaN(weights[transIndex]));
assert (weights[transIndex] != Double.POSITIVE_INFINITY);
}
else
weights[transIndex] = IMPOSSIBLE_WEIGHT;
}
// Prepare nextIndex, pointing at the next non-impossible transition
nextIndex = 0;
while (nextIndex < source.destinations.length && weights[nextIndex] == IMPOSSIBLE_WEIGHT)
nextIndex++;
}
public boolean hasNext () { return nextIndex < source.destinations.length; }
public Transducer.State nextState ()
{
assert (nextIndex < source.destinations.length);
index = nextIndex;
nextIndex++;
while (nextIndex < source.destinations.length && weights[nextIndex] == IMPOSSIBLE_WEIGHT)
nextIndex++;
return source.getDestinationState (index);
}
// These "final"s are just to try to make this more efficient. Perhaps some of them will have to go away
public final int getIndex () { return index; }
public final Object getInput () { return input; }
public final Object getOutput () { return source.labels[index]; }
public final double getWeight () { return weights[index]; }
public final Transducer.State getSourceState () { return source; }
public final Transducer.State getDestinationState () { return source.getDestinationState (index); }
// Serialization
// TransitionIterator
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private static final int NULL_INTEGER = -1;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject (source);
out.writeInt (index);
out.writeInt (nextIndex);
out.writeObject(weights);
out.writeObject (input);
out.writeObject(crf);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.readInt ();
source = (State) in.readObject();
index = in.readInt ();
nextIndex = in.readInt ();
weights = (double[]) in.readObject();
input = (FeatureVector) in.readObject();
crf = (CRF) in.readObject();
}
public String describeTransition (double cutoff)
{
DecimalFormat f = new DecimalFormat ("0.###");
StringBuffer buf = new StringBuffer ();
buf.append ("Value: " + f.format (-getWeight ()) + "
\n");
try {
int[] theseWeights = source.weightsIndices[index];
for (int i = 0; i < theseWeights.length; i++) {
int wi = theseWeights[i];
SparseVector w = crf.parameters.weights[wi];
buf.append ("WEIGHTS
\n" + crf.parameters.weightAlphabet.lookupObject (wi) + "
\n");
buf.append (" d.p. = "+f.format (w.dotProduct (input))+"
\n");
double[] vals = new double[input.numLocations ()];
double[] absVals = new double[input.numLocations ()];
for (int k = 0; k < vals.length; k++) {
int index = input.indexAtLocation (k);
vals[k] = w.value (index) * input.value (index);
absVals[k] = Math.abs (vals[k]);
}
buf.append ("DEFAULT " + f.format (crf.parameters.defaultWeights[wi]) + "
\n");
RankedFeatureVector rfv = new RankedFeatureVector (crf.inputAlphabet, input.getIndices (), absVals);
for (int rank = 0; rank < absVals.length; rank++) {
int fidx = rfv.getIndexAtRank (rank);
Object fname = crf.inputAlphabet.lookupObject (input.indexAtLocation (fidx));
if (absVals[fidx] < cutoff) break; // Break looping over features
if (vals[fidx] != 0) {
buf.append (fname + " " + f.format (vals[fidx]) + "
\n");
}
}
}
} catch (Exception e) {
System.err.println ("Error writing transition descriptions.");
e.printStackTrace ();
buf.append ("ERROR WHILE WRITING OUTPUT...\n");
}
return buf.toString ();
}
}
}