All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.SumLatticeBeam Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.fst;

import java.util.ArrayList;
import java.util.logging.Level;
import java.util.logging.Logger;

import cc.mallet.fst.Transducer.State;
import cc.mallet.fst.Transducer.TransitionIterator;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;



//******************************************************************************
//CPAL - NEW "BEAM" Version of Forward Backward
//******************************************************************************


public class SumLatticeBeam implements SumLattice  // CPAL - like Lattice but using max-product to get the viterbiPath
{


	// CPAL - these worked well for nettalk
	//private int beamWidth = 10;
	//private double KLeps = .005;
	boolean UseForwardBackwardBeam = false;
	protected static int beamWidth = 3;
	private double KLeps = 0;
	private double Rmin = 0.1;
	private double nstatesExpl[];
	private int curIter = 0;
	int tctIter = 0;    // The number of times we have been called this iteration
	private double curAvgNstatesExpl;





	public int getBeamWidth ()
	{
		return beamWidth;
	}

	public void setBeamWidth (int beamWidth)
	{
		this.beamWidth = beamWidth;
	}

	public int getTctIter(){
		return this.tctIter;
	}

	public void setCurIter (int curIter)
	{
		this.curIter = curIter;
		this.tctIter = 0;
	}

	public void incIter ()
	{
		this.tctIter++;
	}

	public void setKLeps (double KLeps)
	{
		this.KLeps = KLeps;
	}

	public void setRmin (double Rmin) {
		this.Rmin = Rmin;
	}

	public double[] getNstatesExpl()
	{
		return nstatesExpl;
	}

	public boolean getUseForwardBackwardBeam(){
		return this.UseForwardBackwardBeam;
	}

	public void setUseForwardBackwardBeam (boolean state) {
		this.UseForwardBackwardBeam = state;
	}






	private static Logger logger = MalletLogger.getLogger(SumLatticeBeam.class.getName());

	// "ip" == "input position", "op" == "output position", "i" == "state index"
	Transducer t;
	double weight;
	Sequence input, output;
	LatticeNode[][] nodes;			 // indexed by ip,i
	int latticeLength;
	int curBeamWidth;               // CPAL - can be adapted if maximizer is confused

	// xxx Now that we are incrementing here directly, there isn't
	// necessarily a need to save all these arrays...
	// log(probability) of being in state "i" at input position "ip"
	double[][] gammas;					 // indexed by ip,i
	double[][][] xis;            // indexed by ip,i,j; saved only if saveXis is true;

	LabelVector labelings[];			 // indexed by op, created only if "outputAlphabet" is non-null in constructor

	private LatticeNode getLatticeNode (int ip, int stateIndex)
	{
		if (nodes[ip][stateIndex] == null)
			nodes[ip][stateIndex] = new LatticeNode (ip, t.getState (stateIndex));
		return nodes[ip][stateIndex];
	}

	// You may pass null for output, meaning that the lattice
	// is not constrained to match the output
	public SumLatticeBeam (Transducer t, Sequence input, Sequence output, Transducer.Incrementor incrementor)
	{
		this (t, input, output, incrementor, false, null);
	}

	// You may pass null for output, meaning that the lattice
	// is not constrained to match the output
	public SumLatticeBeam (Transducer t, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis)
	{
		this (t, input, output, incrementor, saveXis, null);
	}

	// If outputAlphabet is non-null, this will create a LabelVector
	// for each position in the output sequence indicating the
	// probability distribution over possible outputs at that time
	// index
	public SumLatticeBeam (Transducer t, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet)
	{
		this.t = t;
		if (false && logger.isLoggable (Level.FINE)) {
			logger.fine ("Starting Lattice");
			logger.fine ("Input: ");
			for (int ip = 0; ip < input.size(); ip++)
				logger.fine (" " + input.get(ip));
			logger.fine ("\nOutput: ");
			if (output == null)
				logger.fine ("null");
			else
				for (int op = 0; op < output.size(); op++)
					logger.fine (" " + output.get(op));
			logger.fine ("\n");
		}

		// Initialize some structures
		this.input = input;
		this.output = output;
		// xxx Not very efficient when the lattice is actually sparse,
		// especially when the number of states is large and the
		// sequence is long.
		latticeLength = input.size()+1;
		int numStates = t.numStates();
		nodes = new LatticeNode[latticeLength][numStates];
		// xxx Yipes, this could get big; something sparse might be better?
		gammas = new double[latticeLength][numStates];
		if (saveXis) xis = new double[latticeLength][numStates][numStates];

		double outputCounts[][] = null;
		if (outputAlphabet != null)
			outputCounts = new double[latticeLength][outputAlphabet.size()];

		for (int i = 0; i < numStates; i++) {
			for (int ip = 0; ip < latticeLength; ip++)
				gammas[ip][i] = Transducer.IMPOSSIBLE_WEIGHT;
			if (saveXis)
				for (int j = 0; j < numStates; j++)
					for (int ip = 0; ip < latticeLength; ip++)
						xis[ip][i][j] = Transducer.IMPOSSIBLE_WEIGHT;
		}

		// Forward pass
		logger.fine ("Starting Foward pass");
		boolean atLeastOneInitialState = false;
		for (int i = 0; i < numStates; i++) {
			double initialWeight = t.getState(i).getInitialWeight();
			//System.out.println ("Forward pass initialWeight = "+initialWeight);
			if (initialWeight < Transducer.IMPOSSIBLE_WEIGHT) {
				getLatticeNode(0, i).alpha = initialWeight;
				//System.out.println ("nodes[0][i].alpha="+nodes[0][i].alpha);
				atLeastOneInitialState = true;
			}
		}
		if (atLeastOneInitialState == false)
			logger.warning ("There are no starting states!");


		// CPAL - a sorted list for our beam experiments
		NBestSlist[] slists = new NBestSlist[latticeLength];
		// CPAL - used for stats
		nstatesExpl = new double[latticeLength];
		// CPAL - used to adapt beam if optimizer is getting confused
		// tctIter++;
		if(curIter == 0) {
			curBeamWidth = numStates;
		} else if(tctIter > 1 && curIter != 0) {
			//curBeamWidth = Math.min((int)Math.round(curAvgNstatesExpl*2),numStates);
			//System.out.println ("Doubling Minimum Beam Size to: "+curBeamWidth);
			curBeamWidth = beamWidth;
		} else {
			curBeamWidth = beamWidth;
		}

		// ************************************************************
		for (int ip = 0; ip < latticeLength-1; ip++) {

			// CPAL - add this to construct the beam
			// ***************************************************

			// CPAL - sets up the sorted list
			slists[ip] = new NBestSlist(numStates);
			// CPAL - set the
			slists[ip].setKLMinE(curBeamWidth);
			slists[ip].setKLeps(KLeps);
			slists[ip].setRmin(Rmin);

			for(int i = 0 ; i< numStates ; i++){
				if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT)
					continue;
				//State s = t.getState(i);
				// CPAL - give the NB viterbi node the (Weight, position)
				NBForBackNode cnode = new NBForBackNode(nodes[ip][i].alpha, i);
				slists[ip].push(cnode);

			}

			// CPAL - unlike std. n-best beam we now filter the list based
			// on a KL divergence like measure
			// ***************************************************
			// use method which computes the cumulative log sum and
			// finds the point at which the sum is within KLeps
			int KLMaxPos=1;
			int RminPos=1;


			if(KLeps > 0) {
				KLMaxPos = slists[ip].getKLpos();
				nstatesExpl[ip]=(double)KLMaxPos;
			} else if(KLeps == 0) {

				if(Rmin > 0) {
					RminPos = slists[ip].getTHRpos();
				} else {
					slists[ip].setRmin(-Rmin);
					RminPos = slists[ip].getTHRposSTRAWMAN();
				}
				nstatesExpl[ip]=(double)RminPos;

			} else {
				// Trick, negative values for KLeps mean use the max of KL an Rmin
				slists[ip].setKLeps(-KLeps);
				KLMaxPos = slists[ip].getKLpos();

				//RminPos = slists[ip].getTHRpos();

				if(Rmin > 0) {
					RminPos = slists[ip].getTHRpos();
				} else {
					slists[ip].setRmin(-Rmin);
					RminPos = slists[ip].getTHRposSTRAWMAN();
				}

				if(KLMaxPos > RminPos) {
					nstatesExpl[ip]=(double)KLMaxPos;
				} else {
					nstatesExpl[ip]=(double)RminPos;
				}
			}
			//System.out.println(nstatesExpl[ip] + " ");

			// CPAL - contemplating setting values to something else
			int tmppos;
			for (int i = (int) nstatesExpl[ip]+1; i < slists[ip].size(); i++) {
				tmppos = slists[ip].getPosByIndex(i);
				nodes[ip][tmppos].alpha = Transducer.IMPOSSIBLE_WEIGHT;
				nodes[ip][tmppos] = null;   // Null is faster and seems to work the same
			}
			// - done contemplation

			//for (int i = 0; i < numStates; i++) {
			for(int jj=0 ; jj< nstatesExpl[ip]; jj++) {

				int i = slists[ip].getPosByIndex(jj);

				// CPAL - dont need this anymore
				// should be taken care of in the lists
				//if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT)
				// xxx if we end up doing this a lot,
				// we could save a list of the non-null ones
				//	continue;


				State s = t.getState(i);

				TransitionIterator iter = s.transitionIterator (input, ip, output, ip);
				if (logger.isLoggable (Level.FINE))
					logger.fine (" Starting Foward transition iteration from state "
							+ s.getName() + " on input " + input.get(ip).toString()
							+ " and output "
							+ (output==null ? "(null)" : output.get(ip).toString()));
				while (iter.hasNext()) {
					State destination = iter.nextState();
					if (logger.isLoggable (Level.FINE))
						logger.fine ("Forward Lattice[inputPos="+ip
								+"][source="+s.getName()
								+"][dest="+destination.getName()+"]");
					LatticeNode destinationNode = getLatticeNode (ip+1, destination.getIndex());
					destinationNode.output = iter.getOutput();
					double transitionWeight = iter.getWeight();
					if (logger.isLoggable (Level.FINE))
						logger.fine ("transitionWeight="+transitionWeight
								+" nodes["+ip+"]["+i+"].alpha="+nodes[ip][i].alpha
								+" destinationNode.alpha="+destinationNode.alpha);
					destinationNode.alpha = Transducer.sumLogProb (destinationNode.alpha,
							nodes[ip][i].alpha + transitionWeight);
					//System.out.println ("destinationNode.alpha <- "+destinationNode.alpha);
				}
			}
		}

		//System.out.println("Mean Nodes Explored: " + MatrixOps.mean(nstatesExpl));
		curAvgNstatesExpl = MatrixOps.mean(nstatesExpl);

		// Calculate total cost of Lattice.  This is the normalizer
		weight = Transducer.IMPOSSIBLE_WEIGHT;
		for (int i = 0; i < numStates; i++)
			if (nodes[latticeLength-1][i] != null) {
				// Note: actually we could sum at any ip index,
				// the choice of latticeLength-1 is arbitrary
				//System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha);
				//System.out.println ("Ending beta,  state["+i+"] = "+t.getState(i).finalWeight);
				weight = Transducer.sumLogProb (weight,
						(nodes[latticeLength-1][i].alpha + t.getState(i).getFinalWeight()));
			}
		// Weight is now an "unnormalized weight" of the entire Lattice
		//assert (weight >= 0) : "weight = "+weight;

		// If the sequence has -infinite weight, just return.
		// Usefully this avoids calling any incrementX methods.
		// It also relies on the fact that the gammas[][] and .alpha and .beta values
		// are already initialized to values that reflect -infinite weight
		// xxx Although perhaps not all (alphas,betas) exactly correctly reflecting?
		if (weight == Transducer.IMPOSSIBLE_WEIGHT)
			return;

		// Backward pass
		for (int i = 0; i < numStates; i++)
			if (nodes[latticeLength-1][i] != null) {
				State s = t.getState(i);
				nodes[latticeLength-1][i].beta = s.getFinalWeight();
				gammas[latticeLength-1][i] =
					nodes[latticeLength-1][i].alpha + nodes[latticeLength-1][i].beta - weight;
				if (incrementor != null) {
					double p = Math.exp(gammas[latticeLength-1][i]);
					assert (p > Transducer.IMPOSSIBLE_WEIGHT && !Double.isNaN(p))
					: "p="+p+" gamma="+gammas[latticeLength-1][i];
					incrementor.incrementFinalState(s, p);
				}
			}

		for (int ip = latticeLength-2; ip >= 0; ip--) {
			for (int i = 0; i < numStates; i++) {
				if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT)
					// Note that skipping here based on alpha means that beta values won't
					// be correct, but since alpha is infinite anyway, it shouldn't matter.
					continue;
				State s = t.getState(i);
				TransitionIterator iter = s.transitionIterator (input, ip, output, ip);
				while (iter.hasNext()) {
					State destination = iter.nextState();
					if (logger.isLoggable (Level.FINE))
						logger.fine ("Backward Lattice[inputPos="+ip
								+"][source="+s.getName()
								+"][dest="+destination.getName()+"]");
					int j = destination.getIndex();
					LatticeNode destinationNode = nodes[ip+1][j];
					if (destinationNode != null) {
						double transitionWeight = iter.getWeight();
						assert (!Double.isNaN(transitionWeight));
						//							assert (transitionWeight >= 0);  Not necessarily
						double oldBeta = nodes[ip][i].beta;
						assert (!Double.isNaN(nodes[ip][i].beta));
						nodes[ip][i].beta = Transducer.sumLogProb (nodes[ip][i].beta,
								destinationNode.beta + transitionWeight);
						assert (!Double.isNaN(nodes[ip][i].beta))
						: "dest.beta="+destinationNode.beta+" trans="+transitionWeight+" sum="+(destinationNode.beta+transitionWeight)
						+ " oldBeta="+oldBeta;
						double xi = nodes[ip][i].alpha + transitionWeight + nodes[ip+1][j].beta - weight;
						if (saveXis) xis[ip][i][j] = xi;
						assert (!Double.isNaN(nodes[ip][i].alpha));
						assert (!Double.isNaN(transitionWeight));
						assert (!Double.isNaN(nodes[ip+1][j].beta));
						assert (!Double.isNaN(weight));
						if (incrementor != null || outputAlphabet != null) {
							double p = Math.exp(xi);
							assert (p > Transducer.IMPOSSIBLE_WEIGHT && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+xi;
							if (incrementor != null)
								incrementor.incrementTransition(iter, p);
							if (outputAlphabet != null) {
								int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false);
								assert (outputIndex >= 0);
								// xxx This assumes that "ip" == "op"!
								outputCounts[ip][outputIndex] += p;
								//System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p);
							}
						}
					}
				}
				gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - weight;
			}

			if(true){
				// CPAL - check the normalization
				double checknorm = Transducer.IMPOSSIBLE_WEIGHT;
				for (int i = 0; i < numStates; i++)
					if (nodes[ip][i] != null) {
						// Note: actually we could sum at any ip index,
						// the choice of latticeLength-1 is arbitrary
						//System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha);
						//System.out.println ("Ending beta,  state["+i+"] = "+t.getState(i).finalWeight);
						checknorm = Transducer.sumLogProb (checknorm, gammas[ip][i]);
					}
				// System.out.println ("Check Gamma, sum="+checknorm);
				// CPAL - done check of normalization

				// CPAL - normalize
				for (int i = 0; i < numStates; i++)
					if (nodes[ip][i] != null) {
						gammas[ip][i] = gammas[ip][i] - checknorm;
					}
				//System.out.println ("Check Gamma, sum="+checknorm);
				// CPAL - normalization
			}
		}
		if (incrementor != null)
			for (int i = 0; i < numStates; i++) {
				double p = Math.exp(gammas[0][i]);
				assert (p > Transducer.IMPOSSIBLE_WEIGHT && !Double.isNaN(p));
				incrementor.incrementInitialState(t.getState(i), p);
			}
		if (outputAlphabet != null) {
			labelings = new LabelVector[latticeLength];
			for (int ip = latticeLength-2; ip >= 0; ip--) {
				assert (Math.abs(1.0-MatrixOps.sum (outputCounts[ip])) < 0.000001);;
				labelings[ip] = new LabelVector (outputAlphabet, outputCounts[ip]);
			}
		}

	}
	
	public Sequence getInput() { 
	  return input;
	}

	// CPAL - a simple node holding a weight and position of the state
	private class NBForBackNode
	{
		double weight;
		int pos;
		NBForBackNode(double weight, int pos)
		{
			this.weight = weight;
			this.pos = pos;
		}
	}

	private class NBestSlist
	{
		ArrayList list = new ArrayList();
		int MaxElements;
		int KLMinElements;
		int KLMaxPos;
		double KLeps;
		double Rmin;

		NBestSlist(int MaxElements)
		{
			this.MaxElements = MaxElements;
		}

		boolean setKLMinE(int KLMinElements){
			this.KLMinElements = KLMinElements;
			return true;
		}

		int size()
		{
			return list.size();
		}

		boolean empty()
		{
			return list.isEmpty();
		}

		Object pop()
		{
			return list.remove(0);
		}

		int getPosByIndex(int ii){
			NBForBackNode tn = (NBForBackNode)list.get(ii);
			return tn.pos;
		}

		double getWeightByIndex(int ii){
			NBForBackNode tn = (NBForBackNode)list.get(ii);
			return tn.weight;
		}

		void setKLeps(double KLeps){
			this.KLeps = KLeps;
		}

		void setRmin(double Rmin){
			this.Rmin = Rmin;
		}

		int getTHRpos(){

			NBForBackNode tn;
			double lc1, lc2;


			tn = (NBForBackNode)list.get(0);
			lc1 = tn.weight;
			tn = (NBForBackNode)list.get(list.size()-1);
			lc2 = tn.weight;

			double minc = lc1 - lc2;
			double mincTHR = minc - minc*Rmin;

			for(int i=1;i mincTHR){
					return i+1;
				}

			}

			return list.size();

		}

		int getTHRposSTRAWMAN(){

			NBForBackNode tn;
			double lc1, lc2;


			tn = (NBForBackNode)list.get(0);
			lc1 = tn.weight;

			double mincTHR = -lc1*Rmin;

			//double minc = lc1 - lc2;
			//double mincTHR = minc - minc*Rmin;

			for(int i=1;i0){
				//    int asdf=1;
				//}

				if (i==0) {
					CSNLP[i] = lc;
				} else {
					CSNLP[i] = Transducer.sumLogProb(CSNLP[i-1], lc);
				}
			}

			// normalize
			for(int i=0;i= KLMinElements) {
						return KLMaxPos;
					} else if(list.size() >= KLMinElements){
						return KLMinElements;
					}
				}
			}

			KLMaxPos = list.size();
			return KLMaxPos;
		}

		ArrayList push(NBForBackNode vn)
		{
			double tc = vn.weight;
			boolean atEnd = true;

			for(int i=0;iMaxElements) {
				list.remove(MaxElements);
			}

			//double f = o.totalWeight[o.nextBestStateIndex];
			//boolean atEnd = true;
			//for(int i=0; i requiredSegment  as indicated by
	        constrainedSequence 
	       @param inputSequence input sequence
	       @param outputSequence output sequence
	       @param requiredSegment segment of sequence that must be labelled
	       @param constrainedSequence lattice must have labels of this
	       sequence from  requiredSegment.start  to 
	       requiredSegment.end  correctly
	 */
	SumLatticeBeam (Transducer t, Sequence inputSequence, Sequence outputSequence, Segment requiredSegment, Sequence constrainedSequence) 
	{
		this (t, inputSequence, outputSequence, (Transducer.Incrementor)null, null, 
				makeConstraints(t, inputSequence, outputSequence, requiredSegment, constrainedSequence));
	}
	private static int[] makeConstraints (Transducer t, Sequence inputSequence, Sequence outputSequence, Segment requiredSegment, Sequence constrainedSequence) {
		if (constrainedSequence.size () != inputSequence.size ())
			throw new IllegalArgumentException ("constrainedSequence.size [" + constrainedSequence.size () + "] != inputSequence.size [" + inputSequence.size () + "]");
		// constraints tells the lattice which states must emit which
		// observations.  positive values say all paths must pass through
		// this state index, negative values say all paths must _not_
		// pass through this state index.  0 means we don't
		// care. initialize to 0. include 1 extra node for start state.
		int [] constraints = new int [constrainedSequence.size() + 1];
		for (int c = 0; c < constraints.length; c++)
			constraints[c] = 0;
		for (int i=requiredSegment.getStart (); i <= requiredSegment.getEnd(); i++) {
			int si = t.stateIndexOfString ((String)constrainedSequence.get (i));
			if (si == -1)
				logger.warning ("Could not find state " + constrainedSequence.get (i) + ". Check that state labels match startTages and inTags, and that all labels are seen in training data.");
//			throw new IllegalArgumentException ("Could not find state " + constrainedSequence.get(i) + ". Check that state labels match startTags and InTags.");
			constraints[i+1] = si + 1;
		}
		// set additional negative constraint to ensure state after
		// segment is not a continue tag

		// xxx if segment length=1, this actually constrains the sequence
		// to B-tag (B-tag)', instead of the intended constraint of B-tag
		// (I-tag)'
		// the fix below is unsafe, but will have to do for now.
		// FIXED BELOW
		/*		String endTag = (String) constrainedSequence.get (requiredSegment.getEnd ());
				if (requiredSegment.getEnd()+2 < constraints.length) {
					if (requiredSegment.getStart() == requiredSegment.getEnd()) { // segment has length 1
						if (endTag.startsWith ("B-")) {
							endTag = "I" + endTag.substring (1, endTag.length());
						}
						else if (!(endTag.startsWith ("I-") || endTag.startsWith ("0")))
							throw new IllegalArgumentException ("Constrained Lattice requires that states are tagged in B-I-O format.");
					}
					int statei = stateIndexOfString (endTag);
					if (statei == -1) // no I- tag for this B- tag
						statei = stateIndexOfString ((String)constrainedSequence.get (requiredSegment.getStart ()));
					constraints[requiredSegment.getEnd() + 2] = - (statei + 1);
				}
		 */
		if (requiredSegment.getEnd() + 2 < constraints.length) { // if
			String endTag = requiredSegment.getInTag().toString();
			int statei = t.stateIndexOfString (endTag);
			if (statei == -1)
				throw new IllegalArgumentException ("Could not find state " + endTag + ". Check that state labels match startTags and InTags.");
			constraints[requiredSegment.getEnd() + 2] = - (statei + 1);
		}

		//		printStates ();
		logger.fine ("Segment:\n" + requiredSegment.sequenceToString () +
				"\nconstrainedSequence:\n" + constrainedSequence +
		"\nConstraints:\n");
		for (int i=0; i < constraints.length; i++) {
			logger.fine (constraints[i] + "\t");
		}
		logger.fine ("");
		return constraints;
	}




	// culotta: constructor for constrained lattice
	/** Create a lattice that constrains its transitions such that the
	 *  pairs in "constraints" are adhered
	 * to. constraints is an array where each entry is the index of
	 * the required label at that position. An entry of 0 means there
	 * are no constraints on that . Positive values
	 * mean the path must pass through that state. Negative values
	 * mean the path must _not_ pass through that state. NOTE -
	 * constraints.length must be equal to output.size() + 1. A
	 * lattice has one extra position for the initial
	 * state. Generally, this should be unconstrained, since it does
	 * not produce an observation.
	 */
	public SumLatticeBeam (Transducer t, Sequence input, Sequence output, Transducer.Incrementor incrementor, LabelAlphabet outputAlphabet, int [] constraints)
	{
		this.t = t;
		if (false && logger.isLoggable (Level.FINE)) {
			logger.fine ("Starting Lattice");
			logger.fine ("Input: ");
			for (int ip = 0; ip < input.size(); ip++)
				logger.fine (" " + input.get(ip));
			logger.fine ("\nOutput: ");
			if (output == null)
				logger.fine ("null");
			else
				for (int op = 0; op < output.size(); op++)
					logger.fine (" " + output.get(op));
			logger.fine ("\n");
		}

		// Initialize some structures
		this.input = input;
		this.output = output;
		// xxx Not very efficient when the lattice is actually sparse,
		// especially when the number of states is large and the
		// sequence is long.
		latticeLength = input.size()+1;
		int numStates = t.numStates();
		nodes = new LatticeNode[latticeLength][numStates];
		// xxx Yipes, this could get big; something sparse might be better?
		gammas = new double[latticeLength][numStates];
		// xxx Move this to an ivar, so we can save it?  But for what?
		// Commenting this out, because it's a memory hog and not used right now.
		//  Uncomment and conditionalize under a flag if ever needed. -cas
		// double xis[][][] = new double[latticeLength][numStates][numStates];
		double outputCounts[][] = null;
		if (outputAlphabet != null)
			outputCounts = new double[latticeLength][outputAlphabet.size()];

		for (int i = 0; i < numStates; i++) {
			for (int ip = 0; ip < latticeLength; ip++)
				gammas[ip][i] = Transducer.IMPOSSIBLE_WEIGHT;
			/* Commenting out xis -cas
			for (int j = 0; j < numStates; j++)
				for (int ip = 0; ip < latticeLength; ip++)
					xis[ip][i][j] = Transducer.IMPOSSIBLE_WEIGHT;
			 */
		}

		// Forward pass
		logger.fine ("Starting Constrained Foward pass");

		// ensure that at least one state has initial weight less than Infinity
		// so we can start from there
		boolean atLeastOneInitialState = false;
		for (int i = 0; i < numStates; i++) {
			double initialWeight = t.getState(i).getInitialWeight();
			//System.out.println ("Forward pass initialWeight = "+initialWeight);
			if (initialWeight > Transducer.IMPOSSIBLE_WEIGHT) {
				getLatticeNode(0, i).alpha = initialWeight;
				//System.out.println ("nodes[0][i].alpha="+nodes[0][i].alpha);
				atLeastOneInitialState = true;
			}
		}

		if (atLeastOneInitialState == false)
			logger.warning ("There are no starting states!");
		for (int ip = 0; ip < latticeLength-1; ip++)
			for (int i = 0; i < numStates; i++) {
				logger.fine ("ip=" + ip+", i=" + i);
				// check if this node is possible at this . if not, skip it.
				if (constraints[ip] > 0) { // must be in state indexed by constraints[ip] - 1
					if (constraints[ip]-1 != i) {
						logger.fine ("Current state does not match positive constraint. position="+ip+", constraint="+(constraints[ip]-1)+", currState="+i);
						continue;
					}
				}
				else if (constraints[ip] < 0) { // must _not_ be in state indexed by constraints[ip]
					if (constraints[ip]+1 == -i) {
						logger.fine ("Current state does not match negative constraint. position="+ip+", constraint="+(constraints[ip]+1)+", currState="+i);
						continue;
					}
				}
				if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT) {
					// xxx if we end up doing this a lot,
					// we could save a list of the non-null ones
					if (nodes[ip][i] == null) logger.fine ("nodes[ip][i] is NULL");
					else if (nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT) logger.fine ("nodes[ip][i].alpha is Inf");
					logger.fine ("-INFINITE weight or NULL...skipping");
					continue;
				}
				State s = t.getState(i);

				TransitionIterator iter = s.transitionIterator (input, ip, output, ip);
				if (logger.isLoggable (Level.FINE))
					logger.fine (" Starting Forward transition iteration from state "
							+ s.getName() + " on input " + input.get(ip).toString()
							+ " and output "
							+ (output==null ? "(null)" : output.get(ip).toString()));
				while (iter.hasNext()) {
					State destination = iter.nextState();
					boolean legalTransition = true;
					// check constraints to see if node at  can transition to destination
					if (ip+1 < constraints.length && constraints[ip+1] > 0 && ((constraints[ip+1]-1) != destination.getIndex())) {
						logger.fine ("Destination state does not match positive constraint. Assigning -infinite weight. position="+(ip+1)+", constraint="+(constraints[ip+1]-1)+", source ="+i+", destination="+destination.getIndex());
						legalTransition = false;
					}
					else if (((ip+1) < constraints.length) && constraints[ip+1] < 0 && (-(constraints[ip+1]+1) == destination.getIndex())) {
						logger.fine ("Destination state does not match negative constraint. Assigning -infinite weight. position="+(ip+1)+", constraint="+(constraints[ip+1]+1)+", destination="+destination.getIndex());
						legalTransition = false;
					}

					if (logger.isLoggable (Level.FINE))
						logger.fine ("Forward Lattice[inputPos="+ip
								+"][source="+s.getName()
								+"][dest="+destination.getName()+"]");
					LatticeNode destinationNode = getLatticeNode (ip+1, destination.getIndex());
					destinationNode.output = iter.getOutput();
					double transitionWeight = iter.getWeight();
					if (legalTransition) {
						//if (logger.isLoggable (Level.FINE))
						logger.fine ("transitionWeight="+transitionWeight
								+" nodes["+ip+"]["+i+"].alpha="+nodes[ip][i].alpha
								+" destinationNode.alpha="+destinationNode.alpha);
						destinationNode.alpha = Transducer.sumLogProb (destinationNode.alpha,
								nodes[ip][i].alpha + transitionWeight);
						//System.out.println ("destinationNode.alpha <- "+destinationNode.alpha);
						logger.fine ("Set alpha of latticeNode at ip = "+ (ip+1) + " stateIndex = " + destination.getIndex() + ", destinationNode.alpha = " + destinationNode.alpha);
					}
					else {
						// this is an illegal transition according to our
						// constraints, so set its prob to 0 . NO, alpha's are
						// unnormalized weights...set to Inf //
						// destinationNode.alpha = 0.0;
//						destinationNode.alpha = Transducer.IMPOSSIBLE_WEIGHT;
						logger.fine ("Illegal transition from state " + i + " to state " + destination.getIndex() + ". Setting alpha to Inf");
					}
				}
			}

		// Calculate total weight of Lattice.  This is the normalizer
		weight = Transducer.IMPOSSIBLE_WEIGHT;
		for (int i = 0; i < numStates; i++)
			if (nodes[latticeLength-1][i] != null) {
				// Note: actually we could sum at any ip index,
				// the choice of latticeLength-1 is arbitrary
				//System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha);
				//System.out.println ("Ending beta,  state["+i+"] = "+t.getState(i).finalWeight);
				if (constraints[latticeLength-1] > 0 && i != constraints[latticeLength-1]-1)
					continue;
				if (constraints[latticeLength-1] < 0 && -i == constraints[latticeLength-1]+1)
					continue;
				logger.fine ("Summing final lattice weight. state="+i+", alpha="+nodes[latticeLength-1][i].alpha + ", final weight = "+t.getState(i).getFinalWeight());
				weight = Transducer.sumLogProb (weight,
						(nodes[latticeLength-1][i].alpha + t.getState(i).getFinalWeight()));
			}
		// Weight is now an "unnormalized weight" of the entire Lattice
		//assert (weight >= 0) : "weight = "+weight;

		// If the sequence has -infinite weight, just return.
		// Usefully this avoids calling any incrementX methods.
		// It also relies on the fact that the gammas[][] and .alpha and .beta values
		// are already initialized to values that reflect -infinite weight
		// xxx Although perhaps not all (alphas,betas) exactly correctly reflecting?
		if (weight == Transducer.IMPOSSIBLE_WEIGHT)
			return;

		// Backward pass
		for (int i = 0; i < numStates; i++)
			if (nodes[latticeLength-1][i] != null) {
				State s = t.getState(i);
				nodes[latticeLength-1][i].beta = s.getFinalWeight();
				gammas[latticeLength-1][i] =
					nodes[latticeLength-1][i].alpha + nodes[latticeLength-1][i].beta - weight;
				if (incrementor != null) {
					double p = Math.exp(gammas[latticeLength-1][i]);
					assert (p >= 0 && p <= 1.0 && !Double.isNaN(p)) : "p="+p+" gamma="+gammas[latticeLength-1][i];
					incrementor.incrementFinalState(s, p);
				}
			}
		for (int ip = latticeLength-2; ip >= 0; ip--) {
			for (int i = 0; i < numStates; i++) {
				if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT)
					// Note that skipping here based on alpha means that beta values won't
					// be correct, but since alpha is infinite anyway, it shouldn't matter.
					continue;
				State s = t.getState(i);
				TransitionIterator iter = s.transitionIterator (input, ip, output, ip);
				while (iter.hasNext()) {
					State destination = iter.nextState();
					if (logger.isLoggable (Level.FINE))
						logger.fine ("Backward Lattice[inputPos="+ip
								+"][source="+s.getName()
								+"][dest="+destination.getName()+"]");
					int j = destination.getIndex();
					LatticeNode destinationNode = nodes[ip+1][j];
					if (destinationNode != null) {
						double transitionWeight = iter.getWeight();
						assert (!Double.isNaN(transitionWeight));
						//							assert (transitionWeight >= 0);  Not necessarily
						double oldBeta = nodes[ip][i].beta;
						assert (!Double.isNaN(nodes[ip][i].beta));
						nodes[ip][i].beta = Transducer.sumLogProb (nodes[ip][i].beta,
								destinationNode.beta + transitionWeight);
						assert (!Double.isNaN(nodes[ip][i].beta))
						: "dest.beta="+destinationNode.beta+" trans="+transitionWeight+" sum="+(destinationNode.beta+transitionWeight)
						+ " oldBeta="+oldBeta;
						// xis[ip][i][j] = nodes[ip][i].alpha + transitionWeight + nodes[ip+1][j].beta - weight;
						assert (!Double.isNaN(nodes[ip][i].alpha));
						assert (!Double.isNaN(transitionWeight));
						assert (!Double.isNaN(nodes[ip+1][j].beta));
						assert (!Double.isNaN(weight));
						if (incrementor != null || outputAlphabet != null) {
							double xi = nodes[ip][i].alpha + transitionWeight + nodes[ip+1][j].beta - weight;
							double p = Math.exp(xi);
							assert (p >= 0 && p <= 1.0 && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+-xi;
							if (incrementor != null)
								incrementor.incrementTransition(iter, p);
							if (outputAlphabet != null) {
								int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false);
								assert (outputIndex >= 0);
								// xxx This assumes that "ip" == "op"!
								outputCounts[ip][outputIndex] += p;
								//System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p);
							}
						}
					}
				}
				gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - weight;
			}
		}
		if (incrementor != null)
			for (int i = 0; i < numStates; i++) {
				double p = Math.exp(gammas[0][i]);
				assert (p >= 0.0 && p <= 1.0 && !Double.isNaN(p));
				incrementor.incrementInitialState(t.getState(i), p);
			}
		if (outputAlphabet != null) {
			labelings = new LabelVector[latticeLength];
			for (int ip = latticeLength-2; ip >= 0; ip--) {
				assert (Math.abs(1.0-MatrixOps.sum (outputCounts[ip])) < 0.000001);;
				labelings[ip] = new LabelVector (outputAlphabet, outputCounts[ip]);
			}
		}
	}

	public double getTotalWeight () {
		assert (!Double.isNaN(weight));
		return weight; }

	// No, this.weight is an "unnormalized weight"
	//public double getProbability () { return Math.exp (weight); }

	public double getGammaWeight (int inputPosition, State s) {
		return gammas[inputPosition][s.getIndex()]; }

	public double getGammaProbability (int inputPosition, State s) {
		return Math.exp (gammas[inputPosition][s.getIndex()]); }
	
	public double[][][] getXis() {
		return xis;
	}

	public double[][] getGammas () {
		return gammas;
	}
	

	public double getXiProbability (int ip, State s1, State s2) {
		if (xis == null)
			throw new IllegalStateException ("xis were not saved.");

		int i = s1.getIndex ();
		int j = s2.getIndex ();
		return Math.exp (xis[ip][i][j]);
	}

	public double getXiWeight (int ip, State s1, State s2)
	{
		if (xis == null)
			throw new IllegalStateException ("xis were not saved.");

		int i = s1.getIndex ();
		int j = s2.getIndex ();
		return xis[ip][i][j];
	}

	public int length () { return latticeLength; }

	public double getAlpha (int ip, State s) {
		LatticeNode node = getLatticeNode (ip, s.getIndex ());
		return node.alpha;
	}

	public double getBeta (int ip, State s) {
		LatticeNode node = getLatticeNode (ip, s.getIndex ());
		return node.beta;
	}

	public LabelVector getLabelingAtPosition (int outputPosition)	{
		if (labelings != null)
			return labelings[outputPosition];
		return null;
	}

	public Transducer getTransducer ()
	{
		return t;
	}


	// A container for some information about a particular input position and state
	private class LatticeNode
	{
		int inputPosition;
		// outputPosition not really needed until we deal with asymmetric epsilon.
		State state;
		Object output;
		double alpha = Transducer.IMPOSSIBLE_WEIGHT;
		double beta = Transducer.IMPOSSIBLE_WEIGHT;
		LatticeNode (int inputPosition, State state)	{
			this.inputPosition = inputPosition;
			this.state = state;
			assert (this.alpha == Transducer.IMPOSSIBLE_WEIGHT);	// xxx Remove this check
		}
	}
	
	public static class Factory extends SumLatticeFactory
	{
		int bw;
		public Factory (int beamWidth) {
			bw = beamWidth;
		}
		public SumLattice newSumLattice (Transducer trans, Sequence input, Sequence output, 
				Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet)
		{
			return new SumLatticeBeam (trans, input, output, incrementor, saveXis, outputAlphabet) {{ beamWidth = bw; }};
		}


	}

}	




© 2015 - 2025 Weber Informatics LLC | Privacy Policy