![JAR search and dependency download from the Maven repository](/logo.png)
edu.berkeley.nlp.PCFGLA.SpanPredictor Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of berkeleyparser Show documentation
Show all versions of berkeleyparser Show documentation
The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).
The newest version!
/**
*
*/
package edu.berkeley.nlp.PCFGLA;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import edu.berkeley.nlp.discPCFG.WordInSentence;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.Pair;
import edu.berkeley.nlp.util.PriorityQueue;
/**
* @author petrov
*
*/
public class SpanPredictor implements Serializable{
private static final long serialVersionUID = 1L;
public final boolean useFirstAndLast;
public final boolean usePreviousAndNext; // can only be on if useFirstAndLast is on
public final boolean useBeginAndEndPairs;
public final boolean useSyntheticClass;
public final boolean usePunctuation;
Indexer punctuationSignatures;
boolean[] isPunctuation;
public final boolean useOnlyWords = true;
// public final boolean useCapitalization;
public final int minFeatureFrequency;
public final int minSpanLength = 3;
public double[][] firstWordScore;
public double[][] lastWordScore;
public double[][] previousWordScore;
public double[][] nextWordScore;
public double[][] beginPairScore;
public double[][] endPairScore;
private HashMap,Integer> beginMap;
private HashMap,Integer> endMap;
public double[][] punctuationScores;
public int nWords;
public int nFeatures;
private int[] stateClass;
private int nClasses;
private Indexer wordIndexer;
// public int startIndexPrevious, startIndexBegin;
public SpanPredictor(int nWords, StateSetTreeList trainTrees, Numberer tagNumberer, Indexer wordIndexer){
this.useFirstAndLast = ConditionalTrainer.Options.useFirstAndLast;
this.usePreviousAndNext = ConditionalTrainer.Options.usePreviousAndNext;
this.useBeginAndEndPairs = ConditionalTrainer.Options.useBeginAndEndPairs;
this.useSyntheticClass = ConditionalTrainer.Options.useSyntheticClass;
this.usePunctuation = ConditionalTrainer.Options.usePunctuation;
this.minFeatureFrequency = ConditionalTrainer.Options.minFeatureFrequency;
this.wordIndexer = wordIndexer;
this.nWords = nWords;
this.nFeatures = 0;
if (useSyntheticClass){
System.out.println("Distinguishing between real and synthetic classes.");
stateClass = new int[tagNumberer.total()];
for (int i=0; i, Integer>();
endMap = new HashMap, Integer>();
Counter> beginPairCounter = new Counter>();
Counter> endPairCounter = new Counter>();
int beginPairs = 0, endPairs = 0;
for (Tree tree : trainTrees){
List words = tree.getYield();
StateSet stateSet = words.get(0);
int prevIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) prevIndex = stateSet.wordIndex;
int currIndex = -1;
for (int i=1; i<=words.size()-minSpanLength; i++){
stateSet = words.get(i);
currIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) currIndex = stateSet.wordIndex;
Pair pair = new Pair(prevIndex,currIndex);
beginPairCounter.incrementCount(pair, 1.0);
if (!beginMap.containsKey(pair))
beginMap.put(pair,beginPairs++);
prevIndex = currIndex;
}
if (words.size() < minSpanLength) continue;
stateSet = words.get(minSpanLength-1);
prevIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) currIndex = stateSet.wordIndex;
for (int i=minSpanLength; i pair = new Pair(prevIndex,currIndex);
endPairCounter.incrementCount(pair, 1.0);
if (!endMap.containsKey(pair))
endMap.put(pair,endPairs++);
prevIndex = currIndex;
}
}
HashMap, Integer> newBeginMap = new HashMap, Integer>();
HashMap, Integer> newEndMap = new HashMap, Integer>();
int newBeginPairs = 0;
for (Pair pair : beginMap.keySet()){
if (beginPairCounter.getCount(pair) >= minFeatureFrequency){
newBeginMap.put(pair, newBeginPairs++);
}
}
beginMap = newBeginMap;
beginPairs = newBeginPairs;
int newEndPairs = 0;
for (Pair pair : endMap.keySet()){
if (endPairCounter.getCount(pair) >= minFeatureFrequency){
newEndMap.put(pair, newEndPairs++);
}
}
endMap = newEndMap;
endPairs = newEndPairs;
beginPairScore = new double[beginPairs][nClasses];
endPairScore = new double[endPairs][nClasses];
nFeatures += (beginPairs + endPairs)*nClasses;
System.out.println("There were "+beginPairs+" begin-pair types and "+endPairs+" end-pair types.");
}
public double[] scoreSpan(int previousIndex, int firstIndex, int lastIndex, int followingIndex){
double[] result = new double[nClasses];
Arrays.fill(result, 1);
if (firstIndex<0||lastIndex<0) {
// System.out.println("unseen index when scoring span: "+firstIndex+" "+lastIndex);
return result;
}
for (int c=0; c=0) result[c] *= previousWordScore[previousIndex][c];
if (followingIndex>=0) result[c] *= nextWordScore[followingIndex][c];
}
if (useBeginAndEndPairs){
if (previousIndex>=0) {
int index = getBeginIndex(previousIndex, firstIndex);
if (index>=0) result[c] *= beginPairScore[index][c];
}
if (followingIndex>=0) {
int index = getEndIndex(lastIndex, followingIndex);
if (index>=0) result[c] *= endPairScore[index][c];
}
}
if (SloppyMath.isDangerous(result[c])){
System.out.println("Dangerous span prediction set to 1, since it was "+result);
result[c] = 1;
}
}
return result;
}
public double[][][] predictSpans(List sentence) {
int previousIndex=-1, firstIndex, lastIndex, followingIndex=-1;
int length = sentence.size();
double[][][] spanScores = new double[length][length+1][nClasses];
// all spans of size <=minSpanLength are ok
for (int start = 0; start < length; start++) {
for (int end = start + 1; end < start+minSpanLength && end<=length; end++) {
for (int clas=0; clas < nClasses; clas++){
spanScores[start][end][clas] = 1;
}
}
}
for (int start = 0; start <= length-minSpanLength; start++) {
StateSet stateSet = sentence.get(start);
firstIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) firstIndex = stateSet.wordIndex;
for (int end = start + minSpanLength; end <= length; end++) {
stateSet = sentence.get(end-1);
lastIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) lastIndex = stateSet.wordIndex;
if (end tree : trainTrees){
List words = tree.getYield();
if (usePunctuation) punctuationSig = getPunctuationSignatures(words);
countGoldSpanFeaturesHelper(tree, words, firstWordCount, lastWordCount,
previousWordCount, nextWordCount, beginPairsCount, endPairsCount,
punctuationCount, punctuationSig);
}
double[] res = new double[nFeatures];
int index = 0;
if (useFirstAndLast){
int firstSum = 0, lastSum = 0;
for (int c=0; c tree, List words,
int[][] firstWordCount, int[][] lastWordCount, int[][] previousWordCount, int[][] nextWordCount,
int[][] beginPairsCount, int[][] endPairsCount, int[][] punctuationCount, int[][] punctuationSignatures) {
StateSet node = tree.getLabel();
if (node.to - node.from < minSpanLength) return;
short state = node.getState();
int thisClass = stateClass[state];
StateSet stateSet = words.get(node.from);
int firstWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) firstWord = stateSet.wordIndex;
stateSet = words.get(node.to-1);
int lastWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) lastWord = stateSet.wordIndex;
int previousWord = 0, nextWord = 0;
if (node.from > 0) {
stateSet = words.get(node.from-1);
previousWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) previousWord = stateSet.wordIndex;
}
if (node.to < words.size()) {
stateSet = words.get(node.to);
nextWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) nextWord = stateSet.wordIndex;
}
if (useFirstAndLast){
firstWordCount[firstWord][thisClass]++;
lastWordCount[lastWord][thisClass]++;
}
if (usePreviousAndNext){
if (node.from > 0) previousWordCount[previousWord][thisClass]++;
if (node.to < words.size()) nextWordCount[nextWord][thisClass]++;
}
if (useBeginAndEndPairs){
if (node.from > 0) {
int beginIndex = getBeginIndex(previousWord, firstWord);
if (beginIndex>=0)
beginPairsCount[beginIndex][thisClass]++;
}
if (node.to < words.size()) {
int endIndex = getEndIndex(lastWord, nextWord);
if (endIndex>=0)
endPairsCount[endIndex][thisClass]++;
}
}
if (usePunctuation){
int punctSig = punctuationSignatures[node.from][node.to];
if (punctSig>=0)
punctuationCount[punctSig][thisClass]++;
}
for (Tree child : tree.getChildren()){
countGoldSpanFeaturesHelper(child, words, firstWordCount, lastWordCount, previousWordCount, nextWordCount,
beginPairsCount, endPairsCount, punctuationCount, punctuationSignatures);
}
}
private void initPunctuations(StateSetTreeList trainTrees){
punctuationSignatures = new Indexer();
isPunctuation = new boolean[nWords];
Counter punctSigCounter = new Counter();
for (int word=0; word tree : trainTrees){
getPunctuationSignatures(tree.getYield(), true, punctSigCounter);
}
Indexer newPunctuationSignatures = new Indexer();
for (String sig : punctSigCounter.keySet()){
if (punctSigCounter.getCount(sig) >= minFeatureFrequency)
newPunctuationSignatures.add(sig);
}
punctuationSignatures = newPunctuationSignatures;
punctuationScores = new double[punctuationSignatures.size()][nClasses];
ArrayUtil.fill(punctuationScores,1);
nFeatures += nClasses*punctuationScores.length;
}
private boolean isPunctuation(String word){
if (word.length()>2) return false;
if (Character.isLetterOrDigit(word.charAt(0))) return false;
if (word.length()==1) return true;
return !Character.isLetterOrDigit(word.charAt(1));
}
private int appendItem(StringBuilder sb, String maskedWord, int nWordsBefore){
if (maskedWord != X) {
sb.append(maskedWord);
nWordsBefore = 0;
} else if (nWordsBefore==0){
sb.append("x");
nWordsBefore++;
} else if (nWordsBefore==1){
sb.append("+");
nWordsBefore++;
}
return nWordsBefore;
}
public int[][] getPunctuationSignatures(List sentence){
return getPunctuationSignatures(sentence, false, null);
}
private final String X = "x".intern();
// replace words with x and leave only punctuation, collapse xx,xxx,xxxx,... to x+
public int[][] getPunctuationSignatures(List sentence, boolean update, Counter punctSigCounter){
int length = sentence.size();
String[] masked = new String[length];
for (int i=0; i0&&isPunctuation[thisStateSet.wordIndex]) ? thisStateSet.getWord() : X;
}
int[][] result = new int[length][length+1];
ArrayUtil.fill(result, -1);
for (int start = 0; start <= length-minSpanLength; start++) {
StringBuilder sb = new StringBuilder();
String prev = "";
if (start<=1) sb.append("");
int nWordsBefore = 0;
if (start>0){
appendItem(sb, masked[start-1], nWordsBefore);
}
sb.append("[");
nWordsBefore = appendItem(sb, masked[start], 0);
for (int end = start + minSpanLength; end <= length; end++) {
nWordsBefore = appendItem(sb, masked[end-1], nWordsBefore);
prev = sb.toString();
sb.append("]");
if (end");
}
String sig = sb.toString();
if (update) {
punctuationSignatures.add(sig);
punctSigCounter.incrementCount(sig, 1.0);
}
result[start][end] = punctuationSignatures.indexOf(sig);
sb = new StringBuilder(prev);
}
}
return result;
}
public String toString(){
return toString(null);
}
public String toString(Indexer wordIndexer){
StringBuffer sb = new StringBuffer();
if (useFirstAndLast||usePreviousAndNext){
sb.append("word");
if (useFirstAndLast) sb.append("\tfirst\t\tlast\t");
if (usePreviousAndNext) sb.append("\tprevious\tfollowing");
sb.append("\n");
for (int word=0; word pQf = new PriorityQueue();
PriorityQueue pQl = new PriorityQueue();
PriorityQueue pQp = null;
PriorityQueue pQn = null;
if (usePreviousAndNext){
pQp = new PriorityQueue();
pQn = new PriorityQueue();
}
for (int word=0; word pQb = new PriorityQueue();
PriorityQueue pQe = new PriorityQueue();
for (Pair p : beginMap.keySet()){
String w1 = wordIndexer.get((Integer)p.getFirst());
String w2 = wordIndexer.get((Integer)p.getSecond());
pQb.add("("+w1+" | "+w2+"),", beginPairScore[beginMap.get(p)][0]);
}
for (Pair p : endMap.keySet()){
String w1 = wordIndexer.get((Integer)p.getFirst());
String w2 = wordIndexer.get((Integer)p.getSecond());
pQe.add("("+w1+" | "+w2+"),", endPairScore[endMap.get(p)][0]);
}
while (pQb.hasNext()||pQe.hasNext()){
double weight = 0;
if (pQb.hasNext()){
weight = pQb.getPriority();
sb.append(pQb.next()+" "+weight+"\t");
} else sb.append("\t\t\t\t");
if (pQe.hasNext()){
weight = pQe.getPriority();
sb.append(pQe.next()+" "+weight+"\n");
} else sb.append("\n");
}
}
if (usePunctuation){
sb.append("Punctuation features:\n");
PriorityQueue pQp = new PriorityQueue();
for (int f=0; f pair = new Pair(previousIndex, currIndex);
if (!beginMap.containsKey(pair)) return -1;
return beginMap.get(pair);
}
public int getEndIndex(int previousIndex, int currIndex) {
Pair pair = new Pair(previousIndex, currIndex);
if (!endMap.containsKey(pair)) return -1;
return endMap.get(pair);
}
/**
* @return the stateClass
*/
public int[] getStateClass() {
return stateClass;
}
/**
* @return the nClasses
*/
public final int getNClasses() {
return nClasses;
}
// public class FeatureBundle{
// public int firstWord;
// public int lastWord;
// public int previousWord;
// public int nextWord;
//
// public int beginPair;
// public int endPair;
//
//
// }
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy