All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.SumLatticeScaling Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.fst;

import java.io.IOException;

import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.logging.Logger;

import cc.mallet.fst.Transducer.State;
import cc.mallet.fst.Transducer.TransitionIterator;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;

public class SumLatticeScaling implements SumLattice {
	private static Logger logger = MalletLogger
			.getLogger(SumLatticeScaling.class.getName());
	protected static boolean saveXis = false;

	// "ip" == "input position", "op" == "output position", "i" == "state index"
	@SuppressWarnings("unchecked")
	Sequence input, output;
	Transducer t;
	double totalWeight;
	LatticeNode[][] nodes; // indexed by ip,i
	double[] alphaLogScaling, betaLogScaling;
	double zLogScaling;
	int latticeLength;
	double[][] gammas; // indexed by ip,i
	double[][][] xis; // indexed by ip,i,j; saved only if saveXis is true;

	// Ensure that instances cannot easily be created by a zero arg constructor.
	protected SumLatticeScaling() {
	}

	protected LatticeNode getLatticeNode(int ip, int stateIndex) {
		if (nodes[ip][stateIndex] == null)
			nodes[ip][stateIndex] = new LatticeNode(ip, t.getState(stateIndex));
		return nodes[ip][stateIndex];
	}

	@SuppressWarnings("unchecked")
	public SumLatticeScaling(Transducer trans, Sequence input) {
		this(trans, input, null, (Transducer.Incrementor) null, saveXis, null);
	}

	@SuppressWarnings("unchecked")
	public SumLatticeScaling(Transducer trans, Sequence input, boolean saveXis) {
		this(trans, input, null, (Transducer.Incrementor) null, saveXis, null);
	}

	@SuppressWarnings("unchecked")
	public SumLatticeScaling(Transducer trans, Sequence input,
			Transducer.Incrementor incrementor) {
		this(trans, input, null, incrementor, saveXis, null);
	}

	@SuppressWarnings("unchecked")
	public SumLatticeScaling(Transducer trans, Sequence input, Sequence output) {
		this(trans, input, output, (Transducer.Incrementor) null, saveXis, null);
	}

	// You may pass null for output, meaning that the lattice
	// is not constrained to match the output
	@SuppressWarnings("unchecked")
	public SumLatticeScaling(Transducer trans, Sequence input, Sequence output,
			Transducer.Incrementor incrementor) {
		this(trans, input, output, incrementor, saveXis, null);
	}

	@SuppressWarnings("unchecked")
	public SumLatticeScaling(Transducer trans, Sequence input, Sequence output,
			Transducer.Incrementor incrementor, LabelAlphabet outputAlphabet) {
		this(trans, input, output, incrementor, saveXis, outputAlphabet);
	}

	// You may pass null for output, meaning that the lattice
	// is not constrained to match the output
	@SuppressWarnings("unchecked")
	public SumLatticeScaling(Transducer trans, Sequence input, Sequence output,
			Transducer.Incrementor incrementor, boolean saveXis) {
		this(trans, input, output, incrementor, saveXis, null);
	}

	@SuppressWarnings("unchecked")
	public SumLatticeScaling(Transducer trans, Sequence input, Sequence output,
			Transducer.Incrementor incrementor, boolean saveXis,
			LabelAlphabet outputAlphabet) {
		assert (output == null || input.size() == output.size());

		// Initialize some structures
		this.t = trans;
		this.input = input;
		this.output = output;
		latticeLength = input.size() + 1;
		int numStates = t.numStates();
		nodes = new LatticeNode[latticeLength][numStates];
		alphaLogScaling = new double[latticeLength];
		betaLogScaling = new double[latticeLength];
		gammas = new double[latticeLength][numStates];
		if (saveXis)
			xis = new double[latticeLength][numStates][numStates];

		double outputCounts[][] = null;
		if (outputAlphabet != null)
			outputCounts = new double[latticeLength][outputAlphabet.size()];

		for (int ip = 0; ip < latticeLength; ip++) {
			alphaLogScaling[ip] = 0.0;
			betaLogScaling[ip] = 0.0;
			for (int i = 0; i < numStates; i++) {
				gammas[ip][i] = Transducer.IMPOSSIBLE_WEIGHT;
				if (saveXis)
					for (int j = 0; j < numStates; j++)
						xis[ip][i][j] = Transducer.IMPOSSIBLE_WEIGHT;
			}
		}

		// Forward pass
		logger.fine("Starting Foward pass");
		boolean atLeastOneInitialState = false;
		for (int i = 0; i < numStates; i++) {
			double initialWeight = t.getState(i).getInitialWeight();
			if (initialWeight > Transducer.IMPOSSIBLE_WEIGHT) {
				getLatticeNode(0, i).alpha = Math.exp(initialWeight);
				atLeastOneInitialState = true;
			}
		}
		rescaleAlphas(0);
		if (atLeastOneInitialState == false)
			logger.warning("There are no starting states!");

		for (int ip = 0; ip < latticeLength - 1; ip++) {
			for (int i = 0; i < numStates; i++) {
				if (isInvalidNode(ip, i))
					continue;
				State s = t.getState(i);
				TransitionIterator iter = s.transitionIterator(input, ip,
						output, ip);
				while (iter.hasNext()) {
					State destination = iter.next();
					LatticeNode destinationNode = getLatticeNode(ip + 1,
							destination.getIndex());
					if (Double.isNaN(destinationNode.alpha))
						destinationNode.alpha = 0;
					destinationNode.output = iter.getOutput();
					double transitionWeight = iter.getWeight();
					destinationNode.alpha += nodes[ip][i].alpha
							* Math.exp(transitionWeight);
				}
			}
			// re-scale alphas to so that \sum_i \alpha[ip][i] = 1
			rescaleAlphas(ip + 1);
		}

		// Calculate total weight of Lattice. This is the normalizer
		double Z = Double.NaN;
		for (int i = 0; i < numStates; i++)
			if (nodes[latticeLength - 1][i] != null) {
				if (Double.isNaN(Z))
					Z = 0;
				Z += nodes[latticeLength - 1][i].alpha
						* Math.exp(t.getState(i).getFinalWeight());
			}
		zLogScaling = alphaLogScaling[latticeLength - 1];

		if (Double.isNaN(Z)) {
			totalWeight = Transducer.IMPOSSIBLE_WEIGHT;
			return;
		} else
			totalWeight = Math.log(Z) + zLogScaling;

		// Backward pass
		for (int i = 0; i < numStates; i++)
			if (nodes[latticeLength - 1][i] != null) {
				State s = t.getState(i);
				nodes[latticeLength - 1][i].beta = Math.exp(s.getFinalWeight());
				double gamma = nodes[latticeLength - 1][i].alpha
						* nodes[latticeLength - 1][i].beta / Z;
				gammas[latticeLength - 1][i] = Math.log(gamma);
				if (incrementor != null) {
					double p = gamma;
					assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p
							+ ", gamma=" + gammas[latticeLength - 1][i];
					incrementor.incrementFinalState(s, p);
				}
			}
		rescaleBetas(latticeLength - 1);

		for (int ip = latticeLength - 2; ip >= 0; ip--) {
			for (int i = 0; i < numStates; i++) {
				if (isInvalidNode(ip, i))
					continue;
				State s = t.getState(i);
				TransitionIterator iter = s.transitionIterator(input, ip,
						output, ip);
				double logScaling = alphaLogScaling[ip]
						+ betaLogScaling[ip + 1] - zLogScaling;
				double pscaling = Math.exp(logScaling);
				while (iter.hasNext()) {
					State destination = iter.next();
					int j = destination.getIndex();
					LatticeNode destinationNode = nodes[ip + 1][j];
					if (destinationNode != null) {
						double transitionWeight = iter.getWeight();
						if (Double.isNaN(nodes[ip][i].beta))
							nodes[ip][i].beta = 0;
						double transitionProb = Math.exp(transitionWeight);
						nodes[ip][i].beta += destinationNode.beta
								* transitionProb;
						double xi = nodes[ip][i].alpha * transitionProb
								* nodes[ip + 1][j].beta / Z;
						if (saveXis)
							xis[ip][i][j] = Math.log(xi) + logScaling;
						if (incrementor != null || outputAlphabet != null) {
							double p = xi * pscaling;
							assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p
									+ ", xis[" + ip + "][" + i + "][" + j
									+ "]=" + xi;
							if (incrementor != null)
								incrementor.incrementTransition(iter, p);
							if (outputAlphabet != null) {
								int outputIndex = outputAlphabet.lookupIndex(
										iter.getOutput(), false);
								assert (outputIndex >= 0);
								outputCounts[ip][outputIndex] += p;
							}
						}
					}
				}
				gammas[ip][i] = Math.log(nodes[ip][i].alpha * nodes[ip][i].beta
						/ Z)
						+ logScaling;
			}
			// re-scale betas so that they are normalized
			rescaleBetas(ip);
		}
		if (incrementor != null)
			for (int i = 0; i < numStates; i++) {
				double p = Math.exp(gammas[0][i]);
				assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p;
				incrementor.incrementInitialState(t.getState(i), p);
			}
	}

	private boolean isInvalidNode(int ip, int i) {
		return nodes[ip][i] == null || Double.isNaN(nodes[ip][i].alpha);
	}

	private void rescaleAlphas(int ip) {
		double sumAlpha = 0;
		for (int i = 0; i < t.numStates(); i++) {
			if (!isInvalidNode(ip, i))
				sumAlpha += nodes[ip][i].alpha;
		}
		assert sumAlpha > 0 : "Invalid sum over alphas for ip=" + ip;
		alphaLogScaling[ip] = Math.log(sumAlpha)
				+ (ip == 0 ? 0 : alphaLogScaling[ip - 1]);
		for (int i = 0; i < t.numStates(); i++) {
			if (!isInvalidNode(ip, i))
				nodes[ip][i].alpha /= sumAlpha;
		}
	}

	private void rescaleBetas(int ip) {
		double sumBeta = 0;
		for (int i = 0; i < t.numStates(); i++) {
			if (!isInvalidNode(ip, i))
				sumBeta += nodes[ip][i].beta;
		}
		assert sumBeta > 0 : "Invalid sum over betas for ip=" + ip;
		betaLogScaling[ip] = Math.log(sumBeta)
				+ (ip == latticeLength - 1 ? 0 : betaLogScaling[ip + 1]);
		for (int i = 0; i < t.numStates(); i++) {
			if (!isInvalidNode(ip, i))
				nodes[ip][i].beta /= sumBeta;
		}
	}

	public double[][][] getXis() {
		return xis;
	}

	public double[][] getGammas() {
		return gammas;
	}

	public double getTotalWeight() {
		return totalWeight;
	}

	public double getGammaWeight(int inputPosition, State s) {
		return gammas[inputPosition][s.getIndex()];
	}

	public double getGammaWeight(int inputPosition, int stateIndex) {
		return gammas[inputPosition][stateIndex];
	}

	public double getGammaProbability(int inputPosition, State s) {
		return Math.exp(gammas[inputPosition][s.getIndex()]);
	}

	public double getGammaProbability(int inputPosition, int stateIndex) {
		return getGammaProbability(inputPosition, t.getState(stateIndex));
	}

	public double getXiProbability(int ip, State s1, State s2) {
		return Math.exp(getXiWeight(ip, s1, s2));
	}

	public double getXiWeight(int ip, State s1, State s2) {
		if (xis == null)
			throw new IllegalStateException("xis were not saved.");
		int i = s1.getIndex();
		int j = s2.getIndex();
		return xis[ip][i][j];
	}

	public int length() {
		return latticeLength;
	}

	public double getAlpha(int ip, State s) {
		LatticeNode node = getLatticeNode(ip, s.getIndex());
		return node.alpha * Math.exp(alphaLogScaling[ip]);
	}

	public double getBeta(int ip, State s) {
		LatticeNode node = getLatticeNode(ip, s.getIndex());
		return node.beta * Math.exp(betaLogScaling[ip]);
	}

	public LabelVector getLabelingAtPosition(int outputPosition) {
		throw new RuntimeException("Not implemented for SumLatticeScaling!");
	}
	
	public Sequence getInput() { 
	  return input;
	}

	public Transducer getTransducer() {
		return t;
	}

	protected class LatticeNode {
		int inputPosition;
		State state;
		Object output;
		double alpha = Double.NaN;
		double beta = Double.NaN;

		LatticeNode(int inputPosition, State state) {
			this.inputPosition = inputPosition;
			this.state = state;
		}
	}

	public static class Factory extends SumLatticeFactory implements
			Serializable {
		@SuppressWarnings("unchecked")
		public SumLattice newSumLattice(Transducer trans, Sequence input,
				Sequence output, Transducer.Incrementor incrementor,
				boolean saveXis, LabelAlphabet outputAlphabet) {
			return new SumLatticeScaling(trans, input, output, incrementor,
					saveXis, outputAlphabet);
		}

		private static final long serialVersionUID = 1;
		private static final int CURRENT_SERIAL_VERSION = 1;

		private void writeObject(ObjectOutputStream out) throws IOException {
			out.writeInt(CURRENT_SERIAL_VERSION);
		}

		private void readObject(ObjectInputStream in) throws IOException,
				ClassNotFoundException {
			@SuppressWarnings("unused")
			int version = in.readInt();
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy