All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.PCFGLA.CoarseToFineNBestParser Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
/**
 * 
 */
package edu.berkeley.nlp.PCFGLA;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.ScalingTools;

/**
 * @author petrov
 *
 */
public class CoarseToFineNBestParser extends CoarseToFineMaxRuleParser{
	LazyList[][][] chartBeforeU;
	LazyList[][][] chartAfterU;
	int k;
	List maxRuleScores;
	int tmp_k;
	
	
	/**
	 * @param gr
	 * @param lex
	 * @param unaryPenalty
	 * @param endL
	 * @param viterbi
	 * @param sub
	 * @param score
	 * @param accurate
	 * @param variational
	 * @param useGoldPOS
	 */
	public CoarseToFineNBestParser(Grammar gr, Lexicon lex, int k, double unaryPenalty, int endL, boolean viterbi, 
			boolean sub, boolean score, boolean accurate, boolean variational, boolean useGoldPOS, boolean initCascade) {
		super(gr, lex, unaryPenalty, endL, viterbi, sub, score, accurate, variational,useGoldPOS, initCascade);
		this.k = k;
	}
	
  /** Assumes that inside and outside scores (sum version, not viterbi) have been computed.
   *  In particular, the narrowRExtent and other arrays need not be updated.
   */
  void doConstrainedMaxCScores(List sentence, Grammar grammar, Lexicon lexicon, final boolean scale) {
  	numSubStatesArray = grammar.numSubStates;
    double initVal = Double.NEGATIVE_INFINITY;
    chartBeforeU = new LazyList[length][length + 1][numStates];
    chartAfterU = new LazyList[length][length + 1][numStates];

    
    double logNormalizer = iScore[0][length][0][0];
//    double thresh2 = threshold*logNormalizer;
    for (int diff = 1; diff <= length; diff++) {
      //System.out.print(diff + " ");
      for (int start = 0; start < (length - diff + 1); start++) {
        int end = start + diff;
        if (diff > 1) {
          // diff > 1: Try binary rules
        	for (int pState=0; pState= narrowR); // can this right constituent fit next to the left constituent?
              if (!iPossibleR) { continue; }
              
              int min1 = narrowR;
              int min2 = wideLExtent[end][rState];
              int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent?
              if (min > narrowL) { continue; }
              
              int max1 = wideRExtent[start][lState];
              int max2 = narrowL;
              int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent?
              if (min > max) { continue; }

              double[][][] scores = r.getScores2();
              int nLeftChildStates = numSubStatesArray[lState]; // == scores.length;
              int nRightChildStates = numSubStatesArray[rState]; // == scores[0].length;
              for (int split = min; split <= max; split++) {
                double ruleScore = 0;
  	      			if (!allowedStates[start][split][lState]) continue;
  	      			if (!allowedStates[split][end][rState]) continue;
                
  	      			HyperEdge bestLeft = chartAfterU[start][split][lState].getKbest(0);
                double leftChildScore = (bestLeft==null) ? Double.NEGATIVE_INFINITY : bestLeft.score; 
                
                HyperEdge bestRight = chartAfterU[split][end][rState].getKbest(0);
                double rightChildScore = (bestRight==null) ? Double.NEGATIVE_INFINITY : bestRight.score; 
                
//  	      			double leftChildScore = maxcScore[start][split][lState];
//                double rightChildScore = maxcScore[split][end][rState];
                if (leftChildScore==initVal||rightChildScore==initVal) continue;

               	double scalingFactor = 0.0;
               	if (scale) scalingFactor = Math.log(ScalingTools.calcScaleFactor(
                		oScale[start][end][pState]+iScale[start][split][lState]+
                		iScale[split][end][rState]-iScale[0][length][0]));
               	double gScore = leftChildScore + scalingFactor + rightChildScore;
               	
               	if (gScore == Double.NEGATIVE_INFINITY) continue; // no chance of finding a better derivation
               	
               	for (int lp = 0; lp < nLeftChildStates; lp++) {
                  double lIS = iScore[start][split][lState][lp];
                  if (lIS == 0) continue;
//                  if (lIS < thresh2) continue;
                  //if (!allowedSubStates[start][split][lState][lp]) continue;

                  for (int rp = 0; rp < nRightChildStates; rp++) {
                    if (scores[lp][rp]==null) continue;
                    double rIS = iScore[split][end][rState][rp];
                    if (rIS == 0) continue;
//                    if (rIS < thresh2) continue;
                    //if (!allowedSubStates[split][end][rState][rp]) continue;
                    for (int np = 0; np < nParentStates; np++) {
                      //if (!allowedSubStates[start][end][pState][np]) continue;
                      double pOS = oScore[start][end][pState][np];
                      if (pOS == 0) continue;
//                      if (pOS < thresh2) continue;

                      double ruleS = scores[lp][rp][np];
                      if (ruleS == 0) continue;
                      ruleScore += (pOS * ruleS * lIS * rIS) / logNormalizer;
                    }
                  }
                }
                if (ruleScore==0) continue;

                ruleScore = Math.log(ruleScore);
               	gScore += ruleScore;


                if (gScore > Double.NEGATIVE_INFINITY) {
                	HyperEdge newElement = new HyperEdge(pState,lState,rState,0,0,0,start,split,end,gScore,ruleScore);
                	if (gScore>bestScore){
                		bestScore = gScore;
                		bestElement = newElement;
                	}
                	if (diff>2) chartBeforeU[start][end][pState].addToFringe(newElement);
                }
              }
            }
            if (diff==2&&bestElement!=null) chartBeforeU[start][end][pState].addToFringe(bestElement);
//            chartBeforeU[start][end][pState].expandNextBest();
          }
        } else { // diff == 1
          // We treat TAG --> word exactly as if it was a unary rule, except the score of the rule is
          // given by the lexicon rather than the grammar and that we allow another unary on top of it.
          //for (int tag : lexicon.getAllTags()){
        	for (int tag=0; tag Double.NEGATIVE_INFINITY) {
            	HyperEdge newElement = new HyperEdge(pState,cState,0,0,start,end,gScore, ruleScore);
            	if (gScore>bestScore){
            		bestScore = gScore;
            		bestElement = newElement;
            	}
              if (diff>1) chartAfterU[start][end][pState].addToFringe(newElement);
            }
          }
          if (diff==1&&bestElement!=null) chartAfterU[start][end][pState].addToFringe(bestElement);
          if (chartBeforeU[start][end][pState]!=null){
		          HyperEdge bestSelf = chartBeforeU[start][end][pState].getKbest(0);
		        if (bestSelf != null){
		        	HyperEdge selfRule = new HyperEdge(pState,pState,0,0,start,end,bestSelf.score,0);
		        	chartAfterU[start][end][pState].addToFringe(selfRule);
		        }
          }

//          chartAfterU[start][end][pState].expandNextBest();
        }
//      	maxcScore[start][end] = maxcScoreStartEnd;
      }
    }
  }
  
  /**
   * Returns the best parse, the one with maximum expected labelled recall.
   * Assumes that the maxc* arrays have been filled.
   */
  public Tree extractBestMaxRuleParse(int start, int end, List sentence) {
    return extractBestMaxRuleParse1(start, end, 0, 0, sentence);
//  	System.out.println(extractBestMaxRuleParse1(start, end, 0, 0, sentence));
//  	System.out.println(extractBestMaxRuleParse1(start, end, 0, 1, sentence));
//  	System.out.println(extractBestMaxRuleParse1(start, end, 0, 2, sentence));
//  	return extractBestMaxRuleParse1(start, end, 0, 3, sentence);
  }

  public List> extractKBestMaxRuleParses(int start, int end, List sentence, int k) {
  	List> list = new ArrayList>(k);
  	maxRuleScores = new ArrayList(k);
  	tmp_k = 0;
  	for (int i=0; i tmp = extractBestMaxRuleParse1(start, end, 0, i, sentence);
    	if (tmp!=null){
    		maxRuleScores.add(chartAfterU[0][length][0].getKbest(i).score);
    	}
//      HyperEdge parentNode = chartAfterU[start][end][0].getKbest(i);
//      if (parentNode!=null) System.out.println(parentNode.score+" ");
    	if (tmp!=null) list.add(tmp);
    	else break;
    }
  	return list;

  }

  
	public double getModelScore(Tree parsedTree) {
		return maxRuleScores.get(tmp_k++);
	}

  
  /**
   * Returns the best parse for state "state", potentially starting with a unary rule
   */
  public Tree extractBestMaxRuleParse1(int start, int end, int state, int suboptimalities, List sentence ) {
  	//System.out.println(start+", "+end+";");

    HyperEdge parentNode = chartAfterU[start][end][state].getKbest(suboptimalities);
    if (parentNode==null){
    	System.err.println("Don't have a "+(suboptimalities+1)+"-best tree.");
    	return null;
    }
    int cState = parentNode.childState;
    Tree result = null;
    
    HyperEdge childNode = chartBeforeU[start][end][cState].getKbest(parentNode.childBest);
    
    List> children = new ArrayList>();
    String stateStr = (String)tagNumberer.object(cState);//+""+start+""+end;
    if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2);

  	boolean posLevel = (end - start == 1);
    if (posLevel) {
    //        List> childs = new ArrayList>();
//        childs.add(new Tree(sentence.get(start)));
//        String stateStr2 = (String)tagNumberer.object(childNode.parentState);//+""+start+""+end;
//        children.add(new Tree(stateStr2,childs));
//    	}
//    	else { 
    		children.add(new Tree(sentence.get(start)));
//    	}
    } else {
      int split = childNode.split;
      if (split == -1) {
        System.err.println("Warning: no symbol can generate the span from "+ start+ " to "+end+".");
        System.err.println("The score is "+maxcScore[start][end][state]+" and the state is supposed to be "+stateStr);
        System.err.println("The insideScores are "+Arrays.toString(iScore[start][end][state])+" and the outsideScores are " +Arrays.toString(oScore[start][end][state]));
        System.err.println("The maxcScore is "+maxcScore[start][end][state]);
        //return  extractBestMaxRuleParse2(start, end, maxcChild[start][end][state], sentence);
        return  new Tree("ROOT");      
      }
      int lState = childNode.lChildState;
      int rState = childNode.rChildState;
      Tree leftChildTree = extractBestMaxRuleParse1(start, split, lState, childNode.lChildBest, sentence);
      Tree rightChildTree = extractBestMaxRuleParse1(split, end, rState, childNode.rChildBest, sentence);
      children.add(leftChildTree);
      children.add(rightChildTree);
    }
    
    boolean scale = false;
    updateConstrainedMaxCScores(sentence, scale, childNode);

    result = new Tree(stateStr, children);
    if (cState != state){ // unaryRule
      stateStr = (String)tagNumberer.object(state);//+""+start+""+end;
      if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2);

    	 int intermediateNode = grammar.getUnaryIntermediate((short)state,(short)cState);
	    if (intermediateNode>0){
	      List> restoredChild = new ArrayList>();
	      String stateStr2 = (String)tagNumberer.object(intermediateNode);
	      if (stateStr2.endsWith("^g")) stateStr2 = stateStr2.substring(0,stateStr2.length()-2);
	      restoredChild.add(result);
	    	result = new Tree(stateStr2,restoredChild);
	    }
	    List> childs = new ArrayList>();
	    childs.add(result);
	    result = new Tree(stateStr,childs);
    }
    updateConstrainedMaxCScores(sentence, scale, parentNode);
    
    return result; 
  }


  void updateConstrainedMaxCScores(List sentence, final boolean scale, HyperEdge parent) {

    int start = parent.start;
    int end = parent.end;
    int pState = parent.parentState;
    int suboptimalities = parent.parentBest + 1;
    double ruleScore = parent.ruleScore;
    
		if (parent.alreadyExpanded) 
			return;
		
    if (!parent.isUnary) {
//      if (chartBeforeU[start][end][pState].sortedListSize() >= suboptimalities) return; // already have enough derivations

      int lState = parent.lChildState;
      int rState = parent.rChildState;
      int split = parent.split;

      HyperEdge newParentL = null, newParentR = null;
     	if (split-start>1) { // left is not a POS
     		int lBest = parent.lChildBest+1;
     		HyperEdge lChild = chartAfterU[start][split][lState].getKbest(lBest);
     		if (lChild!=null){
       		int rBest = parent.rChildBest;
       		HyperEdge rChild = chartAfterU[split][end][rState].getKbest(rBest);
       		double newScore = lChild.score + rChild.score + ruleScore;
       		newParentL = new HyperEdge(pState,lState,rState,suboptimalities,lBest,rBest,start,split,end,newScore,ruleScore);
//     			chartBeforeU[start][end][pState].addToFringe(newParentL); 
     		}
     	}
     	if (end-split>1){
     		int rBest = parent.rChildBest+1;
     		HyperEdge rChild = chartAfterU[split][end][rState].getKbest(rBest);
     		if (rChild!=null){
       		int lBest = parent.lChildBest;
       		HyperEdge lChild = chartAfterU[start][split][lState].getKbest(lBest);
       		double newScore = lChild.score + rChild.score + ruleScore;
     			newParentR = new HyperEdge(pState,lState,rState,suboptimalities,lBest,rBest,start,split,end,newScore,ruleScore);
//     			chartBeforeU[start][end][pState].addToFringe(newParentR);
     		}
     	}
     	
 			if (newParentL!=null && newParentR!=null && newParentL.score > newParentR.score) chartBeforeU[start][end][pState].addToFringe(newParentL);
 			else if (newParentL!=null && newParentR!=null) chartBeforeU[start][end][pState].addToFringe(newParentR);
 			else if (newParentL!=null || newParentR!=null){
 				if (newParentL!=null) chartBeforeU[start][end][pState].addToFringe(newParentL);
 				else /*newParentR!=null*/ chartBeforeU[start][end][pState].addToFringe(newParentR);
 			}
 			parent.alreadyExpanded = true;
 				
//      chartBeforeU[start][end][pState].expandNextBest();
    }
    else { // unary
//      if (chartAfterU[start][end][pState].sortedListSize() >= suboptimalities) return; // already have enough derivations
      
      int cState = parent.childState;
   		int cBest = parent.childBest+1;

   		if (end-start>1){
	   		HyperEdge child = chartBeforeU[start][end][cState].getKbest(cBest);
	   		if (child!=null){
	     		double newScore = child.score + ruleScore;
	   			HyperEdge newParent = new HyperEdge(pState,cState,suboptimalities,cBest,start,end,newScore,ruleScore);
//         	if (newScore>=parent.score)
//         		System.out.println("ullala");
       		chartAfterU[start][end][pState].addToFringe(newParent);
	   		}
	   		parent.alreadyExpanded = true;
//	   		chartAfterU[start][end][pState].expandNextBest();
   		}
    }
  }

  
  public List> getKBestConstrainedParses(List sentence, List posTags, int k) {
  	if (sentence.size()==0) {
  		ArrayList> result = new ArrayList>();
  		result.add(new Tree("ROOT"));
  		return result;
  	}
  	doPreParses(sentence,null,false,posTags);
  	List> bestTrees = null;
  	double score = 0;
  	//bestTree = extractBestViterbiParse(0, 0, 0, length, sentence);
  	//score = viScore[0][length][0];
    if (true){//score != Double.NEGATIVE_INFINITY) {
    	//score = Math.log(score) + (100*iScale[0][length][0]);
      //System.out.println("\nFound a parse for sentence with length "+length+". The LL is "+score+".");
      
      //voScore[0][length][0] = 0.0;
      //doConstrainedViterbiOutsideScores(baseGrammar); 
      
      //pruneChart(pruningThreshold, baseGrammar.numSubStates, grammar.numSubStates, true);
	  	Grammar curGrammar = grammarCascade[endLevel-startLevel+1];
	  	Lexicon curLexicon = lexiconCascade[endLevel-startLevel+1];
  		//numSubStatesArray = grammar.numSubStates;
      //clearArrays();
	    double initVal = (viterbiParse) ? Double.NEGATIVE_INFINITY : 0;
	    int level = isBaseline ? 1 : endLevel;
			createArrays(false,curGrammar.numStates,curGrammar.numSubStates,level,initVal,false);
			initializeChart(sentence,curLexicon,false,false,posTags,false);
			doConstrainedInsideScores(curGrammar,viterbiParse,viterbiParse);

	    score = iScore[0][length][0][0];
    	if (!viterbiParse) score = Math.log(score);// + (100*iScale[0][length][0]);
    	logLikelihood = score;
      if (score != Double.NEGATIVE_INFINITY) {
//      	System.out.println("\nFinally found a parse for sentence with length "+length+". The LL is "+score+".");
      	      
	      if (!viterbiParse) {
	      	oScore[0][length][0][0] = 1.0;
	      	doConstrainedOutsideScores(curGrammar,viterbiParse,false); 
		    	doConstrainedMaxCScores(sentence,curGrammar,curLexicon,false);
		    	
		    }

	      //Tree withoutRoot = extractBestMaxRuleParse(0, length, sentence);
	      // add the root
	      //ArrayList> rootChild = new ArrayList>();
	      //rootChild.add(withoutRoot);
	      //bestTree = new Tree("ROOT",rootChild);
	      
	      //System.out.print(bestTree);
      }
      else {
//      	System.out.println("Using scaling code for sentence with length "+length+".");
      	setupScaling();
				initializeChart(sentence,curLexicon,false,false,posTags,true);
				doScaledConstrainedInsideScores(curGrammar);
				score = iScore[0][length][0][0];
				if (!viterbiParse) score = Math.log(score) + (100*iScale[0][length][0]);
//				System.out.println("Finally found a parse for sentence with length "+length+". The LL is "+score+".");
//				System.out.println("Scale: "+iScale[0][length][0]);
      	oScore[0][length][0][0] = 1.0;
      	oScale[0][length][0] = 0;
      	doScaledConstrainedOutsideScores(curGrammar); 
	    	doConstrainedMaxCScores(sentence,curGrammar,curLexicon,true);
      }
      
      grammar = curGrammar;
      lexicon = curLexicon;
      bestTrees = extractKBestMaxRuleParses(0, length, sentence, k);
    } 
    return bestTrees;
  }
  
  public CoarseToFineNBestParser newInstance(){
  	CoarseToFineNBestParser newParser = new CoarseToFineNBestParser(grammar, lexicon, k, unaryPenalty, endLevel, viterbiParse, outputSub, outputScore, accurate, this.doVariational,useGoldPOS, false);
  	newParser.initCascade(this);
  	return newParser;
  }

  public synchronized Object call() {
  	List> result = getKBestConstrainedParses(nextSentence, null, k);
		nextSentence = null;
		synchronized(queue) {
      queue.add(result,-nextSentenceID);
      queue.notifyAll();
		}
		return null;
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy