All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.PCFGLA.Grammar Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
package edu.berkeley.nlp.PCFGLA;

import edu.berkeley.nlp.PCFGLA.smoothing.*;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.*;
import edu.berkeley.nlp.util.PriorityQueue;

import java.io.IOException;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.*;

/**
 * Simple implementation of a PCFG grammar, offering the ability to look up
 * rules by their child symbols. Rule probability estimates are just relative
 * frequency estimates off of training trees.
 */
public class Grammar implements java.io.Serializable {
	
	/**
	 * @author leon
	 *
	 */
	public static enum RandomInitializationType {
		INITIALIZE_WITH_SMALL_RANDOMIZATION,
		INITIALIZE_LIKE_MMT //initialize like in the Matzuyaki, Miyao, and Tsujii paper 
	}
	
	public static class RuleNotFoundException extends Exception {
		private static final long serialVersionUID = 2L;
	}
	
	public int finalLevel;

	public boolean[] isGrammarTag;
	public boolean useEntropicPrior = false;
	
	private List[] binaryRulesWithParent;
	private List[] binaryRulesWithLC;
	private List[] binaryRulesWithRC;
	private BinaryRule[][] splitRulesWithLC;
	private BinaryRule[][] splitRulesWithRC;
	private BinaryRule[][] splitRulesWithP;
	public List[] unaryRulesWithParent;
	public List[] unaryRulesWithC;
  private List[] sumProductClosedUnaryRulesWithParent;
	
	/** the number of states */
	public short numStates;
	
	/** the number of substates per state */
	public short[] numSubStates;
	
	//private List allRules;
	
	public Map binaryRuleMap;
	BinaryRule bSearchRule;
	public Map unaryRuleMap;
	UnaryRule uSearchRule;
	
	UnaryCounterTable unaryRuleCounter = null;
	
	BinaryCounterTable binaryRuleCounter = null;
	
	CounterMap symbolCounter = new CounterMap();
	
	private static final long serialVersionUID = 1L;
	
	protected Numberer tagNumberer;
	
	public List[] closedSumRulesWithParent = null;
	public List[] closedSumRulesWithChild = null;
	
	public List[] closedViterbiRulesWithParent = null;
	public List[] closedViterbiRulesWithChild = null;
	
	public UnaryRule[][] closedSumRulesWithP = null;
	public UnaryRule[][] closedSumRulesWithC = null;
	
	public UnaryRule[][] closedViterbiRulesWithP = null;
	public UnaryRule[][] closedViterbiRulesWithC = null;
	
	private Map bestSumRulesUnderMax = null;
	private Map bestViterbiRulesUnderMax = null;
	public double threshold;
	
	public Smoother smoother = null;
	
	/** A policy giving what state to go to next, starting from a
	 * given state, going to a given state.
	 * This array is indexed by the start state, the end state,
	 * the start substate, and the end substate.
	 */
	private int [][] closedViterbiPaths = null;
	private int [][] closedSumPaths=null;
	
	public boolean findClosedPaths;
	
	/** If we are in logarithm mode, then this grammar's scores are all
	 * given as logarithms.  The default is to have a score plus a scale factor.
	 */
	boolean logarithmMode;
	
	public Tree[] splitTrees;
		
	public void clearUnaryIntermediates(){
		ArrayUtil.fill(closedSumPaths,0);
		ArrayUtil.fill(closedViterbiPaths, 0);
	}
	
	
	public void addBinary(BinaryRule br) {
		// System.out.println("BG adding rule " + br);
		binaryRulesWithParent[br.parentState].add(br);
		binaryRulesWithLC[br.leftChildState].add(br);
		binaryRulesWithRC[br.rightChildState].add(br);
		//allRules.add(br);
		binaryRuleMap.put(br, br);
	}
	
	public void addUnary(UnaryRule ur) {
		// System.out.println(" UG adding rule " + ur);
		//closeRulesUnderMax(ur);
		if (!unaryRulesWithParent[ur.parentState].contains(ur)) {
			unaryRulesWithParent[ur.parentState].add(ur);
			unaryRulesWithC[ur.childState].add(ur);
			//allRules.add(ur);
			unaryRuleMap.put(ur, ur);
		}
	}

	public Numberer getTagNumberer() {
		return tagNumberer;
	}

//	@SuppressWarnings("unchecked")
//	public List getBinaryRulesByParent(int state) {
//		if (state >= binaryRulesWithParent.length) {
//			return Collections.EMPTY_LIST;
//		}
//		return binaryRulesWithParent[state];
//	}
//	
  @SuppressWarnings("unchecked")
  public List getUnaryRulesByParent(int state) {
    if (state >= unaryRulesWithParent.length) {
      return Collections.EMPTY_LIST;
    }
    return unaryRulesWithParent[state];
  }
  
  @SuppressWarnings("unchecked")
  public List[] getSumProductClosedUnaryRulesByParent() {
    return sumProductClosedUnaryRulesWithParent;
  }
  
	@SuppressWarnings("unchecked")
	public List getBinaryRulesByLeftChild(int state) {
//		System.out.println("getBinaryRulesByLeftChild not supported anymore.");
//		return null;
		if (state >= binaryRulesWithLC.length) {
			return Collections.EMPTY_LIST;
		}
		return binaryRulesWithLC[state];
	}
	
	@SuppressWarnings("unchecked")
	public List getBinaryRulesByRightChild(int state) {
//		System.out.println("getBinaryRulesByRightChild not supported anymore.");
//		return null;
		if (state >= binaryRulesWithRC.length) {
			return Collections.EMPTY_LIST;
		}
		return binaryRulesWithRC[state];
	}
	
	@SuppressWarnings("unchecked")
	public List getUnaryRulesByChild(int state) {
//		System.out.println("getUnaryRulesByChild not supported anymore.");
//		return null;
 		if (state >= unaryRulesWithC.length) {
			return Collections.EMPTY_LIST;
		}
		return unaryRulesWithC[state];
	}
	
	public String toString_old() {
		/*StringBuilder sb = new StringBuilder();
		List ruleStrings = new ArrayList();
		for (int state = 0; state < numStates; state++) {
			List leftRules = getBinaryRulesByLeftChild(state);
			for (BinaryRule r : leftRules) {
				ruleStrings.add(r.toString());
			}
		}
		for (int state = 0; state < numStates; state++) {
			UnaryRule[] unaries = getClosedViterbiUnaryRulesByChild(state);
			for (int r = 0; r < unaries.length; r++) {
				UnaryRule ur = unaries[r];
				ruleStrings.add(ur.toString());
			}
		}
		for (String ruleString : CollectionUtils.sort(ruleStrings)) {
			sb.append(ruleString);
			sb.append("\n");
		}*/
		return null;//sb.toString();
	}

  public void writeData(Writer w) throws IOException {
  	finalLevel = (short)(Math.log(numSubStates[1])/Math.log(2));
  	PrintWriter out = new PrintWriter(w);
		for (int state = 0; state < numStates; state++) {
			BinaryRule[] parentRules = this.splitRulesWithP(state);
			for (int i = 0; i < parentRules.length; i++) {
				BinaryRule r = parentRules[i];
				out.print(r.toString());
			}
		}
		for (int state = 0; state < numStates; state++) {
			UnaryRule[] unaries = this.getClosedViterbiUnaryRulesByParent(state);
			for (int r = 0; r < unaries.length; r++) {
				UnaryRule ur = unaries[r];
				out.print(ur.toString());
			}
		}
    out.flush();
  }

	
	public String toString() {
		//splitRules();
		StringBuilder sb = new StringBuilder();
		List ruleStrings = new ArrayList();
		for (int state = 0; state < numStates; state++) {
			BinaryRule[] parentRules = this.splitRulesWithP(state);
			for (int i = 0; i < parentRules.length; i++) {
				BinaryRule r = parentRules[i];
				ruleStrings.add(r.toString());
			}
		}
		for (int state = 0; state < numStates; state++) {
			UnaryRule[] unaries = this.getClosedSumUnaryRulesByParent(state);
			//this.getClosedSumUnaryRulesByParent(state);//
			for (int r = 0; r < unaries.length; r++) {
				UnaryRule ur = unaries[r];
				ruleStrings.add(ur.toString());
			}
//			UnaryRule[] unaries2 = this.getClosedViterbiUnaryRulesByParent(state);
//			for (int r = 0; r < unaries2.length; r++) {
//				UnaryRule ur = unaries2[r];
//				ruleStrings.add(ur.toString());
//			}
		}
		for (String ruleString : CollectionUtils.sort(ruleStrings)) {
			sb.append(ruleString);
			//sb.append("\n");
		}
		return sb.toString();
	}

	public int getNumberOfRules() {
		int nRules = 0;
		for (int state = 0; state < numStates; state++) {
			BinaryRule[] parentRules = this.splitRulesWithP(state);
			for (int i = 0; i < parentRules.length; i++) {
				BinaryRule bRule = parentRules[i];
				double[][][] scores = bRule.getScores2();
				for (int j=0; j unaries = this.getUnaryRulesByParent(state);
//			for (UnaryRule uRule : unaries){
				if (uRule.childState==uRule.parentState) continue;
				double[][] scores = uRule.getScores2();
				for (int j=0; j unaries = this.getUnaryRulesByParent(state1);
			for (UnaryRule uRule : unaries){
				UnaryRule uRule2 = (UnaryRule)unaryRuleMap.get(uRule);
				if (!uRule.getScores2().equals(uRule2.getScores2()))
					System.out.print("BY PARENT:\n" +uRule + "" + uRule2+ "\n");
			}
		}
		//System.out.println("VITERBI CLOSED");
		for (int state1 = 0; state1 < numStates; state1++) {
			UnaryRule[] unaries = this.getClosedViterbiUnaryRulesByParent(state1);
			for (int r = 0; r < unaries.length; r++) {
				UnaryRule uRule = unaries[r];
				//System.out.print(uRule);
				UnaryRule uRule2 = (UnaryRule)unaryRuleMap.get(uRule);
				if (unariesAreNotEqual(uRule,uRule2))
					System.out.print("VITERBI CLOSED:\n" + uRule + "" + uRule2+ "\n");
			}
		}
		
		/*System.out.println("FROM RULE MAP");
		for (UnaryRule uRule : unaryRuleMap.keySet()){
			System.out.print(uRule);
		}*/
		
		//System.out.println("AND NOW THE BINARIES");
		//System.out.println("BY PARENT");
		for (int state1 = 0; state1 < numStates; state1++) {
			BinaryRule[] parentRules = this.splitRulesWithP(state1);
			for (int i = 0; i < parentRules.length; i++) {
				BinaryRule bRule = parentRules[i];
				BinaryRule bRule2 = (BinaryRule)binaryRuleMap.get(bRule);
				if (!bRule.getScores2().equals(bRule2.getScores2()))
					System.out.print("BINARY: "+bRule + "" + bRule2 + "\n");
			}
		}
		/*
		System.out.println("FROM RULE MAP");
		for (BinaryRule bRule : binaryRuleMap.keySet()){
			System.out.print(bRule);
		}*/
		
		
	}
	
	public boolean unariesAreNotEqual(UnaryRule u1, UnaryRule u2){
		// two cases:
		// 1. u2 is null and u1 is a selfRule
		if (u2==null){
			return false;
			/*double[][] s1 = u1.getScores2();
			for (int i=0; i();
		unaryRuleMap = new HashMap();
		//allRules = new ArrayList();
		bestSumRulesUnderMax = new HashMap();
		bestViterbiRulesUnderMax = new HashMap();
		binaryRulesWithParent = new List[numStates];
		binaryRulesWithLC = new List[numStates];
		binaryRulesWithRC = new List[numStates];
		unaryRulesWithParent = new List[numStates];
		unaryRulesWithC = new List[numStates];
		closedSumRulesWithParent = new List[numStates];
		closedSumRulesWithChild = new List[numStates];
		closedViterbiRulesWithParent = new List[numStates];
		closedViterbiRulesWithChild = new List[numStates];
		isGrammarTag = new boolean[numStates];
		
		//if (findClosedPaths) {
			closedViterbiPaths = new int[numStates][numStates];
		//}
		closedSumPaths = new int[numStates][numStates];
		
		for (short s = 0; s < numStates; s++) {
			binaryRulesWithParent[s] = new ArrayList();
			binaryRulesWithLC[s] = new ArrayList();
			binaryRulesWithRC[s] = new ArrayList();
			unaryRulesWithParent[s] = new ArrayList();
			unaryRulesWithC[s] = new ArrayList();
			closedSumRulesWithParent[s] = new ArrayList();
			closedSumRulesWithChild[s] = new ArrayList();
			closedViterbiRulesWithParent[s] = new ArrayList();
			closedViterbiRulesWithChild[s] = new ArrayList();
			
			double[][] scores = new double[numSubStates[s]][numSubStates[s]];
			for (int i=0; i>
	 * trainTrees, Grammar old_grammar) { this.tagNumberer =
	 * Numberer.getGlobalNumberer("tags"); unaryRuleCounter = new Counter();
	 * binaryRuleCounter = new Counter(); symbolCounter = new
	 * CounterMap(); numStates = tagNumberer.total();
	 * numSubStates = old_grammar.numSubStates; init();
	 * 
	 * for (Tree trainTree : trainTrees) { tallyStateSetTree(trainTree,
	 * old_grammar); } for (UnaryRule unaryRule : unaryRuleCounter.keySet()) {
	 * double unaryProbability = unaryRuleCounter.getCount(unaryRule) /
	 * symbolCounter.getCount(unaryRule.getParentState(),unaryRule.getParentSubState());
	 * unaryRule.setScore(Math.log(unaryProbability)); addUnary(unaryRule); } for
	 * (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double
	 * binaryProbability = binaryRuleCounter.getCount(binaryRule) /
	 * symbolCounter.getCount(binaryRule.getParentState(),binaryRule.getParentSubState());
	 * binaryRule.setScore(Math.log(binaryProbability)); addBinary(binaryRule); } }
	 */
	
	/**
	 * This constructor generates a grammar with the rule probabilities read as
	 * though there were no substates, but with a bit of randomness added. This is
	 * the way we should initialize the EM algorithm.
	 * 
	 * @param trainTrees
	 *          The training trees, which don't need to have their inside-outside
	 *          probabilities calculated correctly.
	 * @param randomness
	 *          The size of the region to be uniformly sampled from in adding
	 *          extra random weight to the rules.
	 */
	/*
	 * comment out unused constructor public Grammar(List>
	 * trainTrees, int[] nSubStates, int maxN, double randomness) {
	 * this.tagNumberer = Numberer.getGlobalNumberer("tags"); unaryRuleCounter =
	 * new Counter(); binaryRuleCounter = new Counter();
	 * symbolCounter = new CounterMap(); numStates =
	 * tagNumberer.total(); numSubStates = nSubStates; maxNumSubStates = maxN;
	 * init();
	 * 
	 * //tally trees as though there were no subsymbols for (Tree
	 * trainTree : trainTrees) { tallyUninitializedStateSetTree(trainTree); }
	 * //add randomness Random random = new Random(); for (UnaryRule unaryRule :
	 * unaryRuleCounter.keySet()) { double r = random.nextDouble()*randomness;
	 * unaryRuleCounter.incrementCount(unaryRule,r); } for (BinaryRule binaryRule :
	 * binaryRuleCounter.keySet()) { double r = random.nextDouble()*randomness;
	 * binaryRuleCounter.incrementCount(binaryRule,r); } //re-tally the parent
	 * counts because adding the randomness ruined them symbolCounter = new
	 * CounterMap(); for (UnaryRule unaryRule :
	 * unaryRuleCounter.keySet()) { symbolCounter.incrementCount(
	 * unaryRule.getParentState(), unaryRule.getParentSubState(),
	 * unaryRuleCounter.getCount(unaryRule)); } for (BinaryRule binaryRule :
	 * binaryRuleCounter.keySet()) {
	 * symbolCounter.incrementCount(binaryRule.getParentState(),binaryRule.getParentSubState(),
	 * binaryRuleCounter.getCount(binaryRule)); } //set the scores of all the
	 * rules based on these counts for (UnaryRule unaryRule :
	 * unaryRuleCounter.keySet()) { double unaryProbability =
	 * unaryRuleCounter.getCount(unaryRule) /
	 * symbolCounter.getCount(unaryRule.getParentState(),
	 * unaryRule.getParentSubState());
	 * unaryRule.setScore(Math.log(unaryProbability)); addUnary(unaryRule); } for
	 * (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double
	 * binaryProbability = binaryRuleCounter.getCount(binaryRule) /
	 * symbolCounter.getCount(binaryRule.getParentState(),binaryRule.getParentSubState());
	 * binaryRule.setScore(Math.log(binaryProbability)); addBinary(binaryRule); } }
	 */
	
	/**
	 * Rather than calling some all-in-one constructor that takes a list of trees
	 * as training data, you call Grammar() to create an empty grammar, call
	 * tallyTree() repeatedly to include all the training data, then call
	 * optimize() to take it into account.
	 * 
	 * @param oldGrammar
	 *          This is the previous grammar. We use this to copy the split trees
	 *          that record how each state is split recursively. These parameters
	 *          are intialized if oldGrammar is null.
	 */
	@SuppressWarnings("unchecked")
	public Grammar(short[] nSubStates, boolean findClosedPaths,
			 Smoother smoother, Grammar oldGrammar, double thresh) {
		this.tagNumberer = Numberer.getGlobalNumberer("tags");
		this.findClosedPaths = findClosedPaths;
		this.smoother = smoother;
		this.threshold = thresh;
		unaryRuleCounter = new UnaryCounterTable(nSubStates);
		binaryRuleCounter = new BinaryCounterTable(nSubStates);
		symbolCounter = new CounterMap();
		numStates = (short)nSubStates.length;
		numSubStates = nSubStates;
		bSearchRule = new BinaryRule((short)0,(short)0,(short)0);
		uSearchRule = new UnaryRule((short)0,(short)0);
		logarithmMode = false;
		if (oldGrammar!=null) {
			splitTrees = oldGrammar.splitTrees;
		} else {
			splitTrees = new Tree[numStates];
			boolean hasAnySplits = false;
			for (int tag=0; !hasAnySplits && tag1;
			}
			for (int tag=0; tag> children = new ArrayList>(numSubStates[tag]);
				if (hasAnySplits) {
					for (short substate=0; substate(substate));
					}
				}
				splitTrees[tag] = new Tree((short)0,children);
			}
		}
		init();
	}

	public void setSmoother(Smoother smoother){
		this.smoother = smoother;
	}
	
	public static double generateMMTRandomNumber(Random r) {
		double f = r.nextDouble();
		f = f*2 - 1;
		f = f*Math.log(3);
		return Math.exp(f);
	}
	
	public void optimize(double randomness) {
//		System.out.print("Optimizing Grammar...");
		init();
//		checkNumberOfSubstates();
		if (randomness > 0.0) {
			Random random = GrammarTrainer.RANDOM;
//			switch (randomInitializationType ) {
//			case INITIALIZE_WITH_SMALL_RANDOMIZATION:
				// add randomness
				for (UnaryRule unaryRule : unaryRuleCounter.keySet()) {
					double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule);
					for (int i = 0; i < unaryCounts.length; i++) {
						if (unaryCounts[i]==null)
							unaryCounts[i] = new double[numSubStates[unaryRule.getParentState()]];
						for (int j = 0; j < unaryCounts[i].length; j++) {
							double r = random.nextDouble() * randomness;
							unaryCounts[i][j] += r;
						}
					}
					unaryRuleCounter.setCount(unaryRule, unaryCounts);
				}
				for (BinaryRule binaryRule : binaryRuleCounter.keySet()) {
					double[][][] binaryCounts = binaryRuleCounter.getCount(binaryRule);
					for (int i = 0; i < binaryCounts.length; i++) {
						for (int j = 0; j < binaryCounts[i].length; j++) {
							if (binaryCounts[i][j]==null)
								binaryCounts[i][j] = new double[numSubStates[binaryRule.getParentState()]];
							for (int k = 0; k < binaryCounts[i][j].length; k++) {
								double r = random.nextDouble() * randomness;
								binaryCounts[i][j][k] += r;
							}
						}
					}
					binaryRuleCounter.setCount(binaryRule, binaryCounts);
				}
//				break;
//			case INITIALIZE_LIKE_MMT:
//				//multiply by a random factor
//				for (UnaryRule unaryRule : unaryRuleCounter.keySet()) {
//					double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule);
//					for (int i = 0; i < unaryCounts.length; i++) {
//						if (unaryCounts[i]==null)
//							continue;
//						for (int j = 0; j < unaryCounts[i].length; j++) {
//							double r = generateMMTRandomNumber(random);
//							unaryCounts[i][j] *= r;
//						}
//					}
//					unaryRuleCounter.setCount(unaryRule, unaryCounts);
//				}
//				for (BinaryRule binaryRule : binaryRuleCounter.keySet()) {
//					double[][][] binaryCounts = binaryRuleCounter.getCount(binaryRule);
//					for (int i = 0; i < binaryCounts.length; i++) {
//						for (int j = 0; j < binaryCounts[i].length; j++) {
//							if (binaryCounts[i][j]==null)
//								continue;
//							for (int k = 0; k < binaryCounts[i][j].length; k++) {
//								double r = generateMMTRandomNumber(random);
//								binaryCounts[i][j][k] *= r;
//							}
//						}
//					}
//					binaryRuleCounter.setCount(binaryRule, binaryCounts);
//				}				
//				break;
//			}
		}

		// smooth
//		if (useEntropicPrior) {
//			System.out.println("\nGrammar uses entropic prior!");
//			normalizeWithEntropicPrior();
//		}
		normalize();
		smooth(false); // this also adds the rules to the proper arrays
//		System.out.println("done.");
	}
	
	public void removeUnlikelyRules(double thresh, double power){
		//System.out.print("Removing everything below "+thresh+" and rasiing rules to the " +power+"th power... ");
		if (isLogarithmMode()) power = Math.log(power);
		int total=0, removed = 0;
		for (int state = 0; state < numStates; state++) {
			for (int r=0; r0){
//			removeUnlikelyRules(threshold);
//			normalize();
//		}
		
		// compress and add the rules
		for (UnaryRule unaryRule : unaryRuleCounter.keySet()) {
			double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule);
			for (int i = 0; i < unaryCounts.length; i++) {
				if (unaryCounts[i]==null)
					continue;
				/** allZero records if all probabilities are 0.  If so,
				 * we want to null out the matrix element.
				 */
				double allZero = 0;
				int j=0;
				while (allZero == 0 && j < unaryCounts[i].length){
					allZero += unaryCounts[i][j++];
				}
				if (allZero==0) {
					unaryCounts[i] = null;
				}
			}
			unaryRule.setScores2(unaryCounts);
			addUnary(unaryRule);
		}
		computePairsOfUnaries();
		for (BinaryRule binaryRule : binaryRuleCounter.keySet()) {
			double[][][] binaryCounts = binaryRuleCounter.getCount(binaryRule);
			for (int i = 0; i < binaryCounts.length; i++) {
				for (int j = 0; j < binaryCounts[i].length; j++) {
					if (binaryCounts[i][j]==null)
						continue;
					/** allZero records if all probabilities are 0.  If so,
					 * we want to null out the matrix element.
					 */
					double allZero = 0;
					int k=0;
					while (allZero == 0 && k < binaryCounts[i][j].length){
						allZero += binaryCounts[i][j][k++];
					}
					if (allZero==0) {
						binaryCounts[i][j] = null;
					}
				}
			}
			binaryRule.setScores2(binaryCounts);
			addBinary(binaryRule);
		}
		// Reset all counters:
		unaryRuleCounter = new UnaryCounterTable(numSubStates);
		binaryRuleCounter = new BinaryCounterTable(numSubStates);
		symbolCounter = new CounterMap();
		/*
		// tally usage of closed unary rule paths
		if (findClosedPaths) {
			int maxSize = numStates * numStates;
			int size = 0;
			for (int i=0; i();
		
	}
	
	/**
	 * Normalize the unary & binary probabilities so that they sum to 1 for each parent.
	 * The binaryRuleCounter and unaryRuleCounter are assumed to contain probabilities,
	 * NOT log probabilities!
	 */
	public void normalize() {
		// tally the parent counts
		tallyParentCounts();
		// turn the rule scores into fractions
		for (UnaryRule unaryRule : unaryRuleCounter.keySet()) {
			double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule);
			int parentState = unaryRule.getParentState();
			int nParentSubStates = numSubStates[parentState];
			int nChildStates = numSubStates[unaryRule.childState];
			double[] parentCount = new double[nParentSubStates];
			for (int i = 0; i < nParentSubStates; i++) {
				parentCount[i] = symbolCounter.getCount(parentState, i);
			}
			boolean allZero = true;
			for (int j = 0; j < nChildStates; j++) {
				if (unaryCounts[j]==null) continue;
				for (int i = 0; i < nParentSubStates; i++) {
					if (parentCount[i]!=0){
						double nVal = (unaryCounts[j][i] / parentCount[i]);
						if (nVal
	 * This assumes that the unaryRuleCounter and binaryRuleCounter contain probabilities,
	 * NOT log probabilities! 
	 */
	private void tallyParentCounts() {
		symbolCounter = new CounterMap();
		for (UnaryRule unaryRule : unaryRuleCounter.keySet()) {
			double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule);
			int parentState = unaryRule.getParentState();
			isGrammarTag[parentState] = true;
			if (unaryRule.childState == parentState) continue;
			int nParentSubStates = numSubStates[parentState];
			double[] sum = new double[nParentSubStates];
			for (int j = 0; j < unaryCounts.length; j++) {
				if (unaryCounts[j]==null) continue;
				for (int i = 0; i < nParentSubStates; i++) {
					double val = unaryCounts[j][i];
					//if (val>=threshold)	
					sum[i] += val;
				}
			}
			for (int i = 0; i < nParentSubStates; i++) {
				symbolCounter.incrementCount(parentState, i, sum[i]);
			}

		}
		for (BinaryRule binaryRule : binaryRuleCounter.keySet()) {
			double[][][] binaryCounts = binaryRuleCounter.getCount(binaryRule);
			int parentState = binaryRule.parentState;
			isGrammarTag[parentState] = true;
			int nParentSubStates = numSubStates[parentState];
			double[] sum = new double[nParentSubStates];
			for (int j = 0; j < binaryCounts.length; j++) {
				for (int k = 0; k < binaryCounts[j].length; k++) {
				if (binaryCounts[j][k]==null) continue;
					for (int i = 0; i < nParentSubStates; i++) {
						double val = binaryCounts[j][k][i]; 
						//if (val>=threshold) 
						sum[i] += val; 
					}
				}
			}
			for (int i = 0; i < nParentSubStates; i++) {
				symbolCounter.incrementCount(parentState, i, sum[i]);
			}
		}
	}
	
	public void tallyStateSetTree(Tree tree, Grammar old_grammar) {
		// Check that the top node is not split (it has only one substate)
		if (tree.isLeaf())
			return;
		if (tree.isPreTerminal())
			return;
		StateSet node = tree.getLabel();
		if (node.numSubStates() != 1) {
			System.err.println("The top symbol is split!");
			System.out.println(tree);
			System.exit(1);
		}
		// The inside score of its only substate is the (log) probability of the
		// tree
		double tree_score = node.getIScore(0);
		int tree_scale = node.getIScale();
		if (tree_score==0){
			System.out.println("Something is wrong with this tree. I will skip it.");
			return;
		}
		tallyStateSetTree(tree, tree_score, tree_scale, old_grammar);
	}
	
	public void tallyStateSetTree(Tree tree, double tree_score, double tree_scale,
			Grammar old_grammar) {
		if (tree.isLeaf())
			return;
		if (tree.isPreTerminal())
			return;
		List> children = tree.getChildren();
		StateSet parent = tree.getLabel();
		short parentState = parent.getState();
		int nParentSubStates = numSubStates[parentState];
		switch (children.size()) {
		case 0:
			// This is a leaf (a preterminal node, if we count the words themselves),
			// nothing to do
			break;
		case 1:
			StateSet child = children.get(0).getLabel();
			short childState = child.getState();
			int nChildSubStates = numSubStates[childState];
			UnaryRule urule = new UnaryRule(parentState, childState);
			double[][] oldUScores = old_grammar.getUnaryScore(urule); // rule score
			double[][] ucounts = unaryRuleCounter.getCount(urule);
			if (ucounts==null) ucounts = new double[nChildSubStates][];
			double scalingFactor = ScalingTools.calcScaleFactor(parent.getOScale()+child.getIScale()-tree_scale);
//			if (scalingFactor==0){
//				System.out.println("p: "+parent.getOScale()+" c: "+child.getIScale()+" t:"+tree_scale);
//			}
			for (short i = 0; i < nChildSubStates; i++) {
				if (oldUScores[i]==null) continue;
				double cIS = child.getIScore(i);
				if (cIS==0) continue;
				if (ucounts[i]==null) ucounts[i] = new double[nParentSubStates];
				for (short j = 0; j < nParentSubStates; j++) {
					double pOS = parent.getOScore(j); // Parent outside score
					if (pOS==0) continue;
					double rS = oldUScores[i][j];
					if (rS==0) continue;
					if (tree_score==0)
						tree_score = 1;
					double logRuleCount = (rS * cIS / tree_score) * scalingFactor * pOS;
					ucounts[i][j] += logRuleCount;
				}
			}
			//urule.setScores2(ucounts);
			unaryRuleCounter.setCount(urule, ucounts);
			break;
		case 2:
			StateSet leftChild = children.get(0).getLabel();
			short lChildState = leftChild.getState();
			StateSet rightChild = children.get(1).getLabel();
			short rChildState = rightChild.getState();
			int nLeftChildSubStates = numSubStates[lChildState];
			int nRightChildSubStates = numSubStates[rChildState];
				//new double[nLeftChildSubStates][nRightChildSubStates][];
			BinaryRule brule = new BinaryRule(parentState, lChildState, rChildState);
			double[][][] oldBScores = old_grammar.getBinaryScore(brule); // rule score
			if (oldBScores==null){
				//rule was not in the grammar
				//parent.setIScores(iScores2);
				//break;
				oldBScores=new double[nLeftChildSubStates][nRightChildSubStates][nParentSubStates];
				ArrayUtil.fill(oldBScores,1.0);
			}
			double[][][] bcounts = binaryRuleCounter.getCount(brule);
			if (bcounts==null) bcounts = new double[nLeftChildSubStates][nRightChildSubStates][];
			scalingFactor = ScalingTools.calcScaleFactor(parent.getOScale()+leftChild.getIScale()+rightChild.getIScale()-tree_scale);
//			if (scalingFactor==0){
//				System.out.println("p: "+parent.getOScale()+" l: "+leftChild.getIScale()+" r:"+rightChild.getIScale()+" t:"+tree_scale);
//			}
			for (short i = 0; i < nLeftChildSubStates; i++) {
				double lcIS = leftChild.getIScore(i);
				if (lcIS==0) continue;
				for (short j = 0; j < nRightChildSubStates; j++) {
					if (oldBScores[i][j]==null) continue;
					double rcIS = rightChild.getIScore(j);
					if (rcIS==0) continue;
					// allocate parent array
					if (bcounts[i][j]==null) bcounts[i][j] = new double[nParentSubStates];
					for (short k = 0; k < nParentSubStates; k++) {
						double pOS = parent.getOScore(k); // Parent outside score
						if (pOS==0) continue;
						double rS = oldBScores[i][j][k];
						if (rS==0) continue;
						if (tree_score==0)
							tree_score = 1;
						double logRuleCount = (rS * lcIS / tree_score) * rcIS * scalingFactor * pOS;
						/*if (logRuleCount == 0) {
							System.out.println("rS "+rS+", lcIS "+lcIS+", rcIS "+rcIS+", tree_score "+tree_score+
									", scalingFactor "+scalingFactor+", pOS "+pOS);
							System.out.println("Possibly underflow?");
						//	logRuleCount = Double.MIN_VALUE;
						}*/
						bcounts[i][j][k] += logRuleCount;
					}
				}
			}
			binaryRuleCounter.setCount(brule, bcounts);
			break;
		default:
			throw new Error("Malformed tree: more than two children");
		}
		
		for (Tree child : children) {
			tallyStateSetTree(child, tree_score, tree_scale, old_grammar);
		}
	}
	
	public void tallyUninitializedStateSetTree(Tree tree) {
		if (tree.isLeaf())
			return;
		// the lexicon handles preterminal nodes
		if (tree.isPreTerminal())
			return;
		List> children = tree.getChildren();
		StateSet parent = tree.getLabel();
		short parentState = parent.getState();
		int nParentSubStates = parent.numSubStates(); //numSubStates[parentState];
		switch (children.size()) {
		case 0:
			// This is a leaf (a preterminal node, if we count the words
			// themselves), nothing to do
			break;
		case 1:
			StateSet child = children.get(0).getLabel();
			short childState = child.getState();
			int nChildSubStates = child.numSubStates(); //numSubStates[childState];
			double[][] counts = new double[nChildSubStates][nParentSubStates];
			UnaryRule urule = new UnaryRule(parentState, childState, counts);
			unaryRuleCounter.incrementCount(urule, 1.0);
			break;
		case 2:
			StateSet leftChild = children.get(0).getLabel();
			short lChildState = leftChild.getState();
			StateSet rightChild = children.get(1).getLabel();
			short rChildState = rightChild.getState();
			int nLeftChildSubStates = leftChild.numSubStates(); //numSubStates[lChildState];
			int nRightChildSubStates = rightChild.numSubStates();// numSubStates[rChildState];
			double[][][] bcounts = new double[nLeftChildSubStates][nRightChildSubStates][nParentSubStates];
			BinaryRule brule = new BinaryRule(parentState, lChildState, rChildState, bcounts);
			binaryRuleCounter.incrementCount(brule, 1.0);
			break;
		default:
			throw new Error("Malformed tree: more than two children");
		}
		
		for (Tree child : children) {
			tallyUninitializedStateSetTree(child);
		}
	}
	
	/*public void tallyChart(Pair chart, double tree_score,	Grammar old_grammar) {
		double[][][][] iScore = chart.getFirst();
		double[][][][] oScore = chart.getSecond();
		if (tree.isLeaf())
			return;
		if (tree.isPreTerminal())
			return;
		List> children = tree.getChildren();
		StateSet parent = tree.getLabel();
		short parentState = parent.getState();
		int nParentSubStates = numSubStates[parentState];
		switch (children.size()) {
		case 0:
			// This is a leaf (a preterminal node, if we count the words themselves),
			// nothing to do
			break;
		case 1:
			StateSet child = children.get(0).getLabel();
			short childState = child.getState();
			int nChildSubStates = numSubStates[childState];
			UnaryRule urule = new UnaryRule(parentState, childState);
			double[][] oldUScores = old_grammar.getUnaryScore(urule); // rule score
			double[][] ucounts = unaryRuleCounter.getCount(urule);
			if (ucounts==null) ucounts = new double[nChildSubStates][];
			double scalingFactor = Math.pow(GrammarTrainer.SCALE,
					parent.getOScale()+child.getIScale()-tree_scale);
			if (scalingFactor==0){
				System.out.println("p: "+parent.getOScale()+" c: "+child.getIScale()+" t:"+tree_scale);
			}
			for (short i = 0; i < nChildSubStates; i++) {
				if (oldUScores[i]==null) continue;
				double cIS = child.getIScore(i);
				if (cIS==0) continue;
				if (ucounts[i]==null) ucounts[i] = new double[nParentSubStates];
				for (short j = 0; j < nParentSubStates; j++) {
					double pOS = parent.getOScore(j); // Parent outside score
					if (pOS==0) continue;
					double rS = oldUScores[i][j];
					if (rS==0) continue;
					if (tree_score==0)
						tree_score = 1;
					double logRuleCount = (rS * cIS / tree_score) * scalingFactor * pOS;
					ucounts[i][j] += logRuleCount;
				}
			}
			//urule.setScores2(ucounts);
			unaryRuleCounter.setCount(urule, ucounts);
			break;
		case 2:
			StateSet leftChild = children.get(0).getLabel();
			short lChildState = leftChild.getState();
			StateSet rightChild = children.get(1).getLabel();
			short rChildState = rightChild.getState();
			int nLeftChildSubStates = numSubStates[lChildState];
			int nRightChildSubStates = numSubStates[rChildState];
				//new double[nLeftChildSubStates][nRightChildSubStates][];
			BinaryRule brule = new BinaryRule(parentState, lChildState, rChildState);
			double[][][] oldBScores = old_grammar.getBinaryScore(brule); // rule score
			if (oldBScores==null){
				//rule was not in the grammar
				//parent.setIScores(iScores2);
				//break;
				oldBScores=new double[nLeftChildSubStates][nRightChildSubStates][nParentSubStates];
				ArrayUtil.fill(oldBScores,1.0);
			}
			double[][][] bcounts = binaryRuleCounter.getCount(brule);
			if (bcounts==null) bcounts = new double[nLeftChildSubStates][nRightChildSubStates][];
			scalingFactor = Math.pow(GrammarTrainer.SCALE,
					parent.getOScale()+leftChild.getIScale()+rightChild.getIScale()-tree_scale);
			if (scalingFactor==0){
				System.out.println("p: "+parent.getOScale()+" l: "+leftChild.getIScale()+" r:"+rightChild.getIScale()+" t:"+tree_scale);
			}
			for (short i = 0; i < nLeftChildSubStates; i++) {
				double lcIS = leftChild.getIScore(i);
				if (lcIS==0) continue;
				for (short j = 0; j < nRightChildSubStates; j++) {
					if (oldBScores[i][j]==null) continue;
					double rcIS = rightChild.getIScore(j);
					if (rcIS==0) continue;
					// allocate parent array
					if (bcounts[i][j]==null) bcounts[i][j] = new double[nParentSubStates];
					for (short k = 0; k < nParentSubStates; k++) {
						double pOS = parent.getOScore(k); // Parent outside score
						if (pOS==0) continue;
						double rS = oldBScores[i][j][k];
						if (rS==0) continue;
						if (tree_score==0)
							tree_score = 1;
						double logRuleCount = (rS * lcIS / tree_score) * rcIS * scalingFactor * pOS;
						
						bcounts[i][j][k] += logRuleCount;
					}
				}
			}
			binaryRuleCounter.setCount(brule, bcounts);
			break;
		default:
			throw new Error("Malformed tree: more than two children");
		}
		
		for (Tree child : children) {
			tallyStateSetTree(child, tree_score, tree_scale, old_grammar);
		}
	}
	*/
	/*
	 * private UnaryRule makeUnaryRule(Tree tree) { int parent =
	 * tagNumberer.number(tree.getLabel()); int child =
	 * tagNumberer.number(tree.getChildren().get(0).getLabel()); return new
	 * UnaryRule(parent, child); }
	 * 
	 * private BinaryRule makeBinaryRule(Tree tree) { int parent =
	 * tagNumberer.number(tree.getLabel()); int lChild =
	 * tagNumberer.number(tree.getChildren().get(0).getLabel()); int rChild =
	 * tagNumberer.number(tree.getChildren().get(1).getLabel()); return new
	 * BinaryRule(parent, lChild, rChild); }
	 */
	public void makeCRArrays() {
		// int numStates = closedRulesWithParent.length;
		closedSumRulesWithP = new UnaryRule[numStates][];
		closedSumRulesWithC = new UnaryRule[numStates][];
		closedViterbiRulesWithP = new UnaryRule[numStates][];
		closedViterbiRulesWithC = new UnaryRule[numStates][];
		
		for (int i = 0; i < numStates; i++) {
			closedSumRulesWithP[i] = (UnaryRule[]) closedSumRulesWithParent[i].toArray(new UnaryRule[0]);
			closedSumRulesWithC[i] = (UnaryRule[]) closedSumRulesWithChild[i].toArray(new UnaryRule[0]);
			closedViterbiRulesWithP[i] = (UnaryRule[]) closedViterbiRulesWithParent[i].toArray(new UnaryRule[0]);
			closedViterbiRulesWithC[i] = (UnaryRule[]) closedViterbiRulesWithChild[i].toArray(new UnaryRule[0]);
		}
	}
		public UnaryRule[] getClosedSumUnaryRulesByParent(int state) {
		if (closedSumRulesWithP == null) {
			makeCRArrays();
		}
		if (state >= closedSumRulesWithP.length) {
			return new UnaryRule[0];
		}
		return closedSumRulesWithP[state];
	}
	
	public UnaryRule[] getClosedSumUnaryRulesByChild(int state) {
		if (closedSumRulesWithC == null) {
			makeCRArrays();
		}
		if (state >= closedSumRulesWithC.length) {
			return new UnaryRule[0];
		}
		return closedSumRulesWithC[state];
	}
	
	public UnaryRule[] getClosedViterbiUnaryRulesByParent(int state) {
		if (closedViterbiRulesWithP == null) {
			makeCRArrays();
		}
		if (state >= closedViterbiRulesWithP.length) {
			return new UnaryRule[0];
		}
		return closedViterbiRulesWithP[state];
	}
	
	public UnaryRule[] getClosedViterbiUnaryRulesByChild(int state) {
		if (closedViterbiRulesWithC == null) {
			makeCRArrays();
		}
		if (state >= closedViterbiRulesWithC.length) {
			return new UnaryRule[0];
		}
		return closedViterbiRulesWithC[state];
	}
	
	@SuppressWarnings("unchecked")
	public void purgeRules() {
		Map bR = new HashMap();
		Map bR2 = new HashMap();
		for (Iterator i = bestSumRulesUnderMax.keySet().iterator(); i.hasNext();) {
			UnaryRule ur = (UnaryRule) i.next();
			if ((ur.parentState != ur.childState)) {
				bR.put(ur, ur);
				bR2.put(ur, ur);
			}
		}
		bestSumRulesUnderMax = bR;
		bestViterbiRulesUnderMax = bR2;
	}
	
	@SuppressWarnings("unchecked")
	public List getBestViterbiPath(short pState, short np, short cState, short cp) {
		ArrayList path = new ArrayList();
		short[] state = new short[2];
		state[0] = pState;
		state[1] = np;
		// if we haven't built the data structure of closed paths, then
		// return the simplest possible path
		if (!findClosedPaths) {
			path.add(state);
			state = new short[2];
			state[0] = cState;
			state[1] = cp;
			path.add(state);
			return path;
		} else {
			//read the best paths off of the closedViterbiPaths list
			if (pState==cState && np==cp) {
				path.add(state);
				path.add(state);
				return path;
			}
			while (state[0]!=cState || state[1]!=cp) {
				path.add(state);
				state[0] = (short)closedViterbiPaths[state[0]][state[1]];
			}
			// add the destination state as well
			path.add(state);
			return path;
		}
	}
	
	@SuppressWarnings("unchecked")
	private void closeRulesUnderMax(UnaryRule ur) {
		short pState = ur.parentState;
		int nPSubStates = numSubStates[pState];
		short cState = ur.childState;
		double[][] uScores = ur.getScores2();
		// do all sum rules
		for (int i = 0; i < closedSumRulesWithChild[pState].size(); i++) {
			UnaryRule pr = (UnaryRule) closedSumRulesWithChild[pState].get(i);
			for (int j = 0; j < closedSumRulesWithParent[cState].size(); j++) {
				short parentState = pr.parentState;
				int nParentSubStates = numSubStates[parentState];
				UnaryRule cr = (UnaryRule) closedSumRulesWithParent[cState].get(j);
				UnaryRule resultR = new UnaryRule(parentState, cr.getChildState());
				double[][] scores = new double[numSubStates[cr.getChildState()]][nParentSubStates];
				for (int np = 0; np < scores[0].length; np++) {
					for (int cp = 0; cp < scores.length; cp++) {
						// sum over intermediate substates
						double sum = 0;
						for (int unp = 0; unp < nPSubStates; unp++) {
							for (int ucp = 0; ucp < uScores.length; ucp++) {
								sum += pr.getScore(np, unp) * cr.getScore(ucp, cp) * ur.getScore(unp, ucp);
							}
						}
						scores[cp][np] = sum;
					}
				}
				resultR.setScores2(scores);
				//add rule to bestSumRulesUnderMax if it's better
				relaxSumRule(resultR,pState,cState);
			}
		}
		// do viterbi rules also
		for (short i = 0; i < closedViterbiRulesWithChild[pState].size(); i++) {
			UnaryRule pr = (UnaryRule) closedViterbiRulesWithChild[pState].get(i);
			for (short j = 0; j < closedViterbiRulesWithParent[cState].size(); j++) {
				UnaryRule cr = (UnaryRule) closedViterbiRulesWithParent[cState].get(j);
				short parentState = pr.parentState;
				int nParentSubStates = numSubStates[parentState];
				UnaryRule resultR = new UnaryRule(parentState, cr.getChildState());
				double[][] scores = new double[numSubStates[cr.getChildState()]][nParentSubStates];
				short[][] intermediateSubState1 = new short[nParentSubStates][numSubStates[cr.getChildState()]];
				short[][] intermediateSubState2 = new short[nParentSubStates][numSubStates[cr.getChildState()]];
				for (int np = 0; np < scores[0].length; np++) {
					for (int cp = 0; cp < scores.length; cp++) {
						// sum over intermediate substates
						double max = 0;
						for (short unp = 0; unp < nPSubStates; unp++) {
							for (short ucp = 0; ucp < uScores.length; ucp++) {
								double score = pr.getScore(np, unp) * cr.getScore(ucp, cp) * ur.getScore(unp, ucp);
								if (score > max) {
									max = score;
									intermediateSubState1[np][cp] = unp; 
									intermediateSubState2[np][cp] = ucp; 
								}
							}
						}
						scores[cp][np] = max;
					}
				}
				resultR.setScores2(scores);
				//add rule to bestSumRulesUnderMax if it's better
				relaxViterbiRule(resultR,pState,intermediateSubState1,cState,intermediateSubState2);
			}
		}
	}
	
	public int getUnaryIntermediate(short start, short end){
		return closedSumPaths[start][end];
	}
	
	
	@SuppressWarnings("unchecked")
	private boolean relaxSumRule(UnaryRule ur, int intState1, int intState2) {
		//TODO: keep track of path
		UnaryRule bestR = (UnaryRule) bestSumRulesUnderMax.get(ur);
		if (bestR == null) {
			bestSumRulesUnderMax.put(ur, ur);
			closedSumRulesWithParent[ur.parentState].add(ur);
			closedSumRulesWithChild[ur.childState].add(ur);
			return true;
		} else {
			boolean change = false;
			for (int i=0; i scoresMax[cp][np]) {
									scoresMax[cp][np] = sum;
									bestMaxIntermed = -1;
								}
							}
						}
						if (total>maxSumScore){
							bestSumIntermed=-1;
							maxSumScore=total;
						}
					}
					else{
						for (int j = 0; j < unaryRulesWithC[childState].size(); j++) {
							UnaryRule cr = (UnaryRule) unaryRulesWithC[childState].get(j);
							if (state!=cr.getParentState()) continue;
							int nMySubStates = numSubStates[state];
							double total = 0;
							for (int np = 0; np < nParentSubStates; np++) {
								for (int cp = 0; cp < nChildSubStates; cp++) {
									// sum over intermediate substates
									double sum = 0;
									double max = 0;
									for (int unp = 0; unp < nMySubStates; unp++) {
										double val  = pr.getScore(np, unp) * cr.getScore(unp, cp);
										sum += val;
										max = Math.max(max, val);
									}
									scoresSum[cp][np] += sum;
									total += sum;
									if (max > scoresMax[cp][np]) {
										scoresMax[cp][np] = max;
										bestMaxIntermed = state;
									}
								}
							}
							if (total>maxSumScore){
								maxSumScore=total;
								bestSumIntermed=state;
							}
						}
					}
				}
				if (maxSumScore>-1){
					resultRsum.setScores2(scoresSum);
					addUnary(resultRsum);
					closedSumRulesWithParent[parentState].add(resultRsum);
					closedSumRulesWithChild[childState].add(resultRsum);
					closedSumPaths[parentState][childState]=bestSumIntermed;
				}
				if (bestMaxIntermed>-2){
					resultRmax.setScores2(scoresMax);
					//addUnary(resultR);
					closedViterbiRulesWithParent[parentState].add(resultRmax);
					closedViterbiRulesWithChild[childState].add(resultRmax);
					closedViterbiPaths[parentState][childState]=bestMaxIntermed;
					/*if (bestMaxIntermed > -1){
						System.out.println("NEW RULE CREATED");
					}*/
				}
			}
		}

	}
	/*
	@SuppressWarnings("unchecked")
	private boolean relaxSumRule(UnaryRule rule) {
		bestSumRulesUnderMax.put(rule, rule);
		closedSumRulesWithParent[rule.parentState].add(rule);
		closedSumRulesWithChild[rule.childState].add(rule);
		return true;		
	}
	*/
	/**
	 * Update the best unary chain probabilities and paths with this new rule.
	 * 
	 * @param ur
	 * @param subStates1
	 * @param subStates2
	 * @return
	 */
	@SuppressWarnings("unchecked")
	private void relaxViterbiRule(UnaryRule ur, short intState1,
			short[][] intSubStates1, short intState2, short[][] intSubStates2) {
		throw new Error("Viterbi closure is broken!");
/*		UnaryRule bestR = (UnaryRule) bestViterbiRulesUnderMax.get(ur);
		boolean isNewRule = (bestR==null);
		if (isNewRule) {
			bestViterbiRulesUnderMax.put(ur, ur);
			closedViterbiRulesWithParent[ur.parentState].add(ur);
			closedViterbiRulesWithChild[ur.childState].add(ur);
			bestR = ur;
		}
		for (int i=0; i[] matrixMultiply(List[] parentRules, List[] childRules) {
  	throw new Error("I'm broken by parent first");
  	/*
    double[][][][] scores = new double[numStates][numStates][][];
    for ( short A=0; A[] result = new List[numStates];
    for ( short A=0; A();
      for ( short C=0; C[] rules1, List[] rules2) {    
  	throw new Error("I'm broken by parent first");
  	/*
    for ( short A=0; A[] matrixUnity() {
  	throw new Error("I'm broken by parent first");
//    List[] result = new List[numStates];
//    for ( short A=0; A();
//      double[][] scores = new double[numSubStates[A]][numSubStates[A]];
//      ArrayUtil.fill(scores, Double.NEGATIVE_INFINITY);
//      for ( int a = 0; a < numSubStates[A]; a++ ) {
//        scores[a][a] = 0;
//      }
//      UnaryRule rule = new UnaryRule(A, A, scores);
//      result[A].add(rule);
//    }
//    return result;
  }
  
  /**
   * @param P
   * @return I + P + P^2 + P^3 + ... (approximation by truncation after some power)
   */
  private List[] sumProductUnaryClosure(List[] P) {
  	throw new Error("I'm broken by parent first");
  	/*
    List[] R = matrixUnity();
    matrixAdd(R, P);         // R = I + P + P^2 + P^3 + ...
    List[] Q = P; // Q = P^k
    int maxPower = 3;
    for ( int i = 1; i < maxPower; i++ ) {
      Q = matrixMultiply(Q, P);
      matrixAdd(R, Q);
    }
    return R;
    */
  }

  /**
   * Assumption: A in possibleSt ==> V[A] != null. This property is true of the result as well.
   * The converse is not true because of a workaround for part of speech tags that we must handle
   * here.
   * @param V (considered a row vector, indexed by (state, substate))
   * @param M (a matrix represented in List[] (by parent) format)
   * @param possibleSt (a list of possible states to consider)
   * @return U=V*M (row vector)
   */
  public double[][] matrixVectorPreMultiply(double[][] V, List[] M, List possibleSt) {
  	throw new Error("I'm broken by parent first");
  	/*
    double[][] U = new double[numStates][];
    for (int pState : possibleSt){
      U[pState] = new double[numSubStates[pState]];
      Arrays.fill(U[pState], Double.NEGATIVE_INFINITY);
      UnaryRule[] unaries = M[pState].toArray(new UnaryRule[0]);
      for ( UnaryRule ur : unaries ) {
        int cState = ur.childState;
        if ( V[cState] == null ) {
          continue;
        }
        double[][] scores = ur.getScores();  // numSubStates[pState] * numSubStates[cState]
        int nParentStates = numSubStates[pState];
        int nChildStates = numSubStates[cState];
        double[] termsToAdd = new double[nChildStates+1]; // Could be inside the for(np) loop
        for (int np = 0; np < nParentStates; np++) {
          Arrays.fill(termsToAdd, Double.NEGATIVE_INFINITY);
          double currentVal = U[pState][np];
          termsToAdd[termsToAdd.length-1] = currentVal;
          for (int cp = 0; cp < nChildStates; cp++) {
            double iS = V[cState][cp];
            if (iS == Double.NEGATIVE_INFINITY) {
              continue;
            }
            double pS = scores[np][cp];
            termsToAdd[cp] = iS + pS;
          }
          
          double newVal = SloppyMath.logAdd(termsToAdd);
          if (newVal > currentVal) {
            U[pState][np] = newVal;
          }
        }
      }
    }
    return U;
    */
  }
  
  /**
   * Assumption: A in possibleSt ==> V[A] != null. This property is true of the result as well.
   * The converse is not true because of a workaround for part of speech tags that we must handle
   * here.
   * @param M (a matrix represented in List[] (by parent) format)
   * @param V (considered a column vector, indexed by (state, substate))
   * @param possibleSt (a list of possible states to consider)
   * @return U=M*V (column vector)
   */
  public double[][] matrixVectorPostMultiply(List[] M, double[][] V, List possibleSt) {
  	throw new Error("I'm broken by parent first");
  	/*
    double[][] U = new double[numStates][];
    for (int cState : possibleSt){
      U[cState] = new double[numSubStates[cState]];
      Arrays.fill(U[cState], Double.NEGATIVE_INFINITY);
    }
    for (int pState : possibleSt){
      UnaryRule[] unaries = M[pState].toArray(new UnaryRule[0]);
      for ( UnaryRule ur : unaries ) {
        int cState = ur.childState;
        if ( U[cState] == null ) {
          continue;
        }
        double[][] scores = ur.getScores();  // numSubStates[pState] * numSubStates[cState]
        int nParentStates = numSubStates[pState];
        int nChildStates = numSubStates[cState];
        double[] termsToAdd = new double[nParentStates+1]; // Could be inside the for(np) loop
        for (int cp = 0; cp < nChildStates; cp++) {
          Arrays.fill(termsToAdd, Double.NEGATIVE_INFINITY);
          double currentVal = U[cState][cp];
          termsToAdd[termsToAdd.length-1] = currentVal;
          for (int np = 0; np < nParentStates; np++) {
            double oS = V[pState][np];
            if (oS == Double.NEGATIVE_INFINITY) {
              continue;
            }
            double pS = scores[np][cp];
            termsToAdd[cp] = oS + pS;
          }
          
          double newVal = SloppyMath.logAdd(termsToAdd);
          if (newVal > currentVal) {
            U[cState][cp] = newVal;
          }
        }
      }
    }
    return U;
    */
  }
  
	/**
	 * Populates the "splitRules" accessor lists using the existing rule lists. If
	 * the state is synthetic, these lists contain all rules for the state. If the
	 * state is NOT synthetic, these lists contain only the rules in which both
	 * children are not synthetic.
	 * 

* This method must be called before the grammar is used, either after * training or deserializing grammar. */ @SuppressWarnings("unchecked") public void splitRules() { // splitRulesWithLC = new BinaryRule[numStates][]; // splitRulesWithRC = new BinaryRule[numStates][]; //makeRulesAccessibleByChild(); if (binaryRulesWithParent==null) return; splitRulesWithP = new BinaryRule[numStates][]; splitRulesWithLC = new BinaryRule[numStates][]; splitRulesWithRC = new BinaryRule[numStates][]; for (int state = 0; state < numStates; state++) { splitRulesWithLC[state] = toBRArray(binaryRulesWithLC[state]); splitRulesWithRC[state] = toBRArray(binaryRulesWithRC[state]); splitRulesWithP[state] = toBRArray(binaryRulesWithParent[state]); } // we don't need the original lists anymore binaryRulesWithParent = null; binaryRulesWithLC = null; binaryRulesWithRC = null; makeCRArrays(); } public BinaryRule[] splitRulesWithLC(int state) { // System.out.println("splitRulesWithLC not supported anymore."); // return null; if (state >= splitRulesWithLC.length) { return new BinaryRule[0]; } return splitRulesWithLC[state]; } public BinaryRule[] splitRulesWithRC(int state) { // System.out.println("splitRulesWithLC not supported anymore."); // return null; if (state >= splitRulesWithRC.length) { return new BinaryRule[0]; } return splitRulesWithRC[state]; } public BinaryRule[] splitRulesWithP(int state) { if (splitRulesWithP==null) splitRules(); if (state >= splitRulesWithP.length) { return new BinaryRule[0]; } return splitRulesWithP[state]; } private BinaryRule[] toBRArray(List list) { // Collections.sort(list, Rule.scoreComparator()); // didn't seem to help BinaryRule[] array = new BinaryRule[list.size()]; for (int i = 0; i < array.length; i++) { array[i] = list.get(i); } return array; } public double[][] getUnaryScore(short pState, short cState) { UnaryRule r = getUnaryRule(pState, cState); if (r != null) return r.getScores2(); if (GrammarTrainer.VERBOSE) System.out.println("The requested rule ("+uSearchRule+") is not in the grammar!"); double[][] uscores = new double[numSubStates[cState]][numSubStates[pState]]; ArrayUtil.fill(uscores,0.0); return uscores; } /** * @param pState * @param cState * @return */ public UnaryRule getUnaryRule(short pState, short cState) { UnaryRule uRule = new UnaryRule (pState, cState); UnaryRule r = unaryRuleMap.get(uRule); return r; } public double[][] getUnaryScore(UnaryRule rule) { UnaryRule r = unaryRuleMap.get(rule); if (r != null) return r.getScores2(); if (GrammarTrainer.VERBOSE) System.err.println("The requested rule ("+rule+") is not in the grammar!"); double[][] uscores = new double[numSubStates[rule.getChildState()]][numSubStates[rule.getParentState()]]; ArrayUtil.fill(uscores,0.0); return uscores; } public double[][][] getBinaryScore(short pState, short lState, short rState) { BinaryRule r = getBinaryRule(pState, lState, rState); if (r != null) return r.getScores2(); if (GrammarTrainer.VERBOSE) { System.err.println(tagNumberer.object(pState)+"\t"+pState); System.err.println(tagNumberer.object(lState)+"\t"+lState); System.err.println(tagNumberer.object(rState)+"\t"+rState); System.err.println("numSubStates.length:"+"\t"+numSubStates.length); } double[][][] bscores = new double[numSubStates[lState]][numSubStates[rState]][numSubStates[pState]]; ArrayUtil.fill(bscores,0.0); return bscores; } /** * @param pState * @param lState * @param rState * @return */ public BinaryRule getBinaryRule(short pState, short lState, short rState) { BinaryRule bRule = new BinaryRule(pState, lState, rState); BinaryRule r = binaryRuleMap.get(bRule); return r; } public double[][][] getBinaryScore(BinaryRule rule) { BinaryRule r = binaryRuleMap.get(rule); if (r != null) return r.getScores2(); else { if (GrammarTrainer.VERBOSE) System.out.println("The requested rule ("+rule+") is not in the grammar!"); double[][][] bscores = new double[numSubStates[rule.getLeftChildState()]][numSubStates[rule.getRightChildState()]][numSubStates[rule.getParentState()]]; ArrayUtil.fill(bscores,0.0); return bscores; } } public void printSymbolCounter(Numberer tagNumberer) { Set set = symbolCounter.keySet(); PriorityQueue pq = new PriorityQueue(set.size()); for (Integer i : set) { pq.add((String) tagNumberer.object(i), symbolCounter.getCount(i, 0)); // System.out.println(i+". "+(String)tagNumberer.object(i)+"\t // "+symbolCounter.getCount(i,0)); } int i = 0; while (pq.hasNext()) { i++; int p = (int) pq.getPriority(); System.out.println(i + ". " + pq.next() + "\t " + p); } } public int getSymbolCount(Integer i) { return (int) symbolCounter.getCount(i, 0); } private void makeRulesAccessibleByChild(){ // first the binaries if (true) return; for (int state=0; state=counts[i]) { // newNumSubStates[i]=numSubStates[i]; // } // else{ newNumSubStates[i] = (short)(numSubStates[i] * 2); // } } boolean doNotNormalize = (mode==1); newNumSubStates[0] = 1; // never split ROOT // create the new grammar Grammar grammar = new Grammar(newNumSubStates, findClosedPaths, smoother, this, threshold); Random random = GrammarTrainer.RANDOM; for (BinaryRule oldRule : binaryRuleMap.keySet()) { BinaryRule newRule = oldRule.splitRule(numSubStates, newNumSubStates, random, randomness, doNotNormalize, mode); grammar.addBinary(newRule); } for (UnaryRule oldRule : unaryRuleMap.keySet()){ UnaryRule newRule = oldRule.splitRule(numSubStates, newNumSubStates, random, randomness, doNotNormalize, mode); grammar.addUnary(newRule); } grammar.isGrammarTag = this.isGrammarTag; grammar.extendSplitTrees(splitTrees, numSubStates); grammar.computePairsOfUnaries(); return grammar; } @SuppressWarnings("unchecked") public void extendSplitTrees(Tree[] trees, short[] oldNumSubStates) { this.splitTrees = new Tree[numStates]; for (int tag=0; tag splitTree = trees[tag].shallowClone(); for (Tree leaf : splitTree.getTerminals()) { List> children = leaf.getChildren(); if (numSubStates[tag] > oldNumSubStates[tag]) { children.add(new Tree((short)(2*leaf.getLabel()))); children.add(new Tree((short)(2*leaf.getLabel()+1))); } else { children.add(new Tree(leaf.getLabel())); } } this.splitTrees[tag] = splitTree; } } public int totalSubStates() { int count = 0; for (int i = 0; i < numStates; i++) { count += numSubStates[i]; } return count; } /** * Tally the probability of seeing each substate. This data is needed for * tallyMergeScores. mergeWeights is indexed as [state][substate]. This * data should be normalized before being used by another function. * * @param tree * @param mergeWeights The probability of seeing substate given state. */ public void tallyMergeWeights(Tree tree, double mergeWeights[][]) { if (tree.isLeaf()) return; StateSet label = tree.getLabel(); short state = label.getState(); double probs[] = new double[label.numSubStates()]; double total = 0, tmp; for (short i=0; i child : tree.getChildren()) { tallyMergeWeights(child,mergeWeights); } } /* * normalize merge weights. assumes that the mergeWeights are given * as logs. the normalized weights are returned as probabilities. */ public void normalizeMergeWeights(double[][] mergeWeights){ for (int state=0; state tree, double[][][] deltas, double[][] mergeWeights) { if (tree.isLeaf()) return; StateSet label = tree.getLabel(); short state = label.getState(); double[] separatedScores = new double[label.numSubStates()]; double[] combinedScores = new double[label.numSubStates()]; double combinedScore; // calculate separated scores double separatedScoreSum = 0, tmp; //don't need to deal with scale factor because we divide below for (int i = 0; i < label.numSubStates(); i++) { tmp = label.getIScore(i) * label.getOScore(i); combinedScores[i] = separatedScores[i] = tmp; separatedScoreSum += tmp; } // calculate merged scores for (short i = 0; i < numSubStates[state]; i++) { for (short j=(short)(i+1); j child : tree.getChildren()) { tallyMergeScores(child, deltas, mergeWeights); } } /** * This merges the substate pairs indicated by mergeThesePairs[state][substate pair]. * It requires merge weights calculated by tallyMergeWeights. * * @param mergeThesePairs Which substate pairs to merge. * @param mergeWeights The probability of seeing each substate. */ public Grammar mergeStates(boolean[][][] mergeThesePairs, double[][] mergeWeights) { if (logarithmMode) { throw new Error("Do not merge grammars in logarithm mode!"); } short[] newNumSubStates = new short[numSubStates.length]; short[][] mapping = new short[numSubStates.length][]; //invariant: if partners[state][substate][0] == substate, it's the 1st one short[][][] partners = new short[numSubStates.length][][]; calculateMergeArrays(mergeThesePairs, newNumSubStates, mapping, partners, numSubStates); // create the new grammar Grammar grammar = new Grammar(newNumSubStates, findClosedPaths, smoother, this, threshold); //for (Rule r : allRules) { //if (r instanceof BinaryRule) { for (BinaryRule oldRule : binaryRuleMap.keySet()) { //BinaryRule oldRule = r; short pS = oldRule.getParentState(), lcS = oldRule.getLeftChildState(), rcS = oldRule .getRightChildState(); double[][][] oldScores = oldRule.getScores2(); //merge binary rule double[][][] newScores = new double[newNumSubStates[lcS]][newNumSubStates[rcS]][newNumSubStates[pS]]; for (int i=0; i splitTree = splitTrees[tag]; int maxDepth = splitTree.getDepth(); for (Tree preTerminal : splitTree.getAtDepth(maxDepth-2)) { List> children = preTerminal.getChildren(); ArrayList> newChildren = new ArrayList>(2); for (int i=0; i child = children.get(i); int curLoc = child.getLabel(); if (partners[tag][curLoc][0]==curLoc) { newChildren.add(new Tree(mapping[tag][curLoc])); } } preTerminal.setChildren(newChildren); } } } public static void checkNormalization(Grammar grammar) { double[][] psum = new double[grammar.numSubStates.length][]; for (int pS=0; pS0.001) System.out.println(" state "+pS+" substate "+pi+" gives bad psum: "+psum[pS][pi]); } } } /** * @param mergeThesePairs * @param newNumSubStates * @param mapping * @param partners */ public static void calculateMergeArrays(boolean[][][] mergeThesePairs, short[] newNumSubStates, short[][] mapping, short[][][] partners, short[] numSubStates) { for (short state = 0; state < numSubStates.length; state++) { short mergeTarget[] = new short[mergeThesePairs[state].length]; Arrays.fill(mergeTarget,(short)-1); short count = 0; mapping[state] = new short[numSubStates[state]]; partners[state] = new short[numSubStates[state]][]; for (short j=0; j[] a) { if (a!=null) { for (List l : a) { if (l==null) continue; for (BinaryRule r : l) { logarithmModeRule(r); } } } } /** * */ private void logarithmModeURuleListArray(List[] a) { if (a!=null) { for (List l : a) { if (l==null) continue; for (UnaryRule r : l) { logarithmModeRule(r); } } } } /** * */ private void logarithmModeBRuleArrayArray(BinaryRule[][] a) { if (a!=null) { for (BinaryRule[] l : a) { if (l==null) continue; for (BinaryRule r : l) { logarithmModeRule(r); } } } } /** * */ private void logarithmModeURuleArrayArray(UnaryRule[][] a) { if (a!=null) { for (UnaryRule[] l : a) { if (l==null) continue; for (UnaryRule r : l) { logarithmModeRule(r); } } } } /** * @param r */ private static void logarithmModeRule(BinaryRule r) { if (r==null || r.logarithmMode) return; r.logarithmMode = true; double[][][] scores = r.getScores2(); for (int i=0; i0){ double[][][] newScores = new double[1][1][1]; newScores[0][0][0] = newBinaryProbs[lS][rS]; BinaryRule newRule = new BinaryRule((short)0,lS,rS,newScores); //newRule.setScores2(newScores); grammar.addBinary(newRule); } } } for (short cS=0; cS0){ double[][] newScores = new double[1][1]; newScores[0][0] = newUnaryProbs[cS]; UnaryRule newRule = new UnaryRule((short)0,cS,newScores); //newRule.setScores2(newScores); grammar.addUnary(newRule); } } grammar.computePairsOfUnaries(); grammar.makeCRArrays(); grammar.isGrammarTag = this.isGrammarTag; //System.out.println(grammar.toString()); return grammar; } public double[] computeConditionalProbabilities(int[][] fromMapping, int[][] toMapping) { double[][] transitionProbs = computeProductionProbabilities(fromMapping); //System.out.println(ArrayUtil.toString(transitionProbs)); double[] expectedCounts = computeExpectedCounts(transitionProbs); //System.out.println(Arrays.toString(expectedCounts)); /*for (int state=0; state 0-bar states // level 0 -> x-bar states // level 1 -> each (state,substate) gets its own index short[] numSubStates = this.numSubStates; int[][] mapping = new int[numSubStates.length+1][]; int k=0; for (int state=0; state=1) mapping[state][substate]=k++; else if (level==-1){ if (this.isGrammarTag(state)) mapping[state][substate] = 0; else mapping[state][substate]=state; } else /*level==0*/ mapping[state][substate] = state; } } mapping[numSubStates.length] = new int[1]; mapping[numSubStates.length][0]= (level<1) ? numSubStates.length : k; //System.out.println("The grammar has "+mapping[numSubStates.length][0]+" substates."); return mapping; } public int[][] computeSubstateMapping(int level) { // level 0 -> merge all substates // level 1 -> merge upto depth 1 -> keep upto 2 substates // level 2 -> merge upto depth 2 -> keep upto 4 substates short[] numSubStates = this.numSubStates; // for (int i=0; i=0){ Arrays.fill(mapping[state],-1); Tree hierarchy = splitTrees[state]; List> subTrees = hierarchy.getAtDepth(level); for (Tree subTree : subTrees){ List leaves = subTree.getYield(); for (Short substate : leaves){ // System.out.println(substate+" "+numSubStates[state]+" "+state); if (substate==numSubStates[state]) System.out.print("Will crash."); mapping[state][substate+1]=k; } k++; } } else {k=1;} mapping[state][0]=k; } return mapping; } public void computeReverseSubstateMapping(int level, int[][] lChildMap, int[][] rChildMap) { // level 1 -> how do the states from depth 1 expand to depth 2 for (int state=0; state hierarchy = splitTrees[state]; List> subTrees = hierarchy.getAtDepth(level); lChildMap[state] = new int[subTrees.size()]; rChildMap[state] = new int[subTrees.size()]; for (Tree subTree : subTrees){ int substate = subTree.getLabel(); if (subTree.isLeaf()){ lChildMap[state][substate] = substate; rChildMap[state][substate] = substate; continue; } boolean first = true; int nChildren = subTree.getChildren().size(); for (Tree child : subTree.getChildren()){ if (first) { lChildMap[state][substate] = child.getLabel(); first = false; } else rChildMap[state][substate] = child.getLabel(); if (nChildren==1) rChildMap[state][substate] = child.getLabel(); } } } } private double[] computeExpectedCounts(double[][] transitionProbs) { //System.out.println(ArrayUtil.toString(transitionProbs)); double[] expectedCounts = new double[transitionProbs.length]; double[] tmpCounts = new double[transitionProbs.length]; expectedCounts[0] = 1; tmpCounts[0] = 1; //System.out.print("Computing expected counts"); int iter = 0; double diff = 1; double sum = 1; // 1 for the root while (diff>1.0e-10 && iter<50){ iter++; for (int state=1; state uRules = this.getUnaryRulesByParent(state); for (UnaryRule r : uRules){ int cState = r.childState; if (cState==state) continue; /*if (cState==15){ System.out.println("Found one"); }*/ double[][] scores = r.getScores2(); for (int cS=0; cS(); closedSumRulesWithChild[startState] = new ArrayList(); } // finally create rules and add them to the arrays for (short startState=0; startState0){ scores[endSubState][startSubState]=score; atLeastOneNonZero = true; } } } if (atLeastOneNonZero){ UnaryRule newUnary = new UnaryRule(startState, endState, scores); addUnary(newUnary); closedSumRulesWithParent[startState].add(newUnary); closedSumRulesWithChild[endState].add(newUnary); } } } if (closedSumRulesWithP==null){ closedSumRulesWithP = new UnaryRule[numStates][]; closedSumRulesWithC = new UnaryRule[numStates][]; } for (int i = 0; i < numStates; i++) { closedSumRulesWithP[i] = (UnaryRule[]) closedSumRulesWithParent[i].toArray(new UnaryRule[0]); closedSumRulesWithC[i] = (UnaryRule[]) closedSumRulesWithChild[i].toArray(new UnaryRule[0]); } } /** * @param output */ public void writeSplitTrees(Writer w) { PrintWriter out = new PrintWriter(w); for (int state=1; state





© 2015 - 2025 Weber Informatics LLC | Privacy Policy