All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.PCFGLA.SpanPredictor Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
/**
 * 
 */
package edu.berkeley.nlp.PCFGLA;

import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;

import edu.berkeley.nlp.discPCFG.WordInSentence;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.Pair;
import edu.berkeley.nlp.util.PriorityQueue;

/**
 * @author petrov
 *
 */
public class SpanPredictor implements Serializable{

	private static final long serialVersionUID = 1L;

	public final boolean useFirstAndLast;
	public final boolean usePreviousAndNext; // can only be on if useFirstAndLast is on
	public final boolean useBeginAndEndPairs;
	public final boolean useSyntheticClass;
	
	public final boolean usePunctuation;
	Indexer punctuationSignatures;
	boolean[] isPunctuation;
	
	public final boolean useOnlyWords = true;
	
//	public final boolean useCapitalization;
	public final int minFeatureFrequency;
	
	public final int minSpanLength = 3;
	
	public double[][] firstWordScore;
	public double[][] lastWordScore;
	public double[][] previousWordScore;
	public double[][] nextWordScore;
	
	public double[][] beginPairScore;
	public double[][] endPairScore;
	private HashMap,Integer> beginMap;
	private HashMap,Integer> endMap;
	
	public double[][] punctuationScores;
	
	public int nWords;
	public int nFeatures;
	private int[] stateClass;
	private int nClasses;
	private Indexer wordIndexer;
	
//	public int startIndexPrevious, startIndexBegin;
	
	public SpanPredictor(int nWords, StateSetTreeList trainTrees, Numberer tagNumberer, Indexer wordIndexer){
		this.useFirstAndLast = ConditionalTrainer.Options.useFirstAndLast;
		this.usePreviousAndNext = ConditionalTrainer.Options.usePreviousAndNext;
		this.useBeginAndEndPairs = ConditionalTrainer.Options.useBeginAndEndPairs;
		this.useSyntheticClass = ConditionalTrainer.Options.useSyntheticClass;
		this.usePunctuation = ConditionalTrainer.Options.usePunctuation;
		this.minFeatureFrequency = ConditionalTrainer.Options.minFeatureFrequency;

		this.wordIndexer = wordIndexer;
		this.nWords = nWords;
		this.nFeatures = 0;
		if (useSyntheticClass){
			System.out.println("Distinguishing between real and synthetic classes.");
			stateClass = new int[tagNumberer.total()];
			for (int i=0; i, Integer>();
		endMap = new HashMap, Integer>();
		Counter> beginPairCounter = new Counter>();
		Counter> endPairCounter = new Counter>();
		int beginPairs = 0, endPairs = 0;
		for (Tree tree : trainTrees){
			List words = tree.getYield();
			StateSet stateSet = words.get(0);
			int prevIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
			if (useOnlyWords) prevIndex = stateSet.wordIndex;
			int currIndex = -1;
			for (int i=1; i<=words.size()-minSpanLength; i++){
				stateSet = words.get(i);
				currIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
				if (useOnlyWords) currIndex = stateSet.wordIndex;
				Pair pair = new Pair(prevIndex,currIndex);
				beginPairCounter.incrementCount(pair, 1.0);
				if (!beginMap.containsKey(pair))
					beginMap.put(pair,beginPairs++);
				prevIndex = currIndex;
			}
			if (words.size() < minSpanLength) continue;
			stateSet = words.get(minSpanLength-1);
			prevIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
			if (useOnlyWords) currIndex = stateSet.wordIndex;
			for (int i=minSpanLength; i pair = new Pair(prevIndex,currIndex);
				endPairCounter.incrementCount(pair, 1.0);
				if (!endMap.containsKey(pair))
					endMap.put(pair,endPairs++);
				prevIndex = currIndex;
			}
		}
		HashMap, Integer> newBeginMap = new HashMap, Integer>();
		HashMap, Integer> newEndMap = new HashMap, Integer>();
		
		int newBeginPairs = 0;
		for (Pair pair : beginMap.keySet()){
			if (beginPairCounter.getCount(pair) >= minFeatureFrequency){
				newBeginMap.put(pair, newBeginPairs++);
			}
		}
		beginMap = newBeginMap;
		beginPairs = newBeginPairs;
		
		int newEndPairs = 0;
		for (Pair pair : endMap.keySet()){
			if (endPairCounter.getCount(pair) >= minFeatureFrequency){
				newEndMap.put(pair, newEndPairs++);
			}
		}
		endMap = newEndMap;
		endPairs = newEndPairs;

		beginPairScore = new double[beginPairs][nClasses];
		endPairScore = new double[endPairs][nClasses];
		nFeatures += (beginPairs + endPairs)*nClasses;
		System.out.println("There were "+beginPairs+" begin-pair types and "+endPairs+" end-pair types.");
	}

	public double[] scoreSpan(int previousIndex, int firstIndex, int lastIndex, int followingIndex){
		double[] result = new double[nClasses];
		Arrays.fill(result, 1);
		if (firstIndex<0||lastIndex<0) {
//			System.out.println("unseen index when scoring span: "+firstIndex+" "+lastIndex);
			return result;
		}
		for (int c=0; c=0) result[c] *= previousWordScore[previousIndex][c];
				if (followingIndex>=0) result[c] *= nextWordScore[followingIndex][c];
			}
			if (useBeginAndEndPairs){
				if (previousIndex>=0) {
					int index = getBeginIndex(previousIndex, firstIndex);
					if (index>=0) result[c] *= beginPairScore[index][c];
				}
				if (followingIndex>=0) {
					int index = getEndIndex(lastIndex, followingIndex);
					if (index>=0) result[c] *= endPairScore[index][c];
				}
			}
			if (SloppyMath.isDangerous(result[c])){
				System.out.println("Dangerous span prediction set to 1, since it was "+result);
				result[c] = 1;
			}
		}
		return result;
	}
	
  public double[][][] predictSpans(List sentence) {
  	int previousIndex=-1, firstIndex, lastIndex, followingIndex=-1;
  	int length = sentence.size();
  	double[][][] spanScores = new double[length][length+1][nClasses]; 
		// all spans of size <=minSpanLength are ok
  	for (int start = 0; start < length; start++) {
			for (int end = start + 1; end < start+minSpanLength && end<=length; end++) {
				for (int clas=0; clas < nClasses; clas++){
					spanScores[start][end][clas] = 1;
				}
			}
		}
		for (int start = 0; start <= length-minSpanLength; start++) {
			StateSet stateSet = sentence.get(start); 
			firstIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
			if (useOnlyWords) firstIndex = stateSet.wordIndex;
			for (int end = start + minSpanLength; end <= length; end++) {
				stateSet = sentence.get(end-1); 
				lastIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
				if (useOnlyWords) lastIndex = stateSet.wordIndex;
				if (end tree : trainTrees){
  		List words = tree.getYield();
  		if (usePunctuation) punctuationSig = getPunctuationSignatures(words);
  		countGoldSpanFeaturesHelper(tree, words, firstWordCount, lastWordCount, 
  				previousWordCount, nextWordCount, beginPairsCount, endPairsCount, 
  				punctuationCount, punctuationSig);
  	}
  	
  	double[] res =  new double[nFeatures];
  	int index = 0;
  	if (useFirstAndLast){
  		int firstSum = 0, lastSum = 0;
  		for (int c=0; c tree, List words, 
  		int[][] firstWordCount, int[][] lastWordCount, int[][] previousWordCount, int[][] nextWordCount,
			int[][] beginPairsCount, int[][] endPairsCount, int[][] punctuationCount, int[][] punctuationSignatures) {
		StateSet node = tree.getLabel();
		if (node.to - node.from < minSpanLength) return;
		
		short state = node.getState();
		int thisClass = stateClass[state];

		StateSet stateSet = words.get(node.from);
		int firstWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
		if (useOnlyWords) firstWord = stateSet.wordIndex;
		stateSet = words.get(node.to-1);
		int lastWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
		if (useOnlyWords) lastWord = stateSet.wordIndex;

		int previousWord = 0, nextWord = 0;
		if (node.from > 0) {
			stateSet = words.get(node.from-1);
			previousWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
			if (useOnlyWords) previousWord = stateSet.wordIndex;
		}
		if (node.to < words.size()) {
			stateSet = words.get(node.to);
			nextWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
			if (useOnlyWords) nextWord = stateSet.wordIndex;
		}
		
		if (useFirstAndLast){
			firstWordCount[firstWord][thisClass]++;
			lastWordCount[lastWord][thisClass]++;
		}
		if (usePreviousAndNext){
			if (node.from > 0) previousWordCount[previousWord][thisClass]++;
			if (node.to < words.size()) nextWordCount[nextWord][thisClass]++;
		}
		if (useBeginAndEndPairs){
			if (node.from > 0) {
				int beginIndex = getBeginIndex(previousWord, firstWord);
				if (beginIndex>=0)
					beginPairsCount[beginIndex][thisClass]++;
			}
			if (node.to < words.size()) {
				int endIndex = getEndIndex(lastWord, nextWord);
				if (endIndex>=0)
					endPairsCount[endIndex][thisClass]++;
			}
		}
		if (usePunctuation){
			int punctSig = punctuationSignatures[node.from][node.to];
			if (punctSig>=0)
				punctuationCount[punctSig][thisClass]++;
		}
		
		
		for (Tree child : tree.getChildren()){
			countGoldSpanFeaturesHelper(child, words, firstWordCount, lastWordCount, previousWordCount, nextWordCount,
					beginPairsCount, endPairsCount, punctuationCount, punctuationSignatures);
		}
		
	}
  
  
  private void initPunctuations(StateSetTreeList trainTrees){
  	punctuationSignatures = new Indexer();
  	isPunctuation = new boolean[nWords];
  	Counter punctSigCounter = new Counter();
  	for (int word=0; word tree : trainTrees){
  		getPunctuationSignatures(tree.getYield(), true, punctSigCounter);
  	}
  	
  	Indexer newPunctuationSignatures = new Indexer();
  	for (String sig : punctSigCounter.keySet()){
  		if (punctSigCounter.getCount(sig) >= minFeatureFrequency)
  			newPunctuationSignatures.add(sig);
  	}
  	punctuationSignatures = newPunctuationSignatures;
  	punctuationScores = new double[punctuationSignatures.size()][nClasses];
  	ArrayUtil.fill(punctuationScores,1);
  	nFeatures += nClasses*punctuationScores.length;
  }
  
  private boolean isPunctuation(String word){
  	if (word.length()>2) return false;
  	if (Character.isLetterOrDigit(word.charAt(0))) return false;
  	if (word.length()==1) return true;
  	return !Character.isLetterOrDigit(word.charAt(1));
  }
  
  private int appendItem(StringBuilder sb, String maskedWord, int nWordsBefore){
		if (maskedWord != X) {
			sb.append(maskedWord);
			nWordsBefore = 0;
		} else if (nWordsBefore==0){
			sb.append("x");
			nWordsBefore++;
		} else if (nWordsBefore==1){
			sb.append("+");
			nWordsBefore++;
		}
  	return nWordsBefore;
  }
  
  public int[][] getPunctuationSignatures(List sentence){
  	return getPunctuationSignatures(sentence, false, null);
  }
  
  private final String X = "x".intern();
  // replace words with x and leave only punctuation, collapse xx,xxx,xxxx,... to x+
  public int[][] getPunctuationSignatures(List sentence, boolean update, Counter punctSigCounter){
  	int length = sentence.size();
  	String[] masked = new String[length];
  	for (int i=0; i0&&isPunctuation[thisStateSet.wordIndex]) ? thisStateSet.getWord() : X;
  	} 
  	
  	int[][] result = new int[length][length+1];
  	ArrayUtil.fill(result, -1);
		for (int start = 0; start <= length-minSpanLength; start++) {
			StringBuilder sb = new StringBuilder();
			String prev = "";
			if (start<=1) sb.append("");
  		int nWordsBefore = 0;
  		if (start>0){
  			appendItem(sb, masked[start-1], nWordsBefore);
  		}
  		sb.append("[");
			nWordsBefore = appendItem(sb, masked[start], 0);
			for (int end = start + minSpanLength; end <= length; end++) {
				nWordsBefore = appendItem(sb, masked[end-1], nWordsBefore);
				prev = sb.toString();
				sb.append("]");
				if (end");
				}
				String sig = sb.toString();
				if (update) {
					punctuationSignatures.add(sig);
					punctSigCounter.incrementCount(sig, 1.0);
				}
	  		result[start][end] = punctuationSignatures.indexOf(sig);
	  		sb = new StringBuilder(prev);
  		}
  	}
  	return result;
  }
  
  
  public String toString(){
  	return toString(null);
  }
  
  
  public String toString(Indexer wordIndexer){
  	StringBuffer sb = new StringBuffer();
  	if (useFirstAndLast||usePreviousAndNext){
  	sb.append("word");
		if (useFirstAndLast) sb.append("\tfirst\t\tlast\t");
		if (usePreviousAndNext) sb.append("\tprevious\tfollowing");
		sb.append("\n");
		
		for (int word=0; word pQf = new PriorityQueue();
  		PriorityQueue pQl = new PriorityQueue();
  		PriorityQueue pQp = null;
  		PriorityQueue pQn = null;
  		if (usePreviousAndNext){
    		pQp = new PriorityQueue();
    		pQn = new PriorityQueue();
    	}
    	for (int word=0; word pQb = new PriorityQueue();
  		PriorityQueue pQe = new PriorityQueue();
    	for (Pair p : beginMap.keySet()){
  			String w1 = wordIndexer.get((Integer)p.getFirst());
  			String w2 = wordIndexer.get((Integer)p.getSecond());
  			pQb.add("("+w1+" | "+w2+"),", beginPairScore[beginMap.get(p)][0]);
    	}
    	for (Pair p : endMap.keySet()){
  			String w1 = wordIndexer.get((Integer)p.getFirst());
  			String w2 = wordIndexer.get((Integer)p.getSecond());
  			pQe.add("("+w1+" | "+w2+"),", endPairScore[endMap.get(p)][0]);
    	}
  		while (pQb.hasNext()||pQe.hasNext()){
  			double weight = 0;
  			if (pQb.hasNext()){
  				weight = pQb.getPriority();
  				sb.append(pQb.next()+" "+weight+"\t");
  			} else sb.append("\t\t\t\t");
  			if (pQe.hasNext()){
	  			weight = pQe.getPriority();
	  			sb.append(pQe.next()+" "+weight+"\n");
  			} else sb.append("\n");
      }
  	}
  	if (usePunctuation){
  		sb.append("Punctuation features:\n");
  		PriorityQueue pQp = new PriorityQueue();
    	for (int f=0; f pair = new Pair(previousIndex, currIndex);
		if (!beginMap.containsKey(pair)) return -1;
		return beginMap.get(pair);
	}
  
	public int getEndIndex(int previousIndex, int currIndex) {
		Pair pair = new Pair(previousIndex, currIndex);
		if (!endMap.containsKey(pair)) return -1;
		return endMap.get(pair);
	}

	/**
	 * @return the stateClass
	 */
	public int[] getStateClass() {
		return stateClass;
	}

	/**
	 * @return the nClasses
	 */
	public final int getNClasses() {
		return nClasses;
	}
  
//  public class FeatureBundle{
//  	public int firstWord;
//  	public int lastWord;
//  	public int previousWord;
//  	public int nextWord;
//  	
//  	public int beginPair;
//  	public int endPair;
//  	
//
//  }
  
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy