All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.MaxLatticeDefault Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

/** 
@author Fernando Pereira [email protected]
@author Andrew McCallum [email protected]
*/
package cc.mallet.fst;



import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;

import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

import cc.mallet.types.ArraySequence;
import cc.mallet.types.Sequence;
import cc.mallet.types.SequencePairAlignment;

import cc.mallet.fst.Transducer.State;
import cc.mallet.fst.Transducer.TransitionIterator;

import cc.mallet.util.MalletLogger;
import cc.mallet.util.search.AStar;
import cc.mallet.util.search.AStarState;
import cc.mallet.util.search.SearchNode;
import cc.mallet.util.search.SearchState;

/** Default, full dynamic programming version of the Viterbi "Max-(Product)-Lattice" algorithm.
 * 
 * @author Fernando Pereira
 * @author Andrew McCallum 
 */
public class MaxLatticeDefault implements MaxLattice
{
	private static Logger logger = MalletLogger.getLogger(MaxLatticeDefault.class.getName());
	//{ logger.setLevel(Level.INFO); }

	private Transducer t;
	private Sequence input, providedOutput;
	private int latticeLength;
	private ViterbiNode[][] lattice;
	private WeightCache first, last;
	private WeightCache[] caches;
	private int numCaches, maxCaches;
	
	public Transducer getTransducer () { return t; }
	public Sequence getInput() { return input; }
	public Sequence getProvidedOutput() { return providedOutput; }

	private class ViterbiNode implements AStarState {
		int inputPosition;								// Position of input used to enter this node
		State state;											// Transducer state from which this node entered
		Object output;										// Transducer output produced on entering this node
		double delta = Transducer.IMPOSSIBLE_WEIGHT;
		ViterbiNode maxWeightPredecessor = null;
		ViterbiNode (int inputPosition, State state) {
			this.inputPosition = inputPosition;
			this.state = state;
		}
		// The one method required by AStarState
		public double completionCost () { return -delta; }
		public boolean isFinal() {
			return inputPosition == 0 && state.getInitialWeight() > Transducer.IMPOSSIBLE_WEIGHT;
		}
		private class PreviousStateIterator extends AStarState.NextStateIterator {
			private int prev;
			private boolean found;
			private double weight;
			private double[] weights;
			private PreviousStateIterator() {
				prev = 0;
				if (inputPosition > 0) {
					int j = state.getIndex();
					weights = new double[t.numStates()];
					WeightCache c = getCache(inputPosition-1);
					for (int s = 0; s < t.numStates(); s++)
						weights[s] = c.weight[s][j];
				}
			}
			private void lookAhead() {
				if (weights != null && !found) {
					for (; prev < t.numStates(); prev++)
						if (weights[prev] > Transducer.IMPOSSIBLE_WEIGHT) {
							found = true;
							return;
						}
				}
			}
			public boolean hasNext() {
				lookAhead();
				return weights != null && prev < t.numStates();
			}

			public SearchState nextState() {
				lookAhead();
				weight = weights[prev++];
				found = false;
				return getViterbiNode(inputPosition-1, prev-1);
			}

			// Required by SearchState, super-interface of AStarState
			public double cost() {
				return -weight;
			}
			public double weight() {
				return weight;
			}
		}

		public NextStateIterator getNextStates() {
			return new PreviousStateIterator();
		}
	}

	private class WeightCache {
		private WeightCache prev, next;
		private double weight[][];
		private int position;
		private WeightCache(int position) {
			weight = new double[t.numStates()][t.numStates()];
			init(position);
		}
		private void init(int position) {
			this.position = position;
			for (int i = 0; i < t.numStates(); i++)
				for (int j = 0; j < t.numStates(); j++)
					weight[i][j] = Transducer.IMPOSSIBLE_WEIGHT;
		}
	}

	private WeightCache getCache(int position) {
		WeightCache cache = caches[position];
		if (cache == null) {            // No cache for this position
//			System.out.println("cache " + numCaches + "/" + maxCaches);
			if (numCaches < maxCaches)  { // Create another cache
				cache = new WeightCache(position);
				if (numCaches++ == 0)
					first = last = cache;
			}
			else {                        // Steal least used cache
				cache = last;
				caches[cache.position] = null;
				cache.init(position);
			}
			for (int i = 0; i < t.numStates(); i++) {
				if (lattice[position][i] == null || lattice[position][i].delta == Transducer.IMPOSSIBLE_WEIGHT)
					continue;
				State s = t.getState(i);
				TransitionIterator iter =
					s.transitionIterator (input, position, providedOutput, position);
				while (iter.hasNext()) {
					State d = iter.next();
					cache.weight[i][d.getIndex()] = iter.getWeight();
				}
			}        
			caches[position] = cache;
		}
		if (cache != first) {           // Move to front
			if (cache == last)
				last = cache.prev;
			if (cache.prev != null)
				cache.prev.next = cache.next;
			cache.next = first;
			cache.prev = null;
			first.prev = cache;
			first = cache;
		}
		return cache;
	}

	protected ViterbiNode getViterbiNode (int ip, int stateIndex)
	{
		if (lattice[ip][stateIndex] == null)
			lattice[ip][stateIndex] = new ViterbiNode (ip, t.getState (stateIndex));
		return lattice[ip][stateIndex];
	}
	
	public MaxLatticeDefault (Transducer t, Sequence inputSequence) 
	{
		this (t, inputSequence, null, 100000);
	}
	
	public MaxLatticeDefault (Transducer t, Sequence inputSequence, Sequence outputSequence) 
	{
		this (t, inputSequence, outputSequence, 100000);
	}

	/** Initiate Viterbi decoding of the inputSequence, contrained to match non-null parts of the outputSequence.
	 * maxCaches indicates how much state information to memoize in n-best decoding. */
	public MaxLatticeDefault (Transducer t, Sequence inputSequence, Sequence outputSequence, int maxCaches) 
	{
		// This method initializes the forward path, but does not yet do the backward pass.
		this.t = t;
		if (maxCaches < 1)
			maxCaches = 1;
		this.maxCaches = maxCaches;
		assert (inputSequence != null);
		if (logger.isLoggable (Level.FINE)) {
			logger.fine ("Starting ViterbiLattice");
			logger.fine ("Input: ");
			for (int ip = 0; ip < inputSequence.size(); ip++)
				logger.fine (" " + inputSequence.get(ip));
			logger.fine ("\nOutput: ");
			if (outputSequence == null)
				logger.fine ("null");
			else
				for (int op = 0; op < outputSequence.size(); op++)
					logger.fine (" " + outputSequence.get(op));
			logger.fine ("\n");
		}

		this.input = inputSequence;
		this.providedOutput = outputSequence;
		latticeLength = input.size()+1;
		int numStates = t.numStates();
		lattice = new ViterbiNode[latticeLength][numStates];
		caches = new WeightCache[latticeLength-1];

		// Viterbi Forward
		logger.fine ("Starting Viterbi");
		boolean anyInitialState = false;
		for (int i = 0; i < numStates; i++) {
			double initialWeight = t.getState(i).getInitialWeight();
			if (initialWeight > Transducer.IMPOSSIBLE_WEIGHT) {
				ViterbiNode n = getViterbiNode (0, i);
				n.delta = initialWeight;
				anyInitialState = true;
			}
		}

		if (!anyInitialState) {
			logger.warning ("Viterbi: No initial states!");
		}

		for (int ip = 0; ip < latticeLength-1; ip++)
			for (int i = 0; i < numStates; i++) {
				if (lattice[ip][i] == null || lattice[ip][i].delta == Transducer.IMPOSSIBLE_WEIGHT)
					continue;
				State s = t.getState(i);
				TransitionIterator iter = s.transitionIterator (input, ip, providedOutput, ip);
				if (logger.isLoggable (Level.FINE))
					logger.fine (" Starting Viterbi transition iteration from state "
							+ s.getName() + " on input " + input.get(ip));
				while (iter.hasNext()) {
					State destination = iter.next();
					if (logger.isLoggable (Level.FINE))
						logger.fine ("Viterbi[inputPos="+ip
								+"][source="+s.getName()
								+"][dest="+destination.getName()+"]");
					ViterbiNode destinationNode = getViterbiNode (ip+1, destination.getIndex());
					destinationNode.output = iter.getOutput();
					double weight = lattice[ip][i].delta + iter.getWeight();
					if (ip == latticeLength-2) {
						weight += destination.getFinalWeight();
					}
					if (weight > destinationNode.delta) {
						if (logger.isLoggable (Level.FINE))
							logger.fine ("Viterbi[inputPos="+ip
									+"][source][dest="+destination.getName()
									+"] weight increased to "+weight+" by source="+
									s.getName());
						destinationNode.delta = weight;
						destinationNode.maxWeightPredecessor = lattice[ip][i];
					}
				}
			}
	}
	
	public double getDelta (int ip, int stateIndex) {
		if (lattice != null) {
			return getViterbiNode (ip, stateIndex).delta;
		}
    throw new RuntimeException ("Attempt to called getDelta() when lattice not stored.");
	}

	private List> viterbiNodeAlignmentCache = null;

	/**
   * Perform the backward pass of Viterbi, returning the n-best sequences of
   * ViterbiNodes. Each ViterbiNode contains the state, output symbol, and other
   * information. Note that the length of each ViterbiNode Sequence is
   * inputLength+1, because the first element of the sequence is the start
   * state, and the first input/output symbols occur on the transition from a
   * start-state to the next state. These first input/output symbols are stored
   * in the second ViterbiNode in the sequence. The last ViterbiNode in the
   * sequence corresponds to the final state and has the last input/output
   * symbols.
   */
	public List> bestViterbiNodeSequences (int n) {
		if (viterbiNodeAlignmentCache != null && viterbiNodeAlignmentCache.size() >= n)
			return viterbiNodeAlignmentCache;
		int numFinal = 0;
		for (int i = 0; i < t.numStates(); i++) {
			if (lattice[latticeLength-1][i] != null && lattice[latticeLength-1][i].delta > Transducer.IMPOSSIBLE_WEIGHT)
				numFinal++;
		}
		ViterbiNode[] finalNodes = new ViterbiNode[numFinal];
		int f = 0;
		for (int i = 0; i < t.numStates(); i++) {
			if (lattice[latticeLength-1][i] != null && lattice[latticeLength-1][i].delta > Transducer.IMPOSSIBLE_WEIGHT)
				finalNodes[f++] = lattice[latticeLength-1][i];
		}
		AStar search = new AStar(finalNodes, latticeLength * t.numStates());
		List> outputs = new ArrayList>(n);
		for (int i = 0; i < n && search.hasNext(); i++) {
		  // gsc: removing unnecessary cast
			SearchNode ans = search.next();
			double weight = -ans.getCost();
			ViterbiNode[] seq = new ViterbiNode[latticeLength];
			// Commented out so we get the start state ViterbiNode -akm 12/2007
			//ans = ans.getParent(); // ans now corresponds to the Viterbi node after the first transition
			for (int j = 0; j < latticeLength; j++) {
				ViterbiNode v = (ViterbiNode)ans.getState();
				assert(v.inputPosition == j);  // was == j+1
				seq[j] = v;
				ans = ans.getParent();
			}
			outputs.add(new SequencePairAlignment(input, new ArraySequence(seq), weight));
		}
		viterbiNodeAlignmentCache = outputs;
		return outputs;
	}


	private List> stateAlignmentCache = null;

	/**
   * Perform the backward pass of Viterbi, returning the n-best sequences of
   * States. Note that the length of each State Sequence is inputLength+1,
   * because the first element of the sequence is the start state, and the first
   * input/output symbols occur on the transition from a start state to the next
   * state. The last State in the sequence corresponds to the final state.
   */	
	public List> bestStateAlignments (int n) {
		if (stateAlignmentCache != null && stateAlignmentCache.size() >= n)
			return stateAlignmentCache;
		bestViterbiNodeSequences(n); // ensure that viterbiNodeAlignmentCache has at least size n
		ArrayList> ret = new ArrayList>(n);
		for (int i = 0; i < n; i++) {
			State[] ss = new State[latticeLength];
			Sequence vs = viterbiNodeAlignmentCache.get(i).output();
			for (int j = 0; j < latticeLength; j++)
				ss[j] = vs.get(j).state; // Here is where we grab the state from the ViterbiNode
			ret.add(new SequencePairAlignment(input, new ArraySequence(ss), viterbiNodeAlignmentCache.get(i).getWeight()));
		}
		stateAlignmentCache = ret;
		return ret;
	}
	
	public SequencePairAlignment bestStateAlignment () {
		return bestStateAlignments(1).get(0);
	}

	public List> bestStateSequences(int n) {
		List> a = bestStateAlignments(n);
		ArrayList> ret = new ArrayList>(n);
		for (int i = 0; i < n; i++)
			ret.add (a.get(i).output());
		return ret;
	}
	
	public Sequence bestStateSequence() {
		return bestStateAlignments(1).get(0).output();
	}
	
	private List> outputAlignmentCache = null;

	public List> bestOutputAlignments (int n) {
		if (outputAlignmentCache != null && outputAlignmentCache.size() >= n)
			return outputAlignmentCache;
		bestViterbiNodeSequences(n); // ensure that viterbiNodeAlignmentCache has at least size n
		ArrayList> ret = new ArrayList>(n);
		for (int i = 0; i < n; i++) {
			Object[] ss = new Object[latticeLength-1];
			Sequence vs = viterbiNodeAlignmentCache.get(i).output();
			for (int j = 0; j < latticeLength-1; j++)
				ss[j] = vs.get(j+1).output; // Here is where we grab the output from the ViterbiNode destination
			ret.add(new SequencePairAlignment(input, new ArraySequence(ss), viterbiNodeAlignmentCache.get(i).getWeight()));
		}
		outputAlignmentCache = ret;
		return ret;
	}	
	
	public SequencePairAlignment bestOutputAlignment () {
		return bestOutputAlignments(1).get(0);
	}

	public List> bestOutputSequences (int n) {
		bestOutputAlignments(n); // ensure that outputAlignmentCache has at least size n
		ArrayList> ret = new ArrayList>(n);
		for (int i = 0; i < n; i++)
			ret.add (outputAlignmentCache.get(i).output());
		return ret;
		// TODO consider caching this result
	}
	
	public Sequence bestOutputSequence () {
		return bestOutputAlignments(1).get(0).output();
	}
	
	public double bestWeight() {
		return bestOutputAlignments(1).get(0).getWeight();
	}
	
	
	/** Increment states and transitions with a count of 1.0 along the best state sequence.
	 *  This provides for a so-called "Viterbi training" approximation. */
	public void incrementTransducer (Transducer.Incrementor incrementor)
	{
		// We are only going to increment along the single best path ".get(0)" below.
		// We could consider having a version of this method:
		// incrementTransducer(Transducer.Incrementor incrementor, double[] counts)
		// where the number of n-best paths to increment would be determined by counts.length
		SequencePairAlignment viterbiNodeAlignment = this.bestViterbiNodeSequences(1).get(0);
		int sequenceLength = viterbiNodeAlignment.output().size();
		assert (sequenceLength == viterbiNodeAlignment.input().size()); // Not sure this works for unequal input/output lengths
		// Increment the initial state
		incrementor.incrementInitialState(viterbiNodeAlignment.output().get(0).state, 1.0);
		// Increment the final state
		incrementor.incrementFinalState(viterbiNodeAlignment.output().get(sequenceLength-1).state, 1.0);
		for (int ip = 0; ip < viterbiNodeAlignment.input().size()-1; ip++) {
			TransitionIterator iter =
				viterbiNodeAlignment.output().get(ip).state.transitionIterator (input, ip, providedOutput, ip);
			// xxx This assumes that a transition is completely
			// identified, and made unique by its destination state and
			// output.  This may not be true!
			int numIncrements = 0;
			while (iter.hasNext()) {
				if (iter.next().equals (viterbiNodeAlignment.output().get(ip+1).state)
						&& iter.getOutput().equals (viterbiNodeAlignment.output().get(ip).output)) {
					incrementor.incrementTransition(iter, 1.0);
					numIncrements++;
				}
			}
			if (numIncrements > 1)
				throw new IllegalStateException ("More than one satisfying transition found.");
			if (numIncrements == 0)
				throw new IllegalStateException ("No satisfying transition found.");
		}
	}

	public double elementwiseAccuracy (Sequence referenceOutput)
	{
		int accuracy = 0;
		Sequence output = bestOutputSequence();
		assert (referenceOutput.size() == output.size());
		for (int i = 0; i < output.size(); i++) {
			//logger.fine("tokenAccuracy: ref: "+referenceOutput.get(i)+" viterbi: "+output.get(i));
			if (referenceOutput.get(i).toString().equals (output.get(i).toString())) {
				accuracy++;
			}
		}
		logger.info ("Number correct: " + accuracy + " out of " + output.size());
		return ((double)accuracy)/output.size();
	}

	public double tokenAccuracy (Sequence referenceOutput, PrintWriter out)
	{
		Sequence output = bestOutputSequence();
		int accuracy = 0;
		String testString;
		assert (referenceOutput.size() == output.size());
		for (int i = 0; i < output.size(); i++) {
			//logger.fine("tokenAccuracy: ref: "+referenceOutput.get(i)+" viterbi: "+output.get(i));
			testString = output.get(i).toString();
			if (out != null) {
				out.println(testString);
			}
			if (referenceOutput.get(i).toString().equals (testString)) {
				accuracy++;
			}
		}
		logger.info ("Number correct: " + accuracy + " out of " + output.size());
		return ((double)accuracy)/output.size();
	}

	
	public static class Factory extends MaxLatticeFactory implements Serializable
	{
		public MaxLattice newMaxLattice (Transducer trans, Sequence inputSequence, Sequence outputSequence)
		{
			return new MaxLatticeDefault (trans, inputSequence, outputSequence);
		}

		private static final long serialVersionUID = 1;
		private static final int CURRENT_SERIAL_VERSION = 1;

		private void writeObject(ObjectOutputStream out) throws IOException {
			out.writeInt(CURRENT_SERIAL_VERSION);
		}
		private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
			in.readInt();
		}


	}

}