All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.lexparser.SplittingGrammarExtractor Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.parser.lexparser; 
import edu.stanford.nlp.util.logging.Redwood;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;

import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.MutableDouble;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ThreeDimensionalMap;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.TwoDimensionalMap;


import java.io.*;

/**
 * This class is a reimplementation of Berkeley's state splitting
 * grammar.  This work is experimental and still in progress.  There
 * are several extremely important pieces to implement:
 * 
    *
  1. this code should use log probabilities throughout instead of * multiplying tiny numbers *
  2. time efficiency of the training code is fawful *
  3. there are better ways to extract parses using this grammar than * the method in ExhaustivePCFGParser *
  4. we should also implement cascading parsers that let us * shortcircuit low quality parses earlier (which could possibly * benefit non-split parsers as well) *
  5. when looping, we should short circuit if we go too many loops *
  6. ought to smooth as per page 436 *
* * @author John Bauer */ public class SplittingGrammarExtractor { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(SplittingGrammarExtractor.class); static final int MIN_DEBUG_ITERATION=0; static final int MAX_DEBUG_ITERATION=0; static final int MAX_ITERATIONS = Integer.MAX_VALUE; int iteration = 0; boolean DEBUG() { return (iteration >= MIN_DEBUG_ITERATION && iteration < MAX_DEBUG_ITERATION); } Options op; /** * These objects are created and filled in here. The caller can get * the data from the extractor once it is finished. */ Index stateIndex; Index wordIndex; Index tagIndex; /** * This is a list gotten from the list of startSymbols in op.langpack() */ List startSymbols; /** * A combined list of all the trees in the training set. */ List trees = new ArrayList<>(); /** * All of the weights associated with the trees in the training set. * In general, this is just the weight of the original treebank. * Note that this uses an identity hash map to map from tree pointer * to weight. */ Counter treeWeights = new ClassicCounter<>(MapFactory.identityHashMapFactory()); /** * How many total weighted trees we have */ double trainSize; /** * The original states in the trees */ Set originalStates = Generics.newHashSet(); /** * The current number of times a particular state has been split */ IntCounter stateSplitCounts = new IntCounter<>(); /** * The binary betas are weights to go from Ax to By, Cz. This maps * from (A, B, C) to (x, y, z) to beta(Ax, By, Cz). */ ThreeDimensionalMap binaryBetas = new ThreeDimensionalMap<>(); /** * The unary betas are weights to go from Ax to By. This maps * from (A, B) to (x, y) to beta(Ax, By). */ TwoDimensionalMap unaryBetas = new TwoDimensionalMap<>(); /** * The latest lexicon we trained. At the end of the process, this * is the lexicon for the parser. */ Lexicon lex; transient Index tempWordIndex; transient Index tempTagIndex; /** * The lexicon we are in the process of building in each iteration. */ transient Lexicon tempLex; /** * The latest pair of unary and binary grammars we trained. */ Pair bgug; Random random = new Random(87543875943265L); static final double LEX_SMOOTH = 0.0001; static final double STATE_SMOOTH = 0.0; public SplittingGrammarExtractor(Options op) { this.op = op; startSymbols = Arrays.asList(op.langpack().startSymbols()); } double[] neginfDoubles(int size) { double[] result = new double[size]; for (int i = 0; i < size; ++i) { result[i] = Double.NEGATIVE_INFINITY; } return result; } public void outputTransitions(Tree tree, IdentityHashMap unaryTransitions, IdentityHashMap binaryTransitions) { outputTransitions(tree, 0, unaryTransitions, binaryTransitions); } public void outputTransitions(Tree tree, int depth, IdentityHashMap unaryTransitions, IdentityHashMap binaryTransitions) { for (int i = 0; i < depth; ++i) { System.out.print(" "); } if (tree.isLeaf()) { System.out.println(tree.label().value()); return; } if (tree.children().length == 1) { System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value()); if (!tree.isPreTerminal()) { double[][] transitions = unaryTransitions.get(tree); for (int i = 0; i < transitions.length; ++i) { for (int j = 0; j < transitions[0].length; ++j) { for (int z = 0; z < depth; ++z) { System.out.print(" "); } System.out.println(" " + i + "," + j + ": " + transitions[i][j] + " | " + Math.exp(transitions[i][j])); } } } } else { System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value() + " " + tree.children()[1].label().value()); double[][][] transitions = binaryTransitions.get(tree); for (int i = 0; i < transitions.length; ++i) { for (int j = 0; j < transitions[0].length; ++j) { for (int k = 0; k < transitions[0][0].length; ++k) { for (int z = 0; z < depth; ++z) { System.out.print(" "); } System.out.println(" " + i + "," + j + "," + k + ": " + transitions[i][j][k] + " | " + Math.exp(transitions[i][j][k])); } } } } if (tree.isPreTerminal()) { return; } for (Tree child : tree.children()) { outputTransitions(child, depth + 1, unaryTransitions, binaryTransitions); } } public void outputBetas() { System.out.println("UNARY:"); for (String parent : unaryBetas.firstKeySet()) { for (String child : unaryBetas.get(parent).keySet()) { System.out.println(" " + parent + "->" + child); double[][] betas = unaryBetas.get(parent).get(child); int parentStates = betas.length; int childStates = betas[0].length; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < childStates; ++j) { System.out.println(" " + i + "->" + j + " " + betas[i][j] + " | " + Math.exp(betas[i][j])); } } } } System.out.println("BINARY:"); for (String parent : binaryBetas.firstKeySet()) { for (String left : binaryBetas.get(parent).firstKeySet()) { for (String right : binaryBetas.get(parent).get(left).keySet()) { System.out.println(" " + parent + "->" + left + "," + right); double[][][] betas = binaryBetas.get(parent).get(left).get(right); int parentStates = betas.length; int leftStates = betas[0].length; int rightStates = betas[0][0].length; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { System.out.println(" " + i + "->" + j + "," + k + " " + betas[i][j][k] + " | " + Math.exp(betas[i][j][k])); } } } } } } } public String state(String tag, int i) { if (startSymbols.contains(tag) || tag.equals(Lexicon.BOUNDARY_TAG)) { return tag; } return tag + "^" + i; } public int getStateSplitCount(Tree tree) { return stateSplitCounts.getIntCount(tree.label().value()); } public int getStateSplitCount(String label) { return stateSplitCounts.getIntCount(label); } /** * Count all the internal labels in all the trees, and set their * initial state counts to 1. */ public void countOriginalStates() { originalStates.clear(); for (Tree tree : trees) { countOriginalStates(tree); } for (String state : originalStates) { stateSplitCounts.incrementCount(state, 1); } } /** * Counts the labels in the tree, but not the words themselves. */ private void countOriginalStates(Tree tree) { if (tree.isLeaf()) { return; } originalStates.add(tree.label().value()); for (Tree child : tree.children()) { if (child.isLeaf()) continue; countOriginalStates(child); } } private void initialBetasAndLexicon() { wordIndex = new HashIndex<>(); tagIndex = new HashIndex<>(); lex = op.tlpParams.lex(op, wordIndex, tagIndex); lex.initializeTraining(trainSize); for (Tree tree : trees) { double weight = treeWeights.getCount(tree); lex.incrementTreesRead(weight); initialBetasAndLexicon(tree, 0, weight); } lex.finishTraining(); } private int initialBetasAndLexicon(Tree tree, int position, double weight) { if (tree.isLeaf()) { // should never get here, unless a training tree is just one leaf return position; } if (tree.isPreTerminal()) { // fill in initial lexicon here String tag = tree.label().value(); String word = tree.children()[0].label().value(); TaggedWord tw = new TaggedWord(word, state(tag, 0)); lex.train(tw, position, weight); return (position + 1); } if (tree.children().length == 2) { String label = tree.label().value(); String leftLabel = tree.getChild(0).label().value(); String rightLabel = tree.getChild(1).label().value(); if (!binaryBetas.contains(label, leftLabel, rightLabel)) { double[][][] map = new double[1][1][1]; map[0][0][0] = 0.0; binaryBetas.put(label, leftLabel, rightLabel, map); } } else if (tree.children().length == 1) { String label = tree.label().value(); String childLabel = tree.getChild(0).label().value(); if (!unaryBetas.contains(label, childLabel)) { double[][] map = new double[1][1]; map[0][0] = 0.0; unaryBetas.put(label, childLabel, map); } } else { // should have been binarized throw new RuntimeException("Trees should have been binarized, expected 1 or 2 children"); } for (Tree child : tree.children()) { position = initialBetasAndLexicon(child, position, weight); } return position; } /** * Splits the state counts. Root states and the boundary tag do not * get their counts increased, and all others are doubled. Betas * and transition weights are handled later. */ private void splitStateCounts() { // double the count of states... IntCounter newStateSplitCounts = new IntCounter<>(); newStateSplitCounts.addAll(stateSplitCounts); newStateSplitCounts.addAll(stateSplitCounts); // root states should only have 1 for (String root : startSymbols) { if (newStateSplitCounts.getCount(root) > 1) { newStateSplitCounts.setCount(root, 1); } } if (newStateSplitCounts.getCount(Lexicon.BOUNDARY_TAG) > 1) { newStateSplitCounts.setCount(Lexicon.BOUNDARY_TAG, 1); } stateSplitCounts = newStateSplitCounts; } static final double EPSILON = 0.0001; /** * Before each iteration of splitting states, we have tables of * betas which correspond to the transitions between different * substates. When we resplit the states, we duplicate parent * states and then split their transitions 50/50 with some random * variation between child states. */ public void splitBetas() { TwoDimensionalMap tempUnaryBetas = new TwoDimensionalMap<>(); ThreeDimensionalMap tempBinaryBetas = new ThreeDimensionalMap<>(); for (String parent : unaryBetas.firstKeySet()) { for (String child : unaryBetas.get(parent).keySet()) { double[][] betas = unaryBetas.get(parent, child); int parentStates = betas.length; int childStates = betas[0].length; double[][] newBetas; if (!startSymbols.contains(parent)) { newBetas = new double[parentStates * 2][childStates]; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < childStates; ++j) { newBetas[i * 2][j] = betas[i][j]; newBetas[i * 2 + 1][j] = betas[i][j]; } } parentStates *= 2; betas = newBetas; } if (!child.equals(Lexicon.BOUNDARY_TAG)) { newBetas = new double[parentStates][childStates * 2]; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < childStates; ++j) { double childWeight = 0.45 + random.nextDouble() * 0.1; newBetas[i][j * 2] = betas[i][j] + Math.log(childWeight); newBetas[i][j * 2 + 1] = betas[i][j] + Math.log(1.0 - childWeight); } } betas = newBetas; } tempUnaryBetas.put(parent, child, betas); } } for (String parent : binaryBetas.firstKeySet()) { for (String left : binaryBetas.get(parent).firstKeySet()) { for (String right : binaryBetas.get(parent).get(left).keySet()) { double[][][] betas = binaryBetas.get(parent, left, right); int parentStates = betas.length; int leftStates = betas[0].length; int rightStates = betas[0][0].length; double[][][] newBetas; if (!startSymbols.contains(parent)) { newBetas = new double[parentStates * 2][leftStates][rightStates]; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { newBetas[i * 2][j][k] = betas[i][j][k]; newBetas[i * 2 + 1][j][k] = betas[i][j][k]; } } } parentStates *= 2; betas = newBetas; } newBetas = new double[parentStates][leftStates * 2][rightStates]; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { double leftWeight = 0.45 + random.nextDouble() * 0.1; newBetas[i][j * 2][k] = betas[i][j][k] + Math.log(leftWeight); newBetas[i][j * 2 + 1][k] = betas[i][j][k] + Math.log(1 - leftWeight); } } } leftStates *= 2; betas = newBetas; if (!right.equals(Lexicon.BOUNDARY_TAG)) { newBetas = new double[parentStates][leftStates][rightStates * 2]; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { double rightWeight = 0.45 + random.nextDouble() * 0.1; newBetas[i][j][k * 2] = betas[i][j][k] + Math.log(rightWeight); newBetas[i][j][k * 2 + 1] = betas[i][j][k] + Math.log(1 - rightWeight); } } } } tempBinaryBetas.put(parent, left, right, newBetas); } } } unaryBetas = tempUnaryBetas; binaryBetas = tempBinaryBetas; } /** * Recalculates the betas for all known transitions. The current * betas are used to produce probabilities, which then are used to * compute new betas. If splitStates is true, then the * probabilities produced are as if the states were split again from * the last time betas were calculated. *
* The return value is whether or not the betas have mostly * converged from the last time this method was called. Obviously * if splitStates was true, the betas will be entirely different, so * this is false. Otherwise, the new betas are compared against the * old values, and convergence means they differ by less than * EPSILON. */ public boolean recalculateBetas(boolean splitStates) { if (splitStates) { if (DEBUG()) { System.out.println("Pre-split betas"); outputBetas(); } splitBetas(); if (DEBUG()) { System.out.println("Post-split betas"); outputBetas(); } } TwoDimensionalMap tempUnaryBetas = new TwoDimensionalMap<>(); ThreeDimensionalMap tempBinaryBetas = new ThreeDimensionalMap<>(); recalculateTemporaryBetas(splitStates, null, tempUnaryBetas, tempBinaryBetas); boolean converged = useNewBetas(!splitStates, tempUnaryBetas, tempBinaryBetas); if (DEBUG()) { outputBetas(); } return converged; } public boolean useNewBetas(boolean testConverged, TwoDimensionalMap tempUnaryBetas, ThreeDimensionalMap tempBinaryBetas) { rescaleTemporaryBetas(tempUnaryBetas, tempBinaryBetas); // if we just split states, we have obviously not converged boolean converged = testConverged && testConvergence(tempUnaryBetas, tempBinaryBetas); unaryBetas = tempUnaryBetas; binaryBetas = tempBinaryBetas; wordIndex = tempWordIndex; tagIndex = tempTagIndex; lex = tempLex; if (DEBUG()) { System.out.println("LEXICON"); try { OutputStreamWriter osw = new OutputStreamWriter(System.out, "utf-8"); lex.writeData(osw); osw.flush(); } catch (IOException e) { throw new RuntimeIOException(e); } } tempWordIndex = null; tempTagIndex = null; tempLex = null; return converged; } /** * Creates temporary beta data structures and fills them in by * iterating over the trees. */ public void recalculateTemporaryBetas(boolean splitStates, Map totalStateMass, TwoDimensionalMap tempUnaryBetas, ThreeDimensionalMap tempBinaryBetas) { tempWordIndex = new HashIndex<>(); tempTagIndex = new HashIndex<>(); tempLex = op.tlpParams.lex(op, tempWordIndex, tempTagIndex); tempLex.initializeTraining(trainSize); for (Tree tree : trees) { double weight = treeWeights.getCount(tree); if (DEBUG()) { System.out.println("Incrementing trees read: " + weight); } tempLex.incrementTreesRead(weight); recalculateTemporaryBetas(tree, splitStates, totalStateMass, tempUnaryBetas, tempBinaryBetas); } tempLex.finishTraining(); } public boolean testConvergence(TwoDimensionalMap tempUnaryBetas, ThreeDimensionalMap tempBinaryBetas) { // now, we check each of the new betas to see if it's close to the // old value for the same transition. if not, we have not yet // converged. if all of them are, we have converged. for (String parentLabel : unaryBetas.firstKeySet()) { for (String childLabel : unaryBetas.get(parentLabel).keySet()) { double[][] betas = unaryBetas.get(parentLabel, childLabel); double[][] newBetas = tempUnaryBetas.get(parentLabel, childLabel); int parentStates = betas.length; int childStates = betas[0].length; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < childStates; ++j) { double oldValue = betas[i][j]; double newValue = newBetas[i][j]; if (Math.abs(newValue - oldValue) > EPSILON) { return false; } } } } } for (String parentLabel : binaryBetas.firstKeySet()) { for (String leftLabel : binaryBetas.get(parentLabel).firstKeySet()) { for (String rightLabel : binaryBetas.get(parentLabel).get(leftLabel).keySet()) { double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel); double[][][] newBetas = tempBinaryBetas.get(parentLabel, leftLabel, rightLabel); int parentStates = betas.length; int leftStates = betas[0].length; int rightStates = betas[0][0].length; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { double oldValue = betas[i][j][k]; double newValue = newBetas[i][j][k]; if (Math.abs(newValue - oldValue) > EPSILON) { return false; } } } } } } } return true; } public void recalculateTemporaryBetas(Tree tree, boolean splitStates, Map totalStateMass, TwoDimensionalMap tempUnaryBetas, ThreeDimensionalMap tempBinaryBetas) { if (DEBUG()) { System.out.println("Recalculating temporary betas for tree " + tree); } double[] stateWeights = { Math.log(treeWeights.getCount(tree)) }; IdentityHashMap unaryTransitions = new IdentityHashMap<>(); IdentityHashMap binaryTransitions = new IdentityHashMap<>(); recountTree(tree, splitStates, unaryTransitions, binaryTransitions); if (DEBUG()) { System.out.println(" Transitions:"); outputTransitions(tree, unaryTransitions, binaryTransitions); } recalculateTemporaryBetas(tree, stateWeights, 0, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas); } public int recalculateTemporaryBetas(Tree tree, double[] stateWeights, int position, IdentityHashMap unaryTransitions, IdentityHashMap binaryTransitions, Map totalStateMass, TwoDimensionalMap tempUnaryBetas, ThreeDimensionalMap tempBinaryBetas) { if (tree.isLeaf()) { // possible to get here if we have a tree with no structure return position; } if (totalStateMass != null) { double[] stateTotal = totalStateMass.get(tree.label().value()); if (stateTotal == null) { stateTotal = new double[stateWeights.length]; totalStateMass.put(tree.label().value(), stateTotal); } for (int i = 0; i < stateWeights.length; ++i) { stateTotal[i] += Math.exp(stateWeights[i]); } } if (tree.isPreTerminal()) { // fill in our new lexicon here. String tag = tree.label().value(); String word = tree.children()[0].label().value(); // We smooth by LEX_SMOOTH, if relevant. We rescale so that sum // of the weights being added to the lexicon stays the same. double total = 0.0; for (double stateWeight : stateWeights) { total += Math.exp(stateWeight); } if (total <= 0.0) { return position + 1; } double scale = 1.0 / (1.0 + LEX_SMOOTH); double smoothing = total * LEX_SMOOTH / stateWeights.length; for (int state = 0; state < stateWeights.length; ++state) { // TODO: maybe optimize all this TaggedWord creation TaggedWord tw = new TaggedWord(word, state(tag, state)); tempLex.train(tw, position, (Math.exp(stateWeights[state]) + smoothing) * scale); } return position + 1; } if (tree.children().length == 1) { String parentLabel = tree.label().value(); String childLabel = tree.children()[0].label().value(); double[][] transitions = unaryTransitions.get(tree); int parentStates = transitions.length; int childStates = transitions[0].length; double[][] betas = tempUnaryBetas.get(parentLabel, childLabel); if (betas == null) { betas = new double[parentStates][childStates]; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < childStates; ++j) { betas[i][j] = Double.NEGATIVE_INFINITY; } } tempUnaryBetas.put(parentLabel, childLabel, betas); } double[] childWeights = neginfDoubles(childStates); for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < childStates; ++j) { double weight = transitions[i][j]; betas[i][j] = SloppyMath.logAdd(betas[i][j], weight + stateWeights[i]); childWeights[j] = SloppyMath.logAdd(childWeights[j], weight + stateWeights[i]); } } position = recalculateTemporaryBetas(tree.children()[0], childWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas); } else { // length == 2 String parentLabel = tree.label().value(); String leftLabel = tree.children()[0].label().value(); String rightLabel = tree.children()[1].label().value(); double[][][] transitions = binaryTransitions.get(tree); int parentStates = transitions.length; int leftStates = transitions[0].length; int rightStates = transitions[0][0].length; double[][][] betas = tempBinaryBetas.get(parentLabel, leftLabel, rightLabel); if (betas == null) { betas = new double[parentStates][leftStates][rightStates]; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { betas[i][j][k] = Double.NEGATIVE_INFINITY; } } } tempBinaryBetas.put(parentLabel, leftLabel, rightLabel, betas); } double[] leftWeights = neginfDoubles(leftStates); double[] rightWeights = neginfDoubles(rightStates); for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { double weight = transitions[i][j][k]; betas[i][j][k] = SloppyMath.logAdd(betas[i][j][k], weight + stateWeights[i]); leftWeights[j] = SloppyMath.logAdd(leftWeights[j], weight + stateWeights[i]); rightWeights[k] = SloppyMath.logAdd(rightWeights[k], weight + stateWeights[i]); } } } position = recalculateTemporaryBetas(tree.children()[0], leftWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas); position = recalculateTemporaryBetas(tree.children()[1], rightWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas); } return position; } public void rescaleTemporaryBetas(TwoDimensionalMap tempUnaryBetas, ThreeDimensionalMap tempBinaryBetas) { for (String parent : tempUnaryBetas.firstKeySet()) { for (String child : tempUnaryBetas.get(parent).keySet()) { double[][] betas = tempUnaryBetas.get(parent).get(child); int parentStates = betas.length; int childStates = betas[0].length; for (int i = 0; i < parentStates; ++i) { double sum = Double.NEGATIVE_INFINITY; for (int j = 0; j < childStates; ++j) { sum = SloppyMath.logAdd(sum, betas[i][j]); } if (Double.isInfinite(sum)) { for (int j = 0; j < childStates; ++j) { betas[i][j] = -Math.log(childStates); } } else { for (int j = 0; j < childStates; ++j) { betas[i][j] -= sum; } } } } } for (String parent : tempBinaryBetas.firstKeySet()) { for (String left : tempBinaryBetas.get(parent).firstKeySet()) { for (String right : tempBinaryBetas.get(parent).get(left).keySet()) { double[][][] betas = tempBinaryBetas.get(parent).get(left).get(right); int parentStates = betas.length; int leftStates = betas[0].length; int rightStates = betas[0][0].length; for (int i = 0; i < parentStates; ++i) { double sum = Double.NEGATIVE_INFINITY; for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { sum = SloppyMath.logAdd(sum, betas[i][j][k]); } } if (Double.isInfinite(sum)) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { betas[i][j][k] = -Math.log(leftStates * rightStates); } } } else { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { betas[i][j][k] -= sum; } } } } } } } } public void recountTree(Tree tree, boolean splitStates, IdentityHashMap unaryTransitions, IdentityHashMap binaryTransitions) { IdentityHashMap probIn = new IdentityHashMap<>(); IdentityHashMap probOut = new IdentityHashMap<>(); recountTree(tree, splitStates, probIn, probOut, unaryTransitions, binaryTransitions); } public void recountTree(Tree tree, boolean splitStates, IdentityHashMap probIn, IdentityHashMap probOut, IdentityHashMap unaryTransitions, IdentityHashMap binaryTransitions) { recountInside(tree, splitStates, 0, probIn); if (DEBUG()) { System.out.println("ROOT PROBABILITY: " + probIn.get(tree)[0]); } recountOutside(tree, probIn, probOut); recountWeights(tree, probIn, probOut, unaryTransitions, binaryTransitions); } public void recountWeights(Tree tree, IdentityHashMap probIn, IdentityHashMap probOut, IdentityHashMap unaryTransitions, IdentityHashMap binaryTransitions) { if (tree.isLeaf() || tree.isPreTerminal()) { return; } if (tree.children().length == 1) { Tree child = tree.children()[0]; String parentLabel = tree.label().value(); String childLabel = child.label().value(); double[][] betas = unaryBetas.get(parentLabel, childLabel); double[] childInside = probIn.get(child); double[] parentOutside = probOut.get(tree); int parentStates = betas.length; int childStates = betas[0].length; double[][] transitions = new double[parentStates][childStates]; unaryTransitions.put(tree, transitions); for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < childStates; ++j) { transitions[i][j] = parentOutside[i] + childInside[j] + betas[i][j]; } } // Renormalize. Note that we renormalize to 1, regardless of // the original total. // TODO: smoothing? for (int i = 0; i < parentStates; ++i) { double total = Double.NEGATIVE_INFINITY; for (int j = 0; j < childStates; ++j) { total = SloppyMath.logAdd(total, transitions[i][j]); } // By subtracting off the log total, we make it so the log sum // of the transitions is 0, meaning the sum of the actual // transitions is 1. It works if you do the math... if (Double.isInfinite(total)) { double transition = -Math.log(childStates); for (int j = 0; j < childStates; ++j) { transitions[i][j] = transition; } } else { for (int j = 0; j < childStates; ++j) { transitions[i][j] = transitions[i][j] - total; } } } recountWeights(child, probIn, probOut, unaryTransitions, binaryTransitions); } else { // length == 2 Tree left = tree.children()[0]; Tree right = tree.children()[1]; String parentLabel = tree.label().value(); String leftLabel = left.label().value(); String rightLabel = right.label().value(); double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel); double[] leftInside = probIn.get(left); double[] rightInside = probIn.get(right); double[] parentOutside = probOut.get(tree); int parentStates = betas.length; int leftStates = betas[0].length; int rightStates = betas[0][0].length; double[][][] transitions = new double[parentStates][leftStates][rightStates]; binaryTransitions.put(tree, transitions); for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { transitions[i][j][k] = parentOutside[i] + leftInside[j] + rightInside[k] + betas[i][j][k]; } } } // Renormalize. Note that we renormalize to 1, regardless of // the original total. // TODO: smoothing? for (int i = 0; i < parentStates; ++i) { double total = Double.NEGATIVE_INFINITY; for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { total = SloppyMath.logAdd(total, transitions[i][j][k]); } } // By subtracting off the log total, we make it so the log sum // of the transitions is 0, meaning the sum of the actual // transitions is 1. It works if you do the math... if (Double.isInfinite(total)) { double transition = -Math.log(leftStates * rightStates); for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { transitions[i][j][k] = transition; } } } else { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { transitions[i][j][k] = transitions[i][j][k] - total; } } } } recountWeights(left, probIn, probOut, unaryTransitions, binaryTransitions); recountWeights(right, probIn, probOut, unaryTransitions, binaryTransitions); } } public void recountOutside(Tree tree, IdentityHashMap probIn, IdentityHashMap probOut) { double[] rootScores = { 0.0 }; probOut.put(tree, rootScores); recurseOutside(tree, probIn, probOut); } public void recurseOutside(Tree tree, IdentityHashMap probIn, IdentityHashMap probOut) { if (tree.isLeaf() || tree.isPreTerminal()) { return; } if (tree.children().length == 1) { recountOutside(tree.children()[0], tree, probIn, probOut); } else { // length == 2 recountOutside(tree.children()[0], tree.children()[1], tree, probIn, probOut); } } public void recountOutside(Tree child, Tree parent, IdentityHashMap probIn, IdentityHashMap probOut) { String parentLabel = parent.label().value(); String childLabel = child.label().value(); double[] parentScores = probOut.get(parent); double[][] betas = unaryBetas.get(parentLabel, childLabel); int parentStates = betas.length; int childStates = betas[0].length; double[] scores = neginfDoubles(childStates); probOut.put(child, scores); for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < childStates; ++j) { // TODO: no inside scores here, right? scores[j] = SloppyMath.logAdd(scores[j], betas[i][j] + parentScores[i]); } } recurseOutside(child, probIn, probOut); } public void recountOutside(Tree left, Tree right, Tree parent, IdentityHashMap probIn, IdentityHashMap probOut) { String parentLabel = parent.label().value(); String leftLabel = left.label().value(); String rightLabel = right.label().value(); double[] leftInsideScores = probIn.get(left); double[] rightInsideScores = probIn.get(right); double[] parentScores = probOut.get(parent); double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel); int parentStates = betas.length; int leftStates = betas[0].length; int rightStates = betas[0][0].length; double[] leftScores = neginfDoubles(leftStates); probOut.put(left, leftScores); double[] rightScores = neginfDoubles(rightStates); probOut.put(right, rightScores); for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { leftScores[j] = SloppyMath.logAdd(leftScores[j], betas[i][j][k] + parentScores[i] + rightInsideScores[k]); rightScores[k] = SloppyMath.logAdd(rightScores[k], betas[i][j][k] + parentScores[i] + leftInsideScores[j]); } } } recurseOutside(left, probIn, probOut); recurseOutside(right, probIn, probOut); } public int recountInside(Tree tree, boolean splitStates, int loc, IdentityHashMap probIn) { if (tree.isLeaf()) { throw new RuntimeException(); } else if (tree.isPreTerminal()) { int stateCount = getStateSplitCount(tree); String word = tree.children()[0].label().value(); String tag = tree.label().value(); double[] scores = new double[stateCount]; probIn.put(tree, scores); if (splitStates && !tag.equals(Lexicon.BOUNDARY_TAG)) { for (int i = 0; i < stateCount / 2; ++i) { IntTaggedWord tw = new IntTaggedWord(word, state(tag, i), wordIndex, tagIndex); double logProb = lex.score(tw, loc, word, null); double wordWeight = 0.45 + random.nextDouble() * 0.1; scores[i * 2] = logProb + Math.log(wordWeight); scores[i * 2 + 1] = logProb + Math.log(1.0 - wordWeight); if (DEBUG()) { System.out.println("Lexicon log prob " + state(tag, i) + "-" + word + ": " + logProb); System.out.println(" Log Split -> " + scores[i * 2] + "," + scores[i * 2 + 1]); } } } else { for (int i = 0; i < stateCount; ++i) { IntTaggedWord tw = new IntTaggedWord(word, state(tag, i), wordIndex, tagIndex); double prob = lex.score(tw, loc, word, null); if (DEBUG()) { System.out.println("Lexicon log prob " + state(tag, i) + "-" + word + ": " + prob); } scores[i] = prob; } } loc = loc + 1; } else if (tree.children().length == 1) { loc = recountInside(tree.children()[0], splitStates, loc, probIn); double[] childScores = probIn.get(tree.children()[0]); String parentLabel = tree.label().value(); String childLabel = tree.children()[0].label().value(); double[][] betas = unaryBetas.get(parentLabel, childLabel); int parentStates = betas.length; // size of the first key int childStates = betas[0].length; double[] scores = neginfDoubles(parentStates); probIn.put(tree, scores); for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < childStates; ++j) { scores[i] = SloppyMath.logAdd(scores[i], childScores[j] + betas[i][j]); } } if (DEBUG()) { System.out.println(parentLabel + " -> " + childLabel); for (int i = 0; i < parentStates; ++i) { System.out.println(" " + i + ":" + scores[i]); for (int j = 0; j < childStates; ++j) { System.out.println(" " + i + "," + j + ": " + betas[i][j] + " | " + Math.exp(betas[i][j])); } } } } else { // length == 2 loc = recountInside(tree.children()[0], splitStates, loc, probIn); loc = recountInside(tree.children()[1], splitStates, loc, probIn); double[] leftScores = probIn.get(tree.children()[0]); double[] rightScores = probIn.get(tree.children()[1]); String parentLabel = tree.label().value(); String leftLabel = tree.children()[0].label().value(); String rightLabel = tree.children()[1].label().value(); double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel); int parentStates = betas.length; int leftStates = betas[0].length; int rightStates = betas[0][0].length; double[] scores = neginfDoubles(parentStates); probIn.put(tree, scores); for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { scores[i] = SloppyMath.logAdd(scores[i], leftScores[j] + rightScores[k] + betas[i][j][k]); } } } if (DEBUG()) { System.out.println(parentLabel + " -> " + leftLabel + "," + rightLabel); for (int i = 0; i < parentStates; ++i) { System.out.println(" " + i + ":" + scores[i]); for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { System.out.println(" " + i + "," + j + "," + k + ": " + betas[i][j][k] + " | " + Math.exp(betas[i][j][k])); } } } } } return loc; } public void mergeStates() { if (op.trainOptions.splitRecombineRate <= 0.0) { return; } // we go through the machinery to sum up the temporary betas, // counting the total mass TwoDimensionalMap tempUnaryBetas = new TwoDimensionalMap<>(); ThreeDimensionalMap tempBinaryBetas = new ThreeDimensionalMap<>(); Map totalStateMass = Generics.newHashMap(); recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas); // Next, for each tree we count the effect of merging its // annotations. We only consider the most recently split // annotations as candidates for merging. Map deltaAnnotations = Generics.newHashMap(); for (Tree tree : trees) { countMergeEffects(tree, totalStateMass, deltaAnnotations); } // Now we have a map of the (approximate) likelihood loss from // merging each state. We merge the ones that provide the least // benefit, up to the splitRecombineRate List> sortedDeltas = new ArrayList<>(); for (String state : deltaAnnotations.keySet()) { double[] scores = deltaAnnotations.get(state); for (int i = 0; i < scores.length; ++i) { sortedDeltas.add(new Triple<>(state, i * 2, scores[i])); } } Collections.sort(sortedDeltas, new Comparator>() { public int compare(Triple first, Triple second) { // The most useful splits will have a large loss in // likelihood if they are merged. Thus, we want those at // the end of the list. This means we make the comparison // "backwards", sorting from high to low. return Double.compare(second.third(), first.third()); } public boolean equals(Object o) { return o == this; } }); // for (Triple delta : sortedDeltas) { // System.out.println(delta.first() + "-" + delta.second() + ": " + delta.third()); // } // System.out.println("-------------"); // Only merge a fraction of the splits based on what the user // originally asked for int splitsToMerge = (int) (sortedDeltas.size() * op.trainOptions.splitRecombineRate); splitsToMerge = Math.max(0, splitsToMerge); splitsToMerge = Math.min(sortedDeltas.size() - 1, splitsToMerge); sortedDeltas = sortedDeltas.subList(0, splitsToMerge); System.out.println(); System.out.println(sortedDeltas); Map mergeCorrespondence = buildMergeCorrespondence(sortedDeltas); recalculateMergedBetas(mergeCorrespondence); for (Triple delta : sortedDeltas) { stateSplitCounts.decrementCount(delta.first(), 1); } } public void recalculateMergedBetas(Map mergeCorrespondence) { TwoDimensionalMap tempUnaryBetas = new TwoDimensionalMap<>(); ThreeDimensionalMap tempBinaryBetas = new ThreeDimensionalMap<>(); tempWordIndex = new HashIndex<>(); tempTagIndex = new HashIndex<>(); tempLex = op.tlpParams.lex(op, tempWordIndex, tempTagIndex); tempLex.initializeTraining(trainSize); for (Tree tree : trees) { double treeWeight = treeWeights.getCount(tree); double[] stateWeights = { Math.log(treeWeight) }; tempLex.incrementTreesRead(treeWeight); IdentityHashMap oldUnaryTransitions = new IdentityHashMap<>(); IdentityHashMap oldBinaryTransitions = new IdentityHashMap<>(); recountTree(tree, false, oldUnaryTransitions, oldBinaryTransitions); IdentityHashMap unaryTransitions = new IdentityHashMap<>(); IdentityHashMap binaryTransitions = new IdentityHashMap<>(); mergeTransitions(tree, oldUnaryTransitions, oldBinaryTransitions, unaryTransitions, binaryTransitions, stateWeights, mergeCorrespondence); recalculateTemporaryBetas(tree, stateWeights, 0, unaryTransitions, binaryTransitions, null, tempUnaryBetas, tempBinaryBetas); } tempLex.finishTraining(); useNewBetas(false, tempUnaryBetas, tempBinaryBetas); } /** * Given a tree and the original set of transition probabilities * from one state to the next in the tree, along with a list of the * weights in the tree and a count of the mass in each substate at * the current node, this method merges the probabilities as * necessary. The results go into newUnaryTransitions and * newBinaryTransitions. */ public void mergeTransitions(Tree parent, IdentityHashMap oldUnaryTransitions, IdentityHashMap oldBinaryTransitions, IdentityHashMap newUnaryTransitions, IdentityHashMap newBinaryTransitions, double[] stateWeights, Map mergeCorrespondence) { if (parent.isPreTerminal() || parent.isLeaf()) { return; } if (parent.children().length == 1) { double[][] oldTransitions = oldUnaryTransitions.get(parent); String parentLabel = parent.label().value(); int[] parentCorrespondence = mergeCorrespondence.get(parentLabel); int parentStates = parentCorrespondence[parentCorrespondence.length - 1] + 1; String childLabel = parent.children()[0].label().value(); int[] childCorrespondence = mergeCorrespondence.get(childLabel); int childStates = childCorrespondence[childCorrespondence.length - 1] + 1; // System.out.println("P: " + parentLabel + " " + parentStates + // " C: " + childLabel + " " + childStates); // Add up the probabilities of transitioning to each state, // scaled by the probability of being in a given state to begin // with. This accounts for when two states in the parent are // collapsed into one state. double[][] newTransitions = new double[parentStates][childStates]; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < childStates; ++j) { newTransitions[i][j] = Double.NEGATIVE_INFINITY; } } newUnaryTransitions.put(parent, newTransitions); for (int i = 0; i < oldTransitions.length; ++i) { int ti = parentCorrespondence[i]; for (int j = 0; j < oldTransitions[0].length; ++j) { int tj = childCorrespondence[j]; // System.out.println(i + " " + ti + " " + j + " " + tj); newTransitions[ti][tj] = SloppyMath.logAdd(newTransitions[ti][tj], oldTransitions[i][j] + stateWeights[i]); } } // renormalize for (int i = 0; i < parentStates; ++i) { double total = Double.NEGATIVE_INFINITY; for (int j = 0; j < childStates; ++j) { total = SloppyMath.logAdd(total, newTransitions[i][j]); } if (Double.isInfinite(total)) { for (int j = 0; j < childStates; ++j) { newTransitions[i][j] = -Math.log(childStates); } } else { for (int j = 0; j < childStates; ++j) { newTransitions[i][j] -= total; } } } double[] childWeights = neginfDoubles(oldTransitions[0].length); for (int i = 0; i < oldTransitions.length; ++i) { for (int j = 0; j < oldTransitions[0].length; ++j) { double weight = oldTransitions[i][j]; childWeights[j] = SloppyMath.logAdd(childWeights[j], weight + stateWeights[i]); } } mergeTransitions(parent.children()[0], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, childWeights, mergeCorrespondence); } else { double[][][] oldTransitions = oldBinaryTransitions.get(parent); String parentLabel = parent.label().value(); int[] parentCorrespondence = mergeCorrespondence.get(parentLabel); int parentStates = parentCorrespondence[parentCorrespondence.length - 1] + 1; String leftLabel = parent.children()[0].label().value(); int[] leftCorrespondence = mergeCorrespondence.get(leftLabel); int leftStates = leftCorrespondence[leftCorrespondence.length - 1] + 1; String rightLabel = parent.children()[1].label().value(); int[] rightCorrespondence = mergeCorrespondence.get(rightLabel); int rightStates = rightCorrespondence[rightCorrespondence.length - 1] + 1; // System.out.println("P: " + parentLabel + " " + parentStates + // " L: " + leftLabel + " " + leftStates + // " R: " + rightLabel + " " + rightStates); double[][][] newTransitions = new double[parentStates][leftStates][rightStates]; for (int i = 0; i < parentStates; ++i) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { newTransitions[i][j][k] = Double.NEGATIVE_INFINITY; } } } newBinaryTransitions.put(parent, newTransitions); for (int i = 0; i < oldTransitions.length; ++i) { int ti = parentCorrespondence[i]; for (int j = 0; j < oldTransitions[0].length; ++j) { int tj = leftCorrespondence[j]; for (int k = 0; k < oldTransitions[0][0].length; ++k) { int tk = rightCorrespondence[k]; // System.out.println(i + " " + ti + " " + j + " " + tj + " " + k + " " + tk); newTransitions[ti][tj][tk] = SloppyMath.logAdd(newTransitions[ti][tj][tk], oldTransitions[i][j][k] + stateWeights[i]); } } } // renormalize for (int i = 0; i < parentStates; ++i) { double total = Double.NEGATIVE_INFINITY; for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { total = SloppyMath.logAdd(total, newTransitions[i][j][k]); } } if (Double.isInfinite(total)) { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { newTransitions[i][j][k] = -Math.log(leftStates * rightStates); } } } else { for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { newTransitions[i][j][k] -= total; } } } } double[] leftWeights = neginfDoubles(oldTransitions[0].length); double[] rightWeights = neginfDoubles(oldTransitions[0][0].length); for (int i = 0; i < oldTransitions.length; ++i) { for (int j = 0; j < oldTransitions[0].length; ++j) { for (int k = 0; k < oldTransitions[0][0].length; ++k) { double weight = oldTransitions[i][j][k]; leftWeights[j] = SloppyMath.logAdd(leftWeights[j], weight + stateWeights[i]); rightWeights[k] = SloppyMath.logAdd(rightWeights[k], weight + stateWeights[i]); } } } mergeTransitions(parent.children()[0], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, leftWeights, mergeCorrespondence); mergeTransitions(parent.children()[1], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, rightWeights, mergeCorrespondence); } } Map buildMergeCorrespondence(List> deltas) { Map mergeCorrespondence = Generics.newHashMap(); for (String state : originalStates) { int states = getStateSplitCount(state); int[] correspondence = new int[states]; for (int i = 0; i < states; ++i) { correspondence[i] = i; } mergeCorrespondence.put(state, correspondence); } for (Triple merge : deltas) { int states = getStateSplitCount(merge.first()); int split = merge.second(); int[] correspondence = mergeCorrespondence.get(merge.first()); for (int i = split + 1; i < states; ++i) { correspondence[i] = correspondence[i] - 1; } } return mergeCorrespondence; } public void countMergeEffects(Tree tree, Map totalStateMass, Map deltaAnnotations) { IdentityHashMap probIn = new IdentityHashMap<>(); IdentityHashMap probOut = new IdentityHashMap<>(); IdentityHashMap unaryTransitions = new IdentityHashMap<>(); IdentityHashMap binaryTransitions = new IdentityHashMap<>(); recountTree(tree, false, probIn, probOut, unaryTransitions, binaryTransitions); // no need to count the root for (Tree child : tree.children()) { countMergeEffects(child, totalStateMass, deltaAnnotations, probIn, probOut); } } public void countMergeEffects(Tree tree, Map totalStateMass, Map deltaAnnotations, IdentityHashMap probIn, IdentityHashMap probOut) { if (tree.isLeaf()) { return; } if (tree.label().value().equals(Lexicon.BOUNDARY_TAG)) { return; } String label = tree.label().value(); double totalMass = 0.0; double[] stateMass = totalStateMass.get(label); for (double mass : stateMass) { totalMass += mass; } double[] nodeProbIn = probIn.get(tree); double[] nodeProbOut = probOut.get(tree); double[] nodeDelta = deltaAnnotations.get(label); if (nodeDelta == null) { nodeDelta = new double[nodeProbIn.length / 2]; deltaAnnotations.put(label, nodeDelta); } for (int i = 0; i < nodeProbIn.length / 2; ++i) { double probInMerged = SloppyMath.logAdd(Math.log(stateMass[i * 2] / totalMass) + nodeProbIn[i * 2], Math.log(stateMass[i * 2 + 1] / totalMass) + nodeProbIn[i * 2 + 1]); double probOutMerged = SloppyMath.logAdd(nodeProbOut[i * 2], nodeProbOut[i * 2 + 1]); double probMerged = probInMerged + probOutMerged; double probUnmerged = SloppyMath.logAdd(nodeProbIn[i * 2] + nodeProbOut[i * 2], nodeProbIn[i * 2 + 1] + nodeProbOut[i * 2 + 1]); nodeDelta[i] = nodeDelta[i] + probMerged - probUnmerged; } if (tree.isPreTerminal()) { return; } for (Tree child : tree.children()) { countMergeEffects(child, totalStateMass, deltaAnnotations, probIn, probOut); } } public void buildStateIndex() { stateIndex = new HashIndex<>(); for (String key : stateSplitCounts.keySet()) { for (int i = 0; i < stateSplitCounts.getIntCount(key); ++i) { stateIndex.addToIndex(state(key, i)); } } } public void buildGrammars() { // In order to build the grammars, we first need to fill in the // temp betas with the sums of the transitions from Ax to By or Ax // to By,Cz. We also need the sum total of the mass in each state // Ax over all the trees. // we go through the machinery to sum up the temporary betas, // counting the total mass... TwoDimensionalMap tempUnaryBetas = new TwoDimensionalMap<>(); ThreeDimensionalMap tempBinaryBetas = new ThreeDimensionalMap<>(); Map totalStateMass = Generics.newHashMap(); recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas); // ... but note we don't actually rescale the betas. // instead we use the temporary betas and the total mass in each // state to calculate the grammars // First build up a BinaryGrammar. // The score for each rule will be the Beta scores found earlier, // scaled by the total weight of a transition between unsplit states BinaryGrammar bg = new BinaryGrammar(stateIndex); for (String parent : tempBinaryBetas.firstKeySet()) { int parentStates = getStateSplitCount(parent); double[] stateTotal = totalStateMass.get(parent); for (String left : tempBinaryBetas.get(parent).firstKeySet()) { int leftStates = getStateSplitCount(left); for (String right : tempBinaryBetas.get(parent).get(left).keySet()) { int rightStates = getStateSplitCount(right); double[][][] betas = tempBinaryBetas.get(parent, left, right); for (int i = 0; i < parentStates; ++i) { if (stateTotal[i] < EPSILON) { continue; } for (int j = 0; j < leftStates; ++j) { for (int k = 0; k < rightStates; ++k) { int parentIndex = stateIndex.indexOf(state(parent, i)); int leftIndex = stateIndex.indexOf(state(left, j)); int rightIndex = stateIndex.indexOf(state(right, k)); double score = betas[i][j][k] - Math.log(stateTotal[i]); BinaryRule br = new BinaryRule(parentIndex, leftIndex, rightIndex, score); bg.addRule(br); } } } } } } // Now build up a UnaryGrammar UnaryGrammar ug = new UnaryGrammar(stateIndex); for (String parent : tempUnaryBetas.firstKeySet()) { int parentStates = getStateSplitCount(parent); double[] stateTotal = totalStateMass.get(parent); for (String child : tempUnaryBetas.get(parent).keySet()) { int childStates = getStateSplitCount(child); double[][] betas = tempUnaryBetas.get(parent, child); for (int i = 0; i < parentStates; ++i) { if (stateTotal[i] < EPSILON) { continue; } for (int j = 0; j < childStates; ++j) { int parentIndex = stateIndex.indexOf(state(parent, i)); int childIndex = stateIndex.indexOf(state(child, j)); double score = betas[i][j] - Math.log(stateTotal[i]); UnaryRule ur = new UnaryRule(parentIndex, childIndex, score); ug.addRule(ur); } } } } bgug = new Pair<>(ug, bg); } public void saveTrees(Collection trees1, double weight1, Collection trees2, double weight2) { trainSize = 0.0; int treeCount = 0; trees.clear(); treeWeights.clear(); for (Tree tree : trees1) { trees.add(tree); treeWeights.incrementCount(tree, weight1); trainSize += weight1; } treeCount += trees1.size(); if (trees2 != null && weight2 >= 0.0) { for (Tree tree : trees2) { trees.add(tree); treeWeights.incrementCount(tree, weight2); trainSize += weight2; } treeCount += trees2.size(); } log.info("Found " + treeCount + " trees with total weight " + trainSize); } public void extract(Collection treeList) { extract(treeList, 1.0, null, 0.0); } /** * First, we do a few setup steps. We read in all the trees, which * is necessary because we continually reprocess them and use the * object pointers as hash keys rather than hashing the trees * themselves. We then count the initial states in the treebank. *
* Having done that, we then assign initial probabilities to the * trees. At first, each state has 1.0 of the probability mass for * each Ax-ByCz and Ax-By transition. We then split the number of * states and the probabilities on each tree. *
* We then repeatedly recalculate the betas and reannotate the * weights, going until we converge, which is defined as no betas * move more then epsilon. *
* java -mx4g edu.stanford.nlp.parser.lexparser.LexicalizedParser -PCFG -saveToSerializedFile englishSplit.ser.gz -saveToTextFile englishSplit.txt -maxLength 40 -train ../data/wsj/wsjtwentytrees.mrg -testTreebank ../data/wsj/wsjtwentytrees.mrg -evals "factDA,tsv" -uwm 0 -hMarkov 0 -vMarkov 0 -simpleBinarizedLabels -noRebinarization -predictSplits -splitTrainingThreads 1 -splitCount 1 -splitRecombineRate 0.5 *
* may also need *
* -smoothTagsThresh 0 *
* java -mx8g edu.stanford.nlp.parser.lexparser.LexicalizedParser -evals "factDA,tsv" -PCFG -vMarkov 0 -hMarkov 0 -uwm 0 -saveToSerializedFile wsjS1.ser.gz -maxLength 40 -train /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 200-2199 -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 2200-2219 -compactGrammar 0 -simpleBinarizedLabels -predictSplits -smoothTagsThresh 0 -splitCount 1 -noRebinarization */ public void extract(Collection trees1, double weight1, Collection trees2, double weight2) { saveTrees(trees1, weight1, trees2, weight2); countOriginalStates(); // Initial betas will be 1 for all possible unary and binary // transitions in our treebank initialBetasAndLexicon(); for (int cycle = 0; cycle < op.trainOptions.splitCount; ++cycle) { // All states except the root state get split into 2 splitStateCounts(); // first, recalculate the betas and the lexicon for having split // the transitions recalculateBetas(true); // now, loop until we converge while recalculating betas // TODO: add a loop counter, stop after X iterations iteration = 0; boolean converged = false; while (!converged && iteration < MAX_ITERATIONS) { if (DEBUG()) { System.out.println(); System.out.println(); System.out.println("-------------------"); System.out.println("Iteration " + iteration); } converged = recalculateBetas(false); ++iteration; } log.info("Converged for cycle " + cycle + " in " + iteration + " iterations"); mergeStates(); } // Build up the state index. The BG & UG both expect a set count // of states. buildStateIndex(); buildGrammars(); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy