All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.FeatureTransducer Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

There is a newer version: 2.0.12
Show newest version
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




/** 
   @author Andrew McCallum [email protected]
 */

package cc.mallet.fst;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Multinomial;
import cc.mallet.types.Sequence;

import cc.mallet.util.MalletLogger;

public class FeatureTransducer extends Transducer
{
	private static Logger logger = MalletLogger.getLogger(FeatureTransducer.class.getName());
	
	// These next two dictionaries may be the same
	Alphabet inputAlphabet;
	Alphabet outputAlphabet;
	ArrayList states = new ArrayList ();
	ArrayList initialStates = new ArrayList ();
	HashMap name2state = new HashMap ();
	Multinomial.Estimator initialStateCounts;
	Multinomial.Estimator finalStateCounts;
	boolean trainable = false;

	public FeatureTransducer (Alphabet inputAlphabet,
														Alphabet outputAlphabet)
	{
		this.inputAlphabet = inputAlphabet;
		this.outputAlphabet = outputAlphabet;
		// xxx When should these be frozen?
	}

	public FeatureTransducer (Alphabet dictionary)
	{
		this (dictionary, dictionary);
	}

	public FeatureTransducer ()
	{
		this (new Alphabet ());
	}

	public Alphabet getInputAlphabet () { return inputAlphabet; }
	public Alphabet getOutputAlphabet () { return outputAlphabet; }
	
	public void addState (String name, double initialWeight, double finalWeight,
												int[] inputs, int[] outputs, double[] weights,
												String[] destinationNames)
	{
		if (name2state.get(name) != null)
			throw new IllegalArgumentException ("State with name `"+name+"' already exists.");
		State s = new State (name, states.size(), initialWeight, finalWeight,
												 inputs, outputs, weights, destinationNames, this);
		states.add (s);
		if (initialWeight < IMPOSSIBLE_WEIGHT)
			initialStates.add (s);
		name2state.put (name, s);
		setTrainable (false);
	}

	public void addState (String name, double initialWeight, double finalWeight,
												Object[] inputs, Object[] outputs, double[] weights,
												String[] destinationNames)
	{
		this.addState (name, initialWeight, finalWeight,
									 inputAlphabet.lookupIndices (inputs, true),
									 outputAlphabet.lookupIndices (outputs, true),
									 weights, destinationNames);
	}

	public int numStates () { return states.size(); }

	public Transducer.State getState (int index) {
		return states.get(index); }
	
	public Iterator initialStateIterator () { return initialStates.iterator (); }

	public boolean isTrainable () { return trainable; }

	public void setTrainable (boolean f)
	{
		trainable = f;
		if (f) {
			// This wipes away any previous counts we had.
			// It also potentially allocates an esimator of a new size if
			// the number of states has increased.
			initialStateCounts = new Multinomial.LaplaceEstimator (states.size());
			finalStateCounts = new Multinomial.LaplaceEstimator (states.size());
		} else {
			initialStateCounts = null;
			finalStateCounts = null;
		}
		for (int i = 0; i < numStates(); i++)
			((State)getState(i)).setTrainable(f);
	}

	public void reset ()
	{
		if (trainable) {
			initialStateCounts.reset ();
			finalStateCounts.reset ();
			for (int i = 0; i < numStates(); i++)
				((State)getState(i)).reset ();
		}
	}

	public void estimate ()
	{
		if (initialStateCounts == null || finalStateCounts == null)
			throw new IllegalStateException ("This transducer not currently trainable.");
		Multinomial initialStateDistribution = initialStateCounts.estimate ();
		Multinomial finalStateDistribution = finalStateCounts.estimate ();
		for (int i = 0; i < states.size(); i++) {
			State s = states.get (i);
			s.initialWeight = initialStateDistribution.logProbability (i);
			s.finalWeight = finalStateDistribution.logProbability (i);
			s.estimate ();
		}
	}

  // Note that this is a non-static inner class, so we have access to all of
  // FeatureTransducer's instance variables.
	public class State extends Transducer.State
	{
		String name;
		int index;
		double initialWeight, finalWeight;
		Transition[] transitions;
		gnu.trove.TIntObjectHashMap input2transitions;
		Multinomial.Estimator transitionCounts;
		FeatureTransducer transducer;

		// Note that you cannot add transitions to a state once it is created.
		protected State (String name, int index, double initialWeight, double finalWeight,
										 int[] inputs, int[] outputs, double[] weights,
										 String[] destinationNames, FeatureTransducer transducer)
		{
			assert (inputs.length == outputs.length
							&& inputs.length == weights.length
							&& inputs.length == destinationNames.length);
			this.transducer = transducer;
			this.name = name;
			this.index = index;
			this.initialWeight = initialWeight;
			this.finalWeight = finalWeight;
			this.transitions = new Transition[inputs.length];
			this.input2transitions = new gnu.trove.TIntObjectHashMap ();
			transitionCounts = null;
			for (int i = 0; i < inputs.length; i++) {
				// This constructor places the transtion into this.input2transitions
				transitions[i] = new Transition (inputs[i], outputs[i],
																				 weights[i], this, destinationNames[i]);
				transitions[i].index = i;
			}
		}
		
		public Transducer getTransducer () { return transducer; } 
		public double getInitialWeight () { return initialWeight; }
		public double getFinalWeight () { return finalWeight; }
		public void setInitialWeight (double v) { initialWeight = v; }
		public void setFinalWeight (double v) { finalWeight = v; }

		private void setTrainable (boolean f)
		{
			if (f)
				transitionCounts = new Multinomial.LaplaceEstimator (transitions.length);
			else
				transitionCounts = null;
		}

		// Temporarily here for debugging
		public Multinomial.Estimator getTransitionEstimator()
		{
			return transitionCounts;
		}

		private void reset ()
		{
			if (transitionCounts != null)
				transitionCounts.reset();
		}

		public int getIndex () { return index; }

		public Transducer.TransitionIterator transitionIterator (Sequence input,
																														 int inputPosition,
																														 Sequence output,
																														 int outputPosition)
		{
			if (inputPosition < 0 || outputPosition < 0 || output != null)
				throw new UnsupportedOperationException ("Not yet implemented.");
			if (input == null)
				return transitionIterator ();
      return transitionIterator (input, inputPosition);
		}

		public Transducer.TransitionIterator transitionIterator (Sequence inputSequence,
																														 int inputPosition)
		{
			int inputIndex = inputAlphabet.lookupIndex (inputSequence.get(inputPosition), false);
			if (inputIndex == -1)
				throw new IllegalArgumentException ("Input not in dictionary.");
			return transitionIterator (inputIndex);
		}

		public Transducer.TransitionIterator transitionIterator (Object o)
		{
			int inputIndex = inputAlphabet.lookupIndex (o, false);
			if (inputIndex == -1)
				throw new IllegalArgumentException ("Input not in dictionary.");
			return transitionIterator (inputIndex);
		}
		
		public Transducer.TransitionIterator transitionIterator (int input)
		{
			return new TransitionIterator (this, input);
		}
		
		public Transducer.TransitionIterator transitionIterator ()
		{
			return new TransitionIterator (this);
		}

		public String getName ()
		{
			return name;
		}

		public void incrementInitialCount (double count)
		{
			if (initialStateCounts == null)
				throw new IllegalStateException ("Transducer is not currently trainable.");
			initialStateCounts.increment (index, count);
		}

		public void incrementFinalCount (double count)
		{
			if (finalStateCounts == null)
				throw new IllegalStateException ("Transducer is not currently trainable.");
			finalStateCounts.increment (index, count);
		}

		private void estimate ()
		{
			if (transitionCounts == null)
				throw new IllegalStateException ("Transducer is not currently trainable.");
			Multinomial transitionDistribution = transitionCounts.estimate ();
			for (int i = 0; i < transitions.length; i++)
				transitions[i].weight = transitionDistribution.logProbability (i);
		}

		private static final long serialVersionUID = 1;
	}

	@SuppressWarnings("serial")
  protected class TransitionIterator extends Transducer.TransitionIterator
	{
		// If "index" is >= -1 we are going through all FeatureState.transitions[] by index.
		// If "index" is -2, we are following the chain of FeatureTransition.nextWithSameInput,
		// and "transition" is already initialized to the first transition.
		// If "index" is -3, we are following the chain of FeatureTransition.nextWithSameInput,
		// and the next transition should be found by following the chain.
		int index;
		Transition transition;
		State source;
		int input;

		// Iterate through all transitions, independent of input
		public TransitionIterator (State source)
		{
			//System.out.println ("FeatureTransitionIterator over all");
			this.source = source;
			this.input = -1;
			this.index = -1;
			this.transition = null;
		}
		
		public TransitionIterator (State source, int input)
		{
			//System.out.println ("SymbolTransitionIterator over "+input);
			this.source = source;
			this.input = input;
			this.index = -2;
			this.transition = (Transition) source.input2transitions.get (input);
		}

		public boolean hasNext ()
		{
			if (index >= -1) {
				//System.out.println ("hasNext index " + index);
				return (index < source.transitions.length-1);
			}
      return (index == -2 ? transition != null : transition.nextWithSameInput != null);
		};

		public Transducer.State nextState ()
		{
			if (index >= -1)
				transition = source.transitions[++index];
			else if (index == -2)
				index = -3;
			else
				transition = transition.nextWithSameInput;
			return transition.getDestinationState();
		}

		public int getIndex () { return index; }
		public Object getInput () { return inputAlphabet.lookupObject(transition.input); }
		public Object getOutput () { return outputAlphabet.lookupObject(transition.output); }
		public double getWeight () { return transition.weight; }
		public Transducer.State getSourceState () { return source; }
		public Transducer.State getDestinationState () {
			return transition.getDestinationState ();	}

		public void incrementCount (double count) {
			logger.info ("FeatureTransducer incrementCount "+count);
			source.transitionCounts.increment (transition.index, count); }
		
	}

	// Note: this class has a natural ordering that is inconsistent with equals.
	protected class Transition
	{
		int input, output;
		double weight;
		int index;
		String destinationName;
		State destination = null;
		Transition nextWithSameInput;

		public Transition (int input, int output, double weight,
											 State sourceState, String destinationName)
		{
			this.input = input;
			this.output = output;
			this.weight = weight;
			this.nextWithSameInput = (Transition) sourceState.input2transitions.get (input);
			sourceState.input2transitions.put (input, this);
			// this.index is set by the caller of this constructor
			this.destinationName = destinationName;
		}

		public State getDestinationState ()
		{
			if (destination == null) {
				destination = name2state.get (destinationName);
				assert (destination != null);
			}
			return destination;
		}

	}

	private static final long serialVersionUID = 1;

	

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy