edu.berkeley.nlp.PCFGLA.TreeAnnotations Maven / Gradle / Ivy
Show all versions of berkeleyparser Show documentation
package edu.berkeley.nlp.PCFGLA;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import edu.berkeley.nlp.ling.CollinsHeadFinder;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees;
import edu.berkeley.nlp.util.Filter;
/**
* Class which contains code for annotating and binarizing trees for the parser's use, and debinarizing and unannotating them for scoring.
*/
public class TreeAnnotations implements java.io.Serializable {
private static final long serialVersionUID = 1L;
static CollinsHeadFinder headFinder = new CollinsHeadFinder();
/** This annotates the parse tree by adding ancestors to the tags, and then by forgetfully binarizing the tree.
* The format goes as follows:
* Tag becomes Tag^Parent^Grandparent
* Then, this is binarized, so that
* Tag^Parent^Grandparent produces A^Tag^Parent B... C...
* becomes
* Tag^Parent^Grandparent produces A^Tag^Parent @Tag^Parent^Grandparent->_A^Tag^Parent
* @Tag^Parent^Grandparent->_A^Tag^Parent produces B^Tag^Parent @Tag^Parent^Grandparent->_A^Tag^Parent_B^Tag^Parent
* and finally we trim the excess _* off to control the amount of horizontal history
*
* */
public static Tree processTree(Tree unAnnotatedTree,
int nVerticalAnnotations, int nHorizontalAnnotations,
Binarization binarization, boolean manualAnnotation) {
return processTree(unAnnotatedTree, nVerticalAnnotations, nHorizontalAnnotations, binarization, manualAnnotation,false, true);
}
public static Tree processTree(Tree unAnnotatedTree,
int nVerticalAnnotations, int nHorizontalAnnotations,
Binarization binarization, boolean manualAnnotation,
boolean annotateUnaryParents, boolean markGrammarSymbols) {
Tree verticallyAnnotated = unAnnotatedTree;
if (nVerticalAnnotations==3) {
verticallyAnnotated = annotateVerticallyTwice(unAnnotatedTree, "", "");
}
else if (nVerticalAnnotations==2) {
if (manualAnnotation) {
verticallyAnnotated = annotateManuallyVertically(unAnnotatedTree, "");
} else {
verticallyAnnotated = annotateVertically(unAnnotatedTree, "");
}
}
else if (nVerticalAnnotations==1){
if (markGrammarSymbols) verticallyAnnotated = markGrammarNonterminals(unAnnotatedTree,"");
if (annotateUnaryParents) verticallyAnnotated = markUnaryParents(verticallyAnnotated);
}
else {
throw new Error("the code does not exist to annotate vertically "+nVerticalAnnotations+" times");
}
Tree binarizedTree = binarizeTree(verticallyAnnotated,binarization);
//removeUnaryChains(binarizedTree);
//System.out.println(binarizedTree);
// if (deleteLabels) return deleteLabels(binarizedTree,true);
// else if (deletePC) return deletePC(binarizedTree,true);
// else
return forgetLabels(binarizedTree,nHorizontalAnnotations);
}
/**
* Binarize a tree with the given binarization style; e.g. head binarization, left binarization, etc.
*
* @param tree
* @param binarization The type of binarization used.
* @return
*/
public static Tree binarizeTree(Tree tree, Binarization binarization) {
switch(binarization) {
case LEFT:
return leftBinarizeTree(tree);
case RIGHT:
return rightBinarizeTree(tree);
case PARENT:
return parentBinarizeTree(tree);
case HEAD:
return headBinarizeTree(tree);
}
return null;
}
private static Tree annotateVerticallyTwice(Tree tree, String parentLabel1, String parentLabel2){
Tree verticallyMarkovizatedTree;
if (tree.isLeaf()) {
verticallyMarkovizatedTree = tree; //new Tree(tree.getLabel());// + parentLabel);
} else {
List> children = new ArrayList>();
for (Tree child : tree.getChildren()) {
// children.add(annotateVerticallyTwice(child, parentLabel2,"^"+tree.getLabel()));
children.add(annotateVerticallyTwice(child, "^"+tree.getLabel(), parentLabel1));
}
verticallyMarkovizatedTree = new Tree(tree.getLabel() + parentLabel1+parentLabel2,children);
}
return verticallyMarkovizatedTree;
}
private static Tree annotateVertically(Tree tree, String parentLabel){
Tree verticallyMarkovizatedTree;
if (tree.isLeaf()){
verticallyMarkovizatedTree = tree;//new Tree(tree.getLabel());// + parentLabel);
}
else {
List> children = new ArrayList>();
for (Tree child : tree.getChildren()) {
children.add(annotateVertically(child, "^"+tree.getLabel()));
}
verticallyMarkovizatedTree = new Tree(tree.getLabel() + parentLabel,children);
}
return verticallyMarkovizatedTree;
}
private static Tree markGrammarNonterminals(Tree tree, String parentLabel){
Tree verticallyMarkovizatedTree;
if (tree.isPreTerminal()){
verticallyMarkovizatedTree = tree;//new Tree(tree.getLabel());// + parentLabel);
}
else {
List> children = new ArrayList>();
for (Tree child : tree.getChildren()) {
children.add(markGrammarNonterminals(child, "^g"));//""));//
}
verticallyMarkovizatedTree = new Tree(tree.getLabel() + parentLabel,children);
}
return verticallyMarkovizatedTree;
}
private static Tree markUnaryParents(Tree tree){
Tree verticallyMarkovizatedTree;
if (tree.isPreTerminal()){
verticallyMarkovizatedTree = tree;//new Tree(tree.getLabel());// + parentLabel);
}
else {
List> children = new ArrayList>();
for (Tree child : tree.getChildren()) {
children.add(markUnaryParents(child));//
}
String add = "";
if (!tree.getLabel().equals("ROOT"))
add = (children.size() == 1 ? "^u" : "");
verticallyMarkovizatedTree = new Tree(tree.getLabel() + add,children);
}
return verticallyMarkovizatedTree;
}
private static Tree annotateManuallyVertically(Tree tree, String parentLabel){
Tree verticallyMarkovizatedTree;
if (tree.isPreTerminal()){
// split only some of the POS tags
// DT, RB, IN, AUX, CC, %
String label = tree.getLabel();
if (label.contains("DT") || label.contains("RB") ||
label.contains("IN") || label.contains("AUX") ||
label.contains("CC") || label.contains("%") ){
verticallyMarkovizatedTree = new Tree(tree.getLabel() + parentLabel,tree.getChildren());
} else {
verticallyMarkovizatedTree = tree;//new Tree(tree.getLabel());// + parentLabel);
}
}
else {
List> children = new ArrayList>();
for (Tree child : tree.getChildren()) {
children.add(annotateManuallyVertically(child, "^"+tree.getLabel()));
}
verticallyMarkovizatedTree = new Tree(tree.getLabel() + parentLabel,children);
}
return verticallyMarkovizatedTree;
}
// replaces labels with three types of labels:
// X, @X=Y and Z
private static Tree deleteLabels(Tree tree, boolean isRoot) {
String label = tree.getLabel();
String newLabel = "";
if (isRoot){
newLabel = label;
}
else if (tree.isPreTerminal()){
newLabel = "Z";
return new Tree(newLabel,tree.getChildren());
}
else if (label.charAt(0)=='@') {
newLabel = "@X";
}
else newLabel = "X";
List> transformedChildren = new ArrayList>();
for (Tree child : tree.getChildren()) {
transformedChildren.add(deleteLabels(child, false));
}
return new Tree(newLabel, transformedChildren);
}
// replaces phrasal categories with
// X, @X=Y but keeps POS-tags
private static Tree deletePC(Tree tree, boolean isRoot) {
String label = tree.getLabel();
String newLabel = "";
if (isRoot){
newLabel = label;
}
else if (tree.isPreTerminal()){
return tree;
}
else if (label.charAt(0)=='@') {
newLabel = "@X";
}
else newLabel = "X";
List> transformedChildren = new ArrayList>();
for (Tree child : tree.getChildren()) {
transformedChildren.add(deletePC(child, false));
}
return new Tree(newLabel, transformedChildren);
}
private static Tree forgetLabels(Tree tree, int nHorizontalAnnotation) {
if (nHorizontalAnnotation==-1) return tree;
String transformedLabel = tree.getLabel();
if (tree.isLeaf()) {
return new Tree(transformedLabel);
}
//the location of the farthest _
int firstCutIndex = transformedLabel.indexOf('_');
int keepBeginning = firstCutIndex;
//will become -1 when the end of the line is reached
int secondCutIndex = transformedLabel.indexOf('_',firstCutIndex+1);
//the location of the second farthest _
int cutIndex = secondCutIndex;
while (secondCutIndex != -1) {
cutIndex = firstCutIndex;
firstCutIndex = secondCutIndex;
secondCutIndex = transformedLabel.indexOf('_',firstCutIndex+1);
}
if (nHorizontalAnnotation == 0) {
cutIndex = transformedLabel.indexOf('>')-1;
if (cutIndex > 0) transformedLabel = transformedLabel.substring(0,cutIndex);
} else if (cutIndex > 0 && !tree.isLeaf()) {
if (nHorizontalAnnotation == 2) {
transformedLabel = transformedLabel.substring(0, keepBeginning) + transformedLabel.substring(cutIndex);
} else if (nHorizontalAnnotation == 1) {
transformedLabel = transformedLabel.substring(0, keepBeginning) + transformedLabel.substring(firstCutIndex);
} else {
throw new Error("code does not exist to horizontally annotate at level "+ nHorizontalAnnotation);
}
}
List> transformedChildren = new ArrayList>();
for (Tree child : tree.getChildren()) {
transformedChildren.add(forgetLabels(child,nHorizontalAnnotation));
}
/*if (!transformedLabel.equals("ROOT")&& transformedLabel.length()>1){
transformedLabel = transformedLabel.substring(0,2);
}*/
/*if (tree.isPreTerminal() && transformedLabel.length()>1){
if (transformedLabel.substring(0,2).equals("NN")){
transformedLabel = "NNX";
}
else if (transformedLabel.equals("VBZ") || transformedLabel.equals("VBP") || transformedLabel.equals("VBD") || transformedLabel.equals("VB") ){
transformedLabel = "VBX";
}
else if (transformedLabel.substring(0,3).equals("PRP")){
transformedLabel = "PRPX";
}
else if (transformedLabel.equals("JJR") || transformedLabel.equals("JJS") ){
transformedLabel = "JJX";
}
else if (transformedLabel.equals("RBR") || transformedLabel.equals("RBS") ){
transformedLabel = "RBX";
}
else if (transformedLabel.equals("WDT") || transformedLabel.equals("WP") || transformedLabel.equals("WP$")){
transformedLabel = "WBX";
}
}*/
return new Tree(transformedLabel, transformedChildren);
}
static Tree leftBinarizeTree(Tree tree) {
String label = tree.getLabel();
List> children = tree.getChildren();
if (tree.isLeaf())
return new Tree(label);
else if (children.size() == 1) {
return new Tree(label, Collections.singletonList(leftBinarizeTree(children.get(0))));
}
// otherwise, it's a binary-or-more local tree, so decompose it into a sequence of binary and unary trees.
String intermediateLabel = "@"+label+"->";
Tree intermediateTree = leftBinarizeTreeHelper(tree, 0, intermediateLabel);
return new Tree(label, intermediateTree.getChildren());
}
private static Tree leftBinarizeTreeHelper(Tree tree, int numChildrenGenerated, String intermediateLabel) {
Tree leftTree = tree.getChildren().get(numChildrenGenerated);
List> children = new ArrayList>(2);
children.add(leftBinarizeTree(leftTree));
if (numChildrenGenerated == tree.getChildren().size() - 2) {
children.add(leftBinarizeTree(tree.getChildren().get(numChildrenGenerated+1)));
} else if (numChildrenGenerated < tree.getChildren().size() - 2) {
Tree rightTree = leftBinarizeTreeHelper(tree, numChildrenGenerated+1, intermediateLabel+"_"+leftTree.getLabel());
children.add(rightTree);
}
return new Tree(intermediateLabel, children);
}
static Tree rightBinarizeTree(Tree tree) {
String label = tree.getLabel();
List> children = tree.getChildren();
if (tree.isLeaf())
return new Tree(label);
else if (children.size() == 1) {
return new Tree(label, Collections
.singletonList(rightBinarizeTree(children.get(0))));
}
// otherwise, it's a binary-or-more local tree, so decompose it into a
// sequence of binary and unary trees.
String intermediateLabel = "@" + label + "->";
Tree intermediateTree = rightBinarizeTreeHelper(tree, children
.size() - 1, intermediateLabel);
return new Tree(label, intermediateTree.getChildren());
}
private static Tree rightBinarizeTreeHelper(Tree tree,
int numChildrenLeft, String intermediateLabel) {
Tree rightTree = tree.getChildren().get(numChildrenLeft);
List> children = new ArrayList>(2);
if (numChildrenLeft == 1) {
children.add(rightBinarizeTree(tree.getChildren()
.get(numChildrenLeft - 1)));
} else if (numChildrenLeft > 1) {
Tree leftTree = rightBinarizeTreeHelper(tree,
numChildrenLeft - 1, intermediateLabel + "_" + rightTree.getLabel());
children.add(leftTree);
}
children.add(rightBinarizeTree(rightTree));
return new Tree(intermediateLabel, children);
}
/**
* Binarize a tree around the head symbol. That is, when there is an n-ary
* rule, with n > 2, we split it into a series of binary rules with titles
* like [AT]JJ-R (if JJ is the head of the rule). The right part of the symbol
* (-R or -L) is used to indicate whether we're producing to the right or to
* the left of the head symbol. Thus, the head symbol is always the deepest
* symbol on the tree we've created.
*
* @param tree
* @return
*/
static Tree headBinarizeTree(Tree tree) {
return headParentBinarizeTree(Binarization.HEAD,tree);
}
/**
* Binarize a tree around the head symbol, but with symbol names derived from
* the parent symbol, rather than the head (as in {@link headBinarizeTree}).
*
* @param tree
* @return
*/
static Tree parentBinarizeTree(Tree tree) {
return headParentBinarizeTree(Binarization.PARENT,tree);
}
/**
* Binarize a tree around its head, with the symbol names derived from either
* the parent or the head (as determined by binarization).
*
* It calls {@link @headParentBinarizeTreeHelper} to do the messy work of
* binarization when that is actually necessary.
*
* @param binarization
* This determines whether the newly-created symbols are based on the
* head symbol or on the parent symbol. It should be either
* headBinarize or parentBinarize.
* @param tree
* @return
*/
private static Tree headParentBinarizeTree(Binarization binarization, Tree tree) {
List> children = tree.getChildren();
if (children.size()==0) {
return tree;
} else if (children.size()==1) {
List> kids = new ArrayList>(1);
kids.add(headParentBinarizeTree(binarization,children.get(0)));
return new Tree(tree.getLabel(),kids);
} else if (children.size()==2) {
List> kids = new ArrayList>(2);
kids.add(headParentBinarizeTree(binarization,children.get(0)));
kids.add(headParentBinarizeTree(binarization,children.get(1)));
return new Tree(tree.getLabel(),kids);
} else {
List> kids = new ArrayList>(1);
kids.add(headParentBinarizeTreeHelper(binarization, tree, 0, children.size() - 1,
headFinder.determineHead(tree), false, ""));
return new Tree(tree.getLabel(),kids);
}
}
/**
* This binarizes a tree into a bunch of binary [at]SYM-R symbols. It assumes
* that this sort of binarization is always necessary, so it is only called by
* {@link headParentBinarizeTree}.
*
* @param binarization
* The type of new symbols to generate, either head or parent.
* @param tree
* @param leftChild
* The index of the leftmost child remaining to be binarized.
* @param rightChild
* The index of the rightmost child remaining to be binarized.
* @param head
* The head symbol of this level of the tree.
* @param right
* This indicates whether we have gotten to the right of the head
* child yet.
* @return
*/
static Tree headParentBinarizeTreeHelper(Binarization binarization,
Tree tree, int leftChild, int rightChild, Tree head,
boolean right, String productionHistory) {
if (head==null)
throw new Error("head is null");
List> children = tree.getChildren();
//test if we've finally come to the head word
if (!right && children.get(leftChild)==head)
right=true;
//prepare the parent label
String label = null;
if (binarization==Binarization.HEAD) {
label = head.getLabel();
} else if (binarization==Binarization.PARENT) {
label = tree.getLabel();
}
String parentLabel = "@"+label+ (right ? "-R" : "-L")+"->"+productionHistory;
//if the left child == the right child, then we only need a unary
if (leftChild==rightChild) {
ArrayList> kids = new ArrayList>(1);
kids.add(headParentBinarizeTree(binarization, children.get(leftChild)));
return new Tree(parentLabel,kids);
}
//if we're to the left of the head word
if (!right) {
ArrayList> kids = new ArrayList>(2);
Tree child = children.get(leftChild);
kids.add(headParentBinarizeTree(binarization, child));
kids.add(headParentBinarizeTreeHelper(binarization, tree,leftChild+1,rightChild,head,right,productionHistory+"_"+child.getLabel()));
return new Tree(parentLabel,kids);
}
//if we're to the right of the head word
else {
ArrayList> kids = new ArrayList>(2);
Tree child = children.get(rightChild);
kids.add(headParentBinarizeTreeHelper(binarization, tree,leftChild,rightChild-1,head,right,productionHistory+"_"+child.getLabel()));
kids.add(headParentBinarizeTree(binarization, child));
return new Tree(parentLabel,kids);
}
}
public static Tree unAnnotateTreeSpecial(Tree annotatedTree) {
// Remove intermediate nodes (labels beginning with "Y"
// Remove all material on node labels which follow their base symbol (cuts at the leftmost -, ^, or : character)
// Examples: a node with label @NP->DT_JJ will be spliced out, and a node with label NP^S will be reduced to NP
Tree debinarizedTree = Trees.spliceNodes(annotatedTree, new Filter() {
public boolean accept(String s) {
return s.startsWith("Y");
}
});
Tree unAnnotatedTree = (new Trees.FunctionNodeStripper()).transformTree(debinarizedTree);
return unAnnotatedTree;
}
public static Tree debinarizeTree(Tree annotatedTree) {
// Remove intermediate nodes (labels beginning with "@"
// Remove all material on node labels which follow their base symbol (cuts at the leftmost -, ^, or : character)
// Examples: a node with label @NP->DT_JJ will be spliced out, and a node with label NP^S will be reduced to NP
Tree debinarizedTree = Trees.spliceNodes(annotatedTree, new Filter() {
public boolean accept(String s) {
return s.startsWith("@") && !s.equals("@");
}
});
return debinarizedTree;
}
public static Tree unAnnotateTree(Tree annotatedTree, boolean keepFunctionLabel) {
// Remove intermediate nodes (labels beginning with "@"
// Remove all material on node labels which follow their base symbol (cuts at the leftmost -, ^, or : character)
// Examples: a node with label @NP->DT_JJ will be spliced out, and a node with label NP^S will be reduced to NP
Tree debinarizedTree = Trees.spliceNodes(annotatedTree, new Filter() {
public boolean accept(String s) {
return s.startsWith("@") && !s.equals("@");
}
});
if (keepFunctionLabel) return debinarizedTree;
Tree unAnnotatedTree = (new Trees.FunctionNodeStripper()).transformTree(debinarizedTree);
return unAnnotatedTree;
}
public static void main(String args[]) {
//test the binarization
Trees.PennTreeReader reader = new Trees.PennTreeReader(new StringReader("((S (NP (DT the) (JJ quick) (JJ (AA (BB (CC brown)))) (NN fox)) (VP (VBD jumped) (PP (IN over) (NP (DT the) (JJ lazy) (NN dog)))) (. .)))"));
Tree tree = reader.next();
System.out.println("tree");
System.out.println(Trees.PennTreeRenderer.render(tree));
for (Binarization binarization : Binarization.values()) {
System.out.println("binarization type "+binarization.name());
//print the binarization
try {
Tree binarizedTree = binarizeTree(tree,binarization);
System.out.println(Trees.PennTreeRenderer.render(binarizedTree));
System.out.println("unbinarized");
Tree unBinarizedTree = unAnnotateTree(binarizedTree, false);
System.out.println(Trees.PennTreeRenderer.render(unBinarizedTree));
System.out.println("------------");
} catch (Error e) {
System.out.println("binarization not implemented");
}
}
}
public static Tree removeSuperfluousNodes(Tree tree){
if (tree.isPreTerminal()) return tree;
if (tree.isLeaf()) return tree;
List> gChildren = tree.getChildren();
if (gChildren.size()!=1) {
// nothing to do, just recurse
ArrayList> children = new ArrayList>();
for (int i=0; i cChild = removeSuperfluousNodes(tree.getChildren().get(i));
children.add(cChild);
}
tree.setChildren(children);
return tree;
}
Tree result = null;
String parent = tree.getLabel();
HashSet nodesInChain = new HashSet();
tree = tree.getChildren().get(0);
while (!tree.isPreTerminal() && tree.getChildren().size()==1){
if (!nodesInChain.contains(tree.getLabel())){
nodesInChain.add(tree.getLabel());
}
tree = tree.getChildren().get(0);
}
Tree child = removeSuperfluousNodes(tree);
String cLabel = child.getLabel();
ArrayList> childs = new ArrayList>();
childs.add(child);
if (cLabel.equals(parent)) {
result = child;
} else {
result = new Tree(parent,childs);
}
for (String node : nodesInChain){
if (node.equals(parent)||node.equals(cLabel)) continue;
Tree intermediate = new Tree(node, result.getChildren());
childs = new ArrayList>();
childs.add(intermediate);
result.setChildren(childs);
}
return result;
}
public static void displayUnaryChains(Tree tree, String parent){
if (tree.getChildren().size()==1){
if (!parent.equals("") && !tree.isPreTerminal()) System.out.println("Unary chain: "+parent+" -> "+ tree.getLabel() + " -> " +tree.getChildren().get(0).getLabel());
if (!tree.isPreTerminal()) displayUnaryChains(tree.getChildren().get(0),tree.getLabel());
}
else {
for (Tree child : tree.getChildren()){
if (!child.isPreTerminal()) displayUnaryChains(child,"");
}
}
}
public static void removeUnaryChains(Tree tree){
if (tree.isPreTerminal()) return;
if (tree.getChildren().size()==1 && tree.getChildren().get(0).getChildren().size()==1){
// unary chain
if (tree.getChildren().get(0).isPreTerminal()) return; // if we are just above a preterminal, dont do anything
else {// otherwise remove the intermediate node
ArrayList> newChildren = new ArrayList>();
newChildren.add(tree.getChildren().get(0).getChildren().get(0));
tree.setChildren(newChildren);
}
}
for (Tree child : tree.getChildren()){
removeUnaryChains(child);
}
}
}