![JAR search and dependency download from the Maven repository](/logo.png)
edu.berkeley.nlp.PCFGLA.CoarseToFineMaxRuleParser Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of berkeleyparser Show documentation
Show all versions of berkeleyparser Show documentation
The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).
The newest version!
package edu.berkeley.nlp.PCFGLA;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.zip.GZIPOutputStream;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.ScalingTools;
/**
*
* @author Slav Petrov
*
* SHOULD BE CLEANED UP!!!
* AND PROBABLY ALSO RENAMED SINCE IT CAN COMPUTE VITERBI PARSES AS WELL
*
* An extension of ConstrainedArrayParser that computes the scores P(w_{i:j}|A), whose
* computation involves a sum, rather than the Viterbi scores, which involve a max.
* This is used by the Labelled Recall parser (maximizes the expected number of correct
* symbols) and the Max-Rule parser (maximizes the expected number of correct rules, ie
* all 3 symbols correct).
*
*/
public class CoarseToFineMaxRuleParser extends ConstrainedArrayParser{
boolean[][][][] allowedSubStates;
boolean[][][] allowedStates;
boolean[][] vAllowedStates;
double[][] spanMass;
// allowedStates[start][end][state][0] -> is this category allowed
// allowedStates[start][end][state][i+1] -> is subcategory i allowed
Grammar[] grammarCascade;
Lexicon[] lexiconCascade;
int[][][] lChildMap;
int[][][] rChildMap;
int startLevel;
int endLevel;
// protected double[][][][] iScore;
/** outside scores; start idx, end idx, state -> logProb */
// protected double[][][][] oScore;
// protected short[] numSubStatesArray;
double[] maxThresholds;
double logLikelihood;
Tree bestTree;
boolean isBaseline;
protected final boolean doVariational;
// inside scores
protected double[][][] viScore; // start idx, end idx, state -> logProb
protected double[][][] voScore; // start idx, end idx, state -> logProb
// maxcScore does not have substate information since these are marginalized out
protected double savedScore;
protected double[][][] maxcScore; // start, end, state --> logProb
protected double[][][] maxsScore; // start, end, state --> logProb
protected int[][][] maxcSplit; // start, end, state -> split position
protected int[][][] maxcChild; // start, end, state -> unary child (if any)
protected int[][][] maxcLeftChild; // start, end, state -> left child
protected int[][][] maxcRightChild; // start, end, state -> right child
protected double unaryPenalty;
int nLevels;
final boolean[] grammarTags;
final boolean viterbiParse;
final boolean outputSub;
final boolean outputScore;
Numberer wordNumberer = Numberer.getGlobalNumberer("words");
final boolean accurate;
final boolean useGoldPOS;
double[] unscaledScoresToAdd;
ArrayParser llParser;
List posteriorsToDump;
// double edgesTouched;
// int sentencesParsed;
CoarseToFineMaxRuleParser(){
grammarTags = null;
viterbiParse = false;
outputSub = false;
outputScore = false;
accurate = false;
useGoldPOS = false;
doVariational = false;
}
public CoarseToFineMaxRuleParser(Grammar gr, Lexicon lex, double unaryPenalty, int endL,
boolean viterbi, boolean sub, boolean score, boolean accurate, boolean variational,
boolean useGoldPOS, boolean initializeCascade) {
grammar=gr;
lexicon=lex;
// this.numSubStatesArray = gr.numSubStates.clone();
//System.out.println("The unary penalty for parsing is "+unaryPenalty+".");
this.unaryPenalty = unaryPenalty;
this.accurate = accurate;
this.viterbiParse = viterbi;
this.outputScore = score;
this.outputSub = sub;
this.doVariational = variational;
this.useGoldPOS = useGoldPOS;
totalUsedUnaries=0;
nTimesRestoredUnaries=0;
nRules=0;
nRulesInf=0;
this.tagNumberer = Numberer.getGlobalNumberer("tags");
this.numStates = gr.numStates;
this.maxNSubStates = maxSubStates(gr);
this.idxC = new int[maxNSubStates];
this.scoresToAdd = new double[maxNSubStates];
this.unscaledScoresToAdd = new double[maxNSubStates];
this.grammarTags = new boolean[numStates];
for (int i=0; iendLevel){
tmpGrammar = gr;
tmpLexicon = lex;
}
else /*if (level>0&& level0) {
lChildMap[level+startLevel] = curLChildMap;
rChildMap[level+startLevel] = curRChildMap;
gr.computeReverseSubstateMapping(level,curLChildMap,curRChildMap);
}
}
tmpGrammar.splitRules();
double filter = 1.0e-4;
if (level>=0 && level=endLevel){
tmpGrammar.removeUnlikelyRules(1.0e-10,1.0);
tmpLexicon.removeUnlikelyTags(1.0e-10,1.0);
}
//System.out.println(baseGrammar.toString());
// DumpGrammar.dumpGrammar("wsj_"+level+".gr", tmpGrammar, (SophisticatedLexicon)tmpLexicon);
if (level<=endLevel || viterbiParse){
tmpGrammar.logarithmMode();
tmpLexicon.logarithmMode();
}
grammarCascade[level-startLevel]=tmpGrammar;
lexiconCascade[level-startLevel]=tmpLexicon;
}
}
void doConstrainedInsideScores(Grammar grammar, boolean viterbi, boolean logScores) {
if (!viterbi && logScores) throw new Error("This would require logAdds and is slow. Exponentiate the scores instead.");
short[] numSubStatesArray = grammar.numSubStates;
double initVal = (logScores) ? Double.NEGATIVE_INFINITY : 0;
for (int diff = 1; diff <= length; diff++) {
for (int start = 0; start < (length - diff + 1); start++) {
int end = start + diff;
for (int pState=0; pState= narrowR); // can this right constituent fit next to the left constituent?
if (!iPossibleR) { continue; }
int min1 = narrowR;
int min2 = wideLExtent[end][rState];
int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent?
if (min > narrowL) { continue; }
int max1 = wideRExtent[start][lState];
int max2 = narrowL;
final int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent?
if (min > max) { continue; }
// TODO switch order of loops for efficiency
double[][][] scores = r.getScores2();
final int nLeftChildStates = numSubStatesArray[lState];
final int nRightChildStates = numSubStatesArray[rState];
for (int split = min; split <= max; split++) {
if (!allowedStates[start][split][lState]) continue;
if (!allowedStates[split][end][rState]) continue;
for (int lp = 0; lp < nLeftChildStates; lp++) {
//if (iScore[start][split][lState] == null) continue;
//if (!allowedSubStates[start][split][lState][lp]) continue;
double lS = iScore[start][split][lState][lp];
if (lS == initVal) continue;
for (int rp = 0; rp < nRightChildStates; rp++) {
if (scores[lp][rp]==null) continue;
double rS = iScore[split][end][rState][rp];
if (rS == initVal) continue;
for (int np = 0; np < nParentStates; np++) {
if (!allowedSubStates[start][end][pState][np]) continue;
// if (level==endLevel-1) edgesTouched++;
double pS = scores[lp][rp][np];
if (pS==initVal) continue;
double thisRound = (logScores) ? pS+lS+rS : pS*lS*rS;
if (viterbi) scoresToAdd[np] = Math.max(thisRound, scoresToAdd[np]);
else scoresToAdd[np] += thisRound;
somethingChanged = true;
}
}
}
}
}
if (!somethingChanged) continue;
for (int np = 0; np < nParentStates; np++) {
if (scoresToAdd[np] > initVal) {
iScore[start][end][pState][np] = scoresToAdd[np];
}
}
if (true){//firstTime) {
if (start > narrowLExtent[end][pState]) {
narrowLExtent[end][pState] = start;
wideLExtent[end][pState] = start;
} else {
if (start < wideLExtent[end][pState]) {
wideLExtent[end][pState] = start;
}
}
if (end < narrowRExtent[start][pState]) {
narrowRExtent[start][pState] = end;
wideRExtent[start][pState] = end;
} else {
if (end > wideRExtent[start][pState]) {
wideRExtent[start][pState] = end;
}
}
}
}
double[][] scoresAfterUnaries = new double[numStates][];
boolean somethingChanged = false;
for (int pState=0; pState initVal) {
if (viterbi) iScore[start][end][pState][np] = Math.max(iScore[start][end][pState][np], thisCell[np]);
else iScore[start][end][pState][np] = iScore[start][end][pState][np] + thisCell[np];
}
}
if (true){
if (start > narrowLExtent[end][pState]) {
narrowLExtent[end][pState] = start;
wideLExtent[end][pState] = start;
} else {
if (start < wideLExtent[end][pState]) {
wideLExtent[end][pState] = start;
}
}
if (end < narrowRExtent[start][pState]) {
narrowRExtent[start][pState] = end;
wideRExtent[start][pState] = end;
} else {
if (end > wideRExtent[start][pState]) {
wideRExtent[start][pState] = end;
}
}
}
}
}
}
}
/** Fills in the oScore array of each category over each span
* of length 2 or more. This version computes the posterior
* outside scores, not the Viterbi outside scores.
*/
void doConstrainedOutsideScores(Grammar grammar, boolean viterbi, boolean logScores) {
short[] numSubStatesArray = grammar.numSubStates;
double initVal = (logScores) ? Double.NEGATIVE_INFINITY : 0.0;
for (int diff = length; diff >= 1; diff--) {
for (int start = 0; start + diff <= length; start++) {
int end = start + diff;
// do unaries
double[][] scoresAfterUnaries = new double[numStates][];
boolean somethingChanged = false;
for (int cState=0; cState1 && !grammar.isGrammarTag[cState]) continue;
if (oScore[start][end][cState] == null) { continue; }
UnaryRule[] rules = null;
if (viterbi) rules = grammar.getClosedViterbiUnaryRulesByChild(cState);
else rules = grammar.getClosedSumUnaryRulesByChild(cState);
final int nChildStates = numSubStatesArray[cState];
final int numRules = rules.length;
for (int r = 0; r < numRules; r++) {
UnaryRule ur = rules[r];
int pState = ur.parentState;
if ((pState == cState)) continue;// && (np == cp))continue;
if (oScore[start][end][pState] == null) { continue; }
double[][] scores = ur.getScores2();
final int nParentStates = numSubStatesArray[pState];
for (int cp = 0; cp < nChildStates; cp++) {
if (scores[cp]==null) continue;
if (!allowedSubStates[start][end][cState][cp]) continue;
for (int np = 0; np < nParentStates; np++) {
double pS = scores[cp][np];
if (pS == initVal) continue;
double oS = oScore[start][end][pState][np];
if (oS == initVal) continue;
double thisRound = (logScores) ? oS+pS : oS*pS;
if (scoresAfterUnaries[cState]==null){
scoresAfterUnaries[cState] = new double[numSubStatesArray[cState]];
if (viterbi) Arrays.fill(scoresAfterUnaries[cState], initVal);
}
if (viterbi) scoresAfterUnaries[cState][cp] = Math.max(thisRound, scoresAfterUnaries[cState][cp]);
else scoresAfterUnaries[cState][cp] += thisRound;
somethingChanged = true;
}
}
}
}
if (somethingChanged){
for (int cState=0; cState initVal){
if (viterbi) oScore[start][end][cState][cp] = Math.max(oScore[start][end][cState][cp], thisCell[cp]);
else oScore[start][end][cState][cp] += thisCell[cp];
}
}
}
}
// do binaries
for (int pState=0; pState 2) {
int min2 = wideLExtent[end][rState];
min = (min1 > min2 ? min1 : min2);
if (max1 < min) { continue; }
int max2 = wideRExtent[start][lState];
max = (max1 < max2 ? max1 : max2);
if (max < min) { continue; }
}
double[][][] scores = br.getScores2();
final int nLeftChildStates = numSubStatesArray[lState];
final int nRightChildStates = numSubStatesArray[rState];
for (int split = min; split <= max; split++) {
if (oScore[start][split][lState] == null) continue;
if (oScore[split][end][rState] == null) continue;
//if (!allowedStates[start][split][lState]) continue;
//if (!allowedStates[split][end][rState]) continue;
double[] rightScores = new double[nRightChildStates];
if (viterbi) Arrays.fill(rightScores,initVal);
Arrays.fill(scoresToAdd,initVal);
somethingChanged = false;
for (int lp=0; lp initVal){
if (viterbi) oScore[start][split][lState][cp] = Math.max(oScore[start][split][lState][cp], scoresToAdd[cp]);
else oScore[start][split][lState][cp] += scoresToAdd[cp];
}
}
for (int cp=0; cp initVal) {
if (viterbi) oScore[split][end][rState][cp] = Math.max(oScore[split][end][rState][cp], rightScores[cp]);
else oScore[split][end][rState][cp] += rightScores[cp];
}
}
}
}
}
}
}
}
void initializeChart(List sentence, Lexicon lexicon,boolean noSubstates,
boolean noSmoothing,List posTags, boolean scale) {
int start = 0;
int end = start+1;
for (String word : sentence) {
end = start+1;
int goldTag = -1;
if (useGoldPOS && posTags!=null) {
goldTag = tagNumberer.number(posTags.get(start));
}
for (int tag=0; tag start, etc.)
// System.out.println("initializing iScore arrays with length " + length + " and numStates " + numStates);
//if (logProbs){
viScore = new double[length][length + 1][];
voScore = new double[length][length + 1][];
//} else{
iScore = new double[length][length + 1][][];
oScore = new double[length][length + 1][][];
//iScale = new int[length][length + 1][];
//oScale = new int[length][length + 1][];
//}
allowedSubStates = new boolean[length][length+1][][];
allowedStates = new boolean[length][length+1][];
vAllowedStates = new boolean[length][length+1];
}
for (int start = 0; start < length; start++) {
for (int end = start + 1; end <= length; end++) {
if (firstTime){
viScore[start][end] = new double[numStates];
voScore[start][end] = new double[numStates];
iScore[start][end] = new double[numStates][];
oScore[start][end] = new double[numStates][];
//iScale[start][end] = new int[numStates];
//oScale[start][end] = new int[numStates];
allowedSubStates[start][end] = new boolean[numStates][];
allowedStates[start][end] = grammarTags.clone();
//Arrays.fill(allowedStates[start][end], true);
vAllowedStates[start][end] = true;
}
for (int state=0; state1 && !grammarTags[state]) continue;
/*if (refreshOnly){
if (allowedStates[start][end][state]){
Arrays.fill(iScore[start][end][state], 0);
Arrays.fill(oScore[start][end][state], 0);
}
continue;
}*/
if (firstTime || allowedStates[start][end][state]){
if (level<1){
viScore[start][end][state] = Double.NEGATIVE_INFINITY;
voScore[start][end][state] = Double.NEGATIVE_INFINITY;
} else{
iScore[start][end][state] = new double[numSubStatesArray[state]];
oScore[start][end][state] = new double[numSubStatesArray[state]];
Arrays.fill(iScore[start][end][state], initVal);
Arrays.fill(oScore[start][end][state], initVal);
//Arrays.fill(iScale[start][end], Integer.MIN_VALUE);
//Arrays.fill(oScale[start][end], Integer.MIN_VALUE);
boolean[] newAllowedSubStates = new boolean[numSubStatesArray[state]];
if (allowedSubStates[start][end][state]==null){
Arrays.fill(newAllowedSubStates,true);
allowedSubStates[start][end][state] = newAllowedSubStates;
} else{
if (!justInit){
int[][] curLChildMap = lChildMap[level-2];
int[][] curRChildMap = rChildMap[level-2];
for (int i=0; i0 && start==0 && end==length ) {
if (iScore[start][end][0]==null)
System.out.println("ROOT does not span the entire tree!");
}
}
}
narrowRExtent = new int[length + 1][numStates];
wideRExtent = new int[length + 1][numStates];
narrowLExtent = new int[length + 1][numStates];
wideLExtent = new int[length + 1][numStates];
for (int loc = 0; loc <= length; loc++) {
Arrays.fill(narrowLExtent[loc], -1); // the rightmost left with state s ending at i that we can get is the beginning
Arrays.fill(wideLExtent[loc], length + 1); // the leftmost left with state s ending at i that we can get is the end
Arrays.fill(narrowRExtent[loc], length + 1); // the leftmost right with state s starting at i that we can get is the end
Arrays.fill(wideRExtent[loc], -1); // the rightmost right with state s starting at i that we can get is the beginning
}
iScale = null;
oScale = null;
}
protected void clearArrays() {
iScore = oScore = null;
viScore = voScore = null;
allowedSubStates = null;
vAllowedStates = null;
// iPossibleByL = iPossibleByR = oFilteredEnd = oFilteredStart =
// oPossibleByL = oPossibleByR = tags = null;
narrowRExtent = wideRExtent = narrowLExtent = wideLExtent = null;
}
protected void pruneChart(double threshold, short[] numSubStatesArray, int level){
int totalStates = 0, previouslyPossible = 0, nowPossible = 0;
//threshold = Double.NEGATIVE_INFINITY;
double sentenceProb = (level<1) ? viScore[0][length][0] : iScore[0][length][0][0];
//double sentenceScale = iScale[0][length][0];//+1.0 for oScale
if (level<1) nowPossible=totalStates=previouslyPossible=length;
int startDiff = (level<0) ? 2 : 1;
for (int diff = startDiff; diff <= length; diff++) {
for (int start = 0; start < (length - diff + 1); start++) {
int end = start + diff;
int lastState = (level<0) ? 1 : numSubStatesArray.length;
for (int state = 0; state < lastState; state++) {
if (diff>1&&!grammarTags[state]) continue;
//boolean allFalse = true;
if (state==0){
allowedStates[start][end][state]=true;
// if (level>1){
// allowedSubStates[start][end][state] = new boolean[1];
// allowedSubStates[start][end][state][0] = true;
// }
continue;
}
if (level==0){
if (!vAllowedStates[start][end]) {
allowedStates[start][end][state]=false;
totalStates++;
continue;
}
} else if (level>0){
if (!allowedStates[start][end][state]) {
totalStates+=numSubStatesArray[state];
continue;
}
}
if (level<1){
totalStates++;
previouslyPossible++;
double iS = viScore[start][end][state];
double oS = voScore[start][end][state];
if (iS==Double.NEGATIVE_INFINITY||oS==Double.NEGATIVE_INFINITY) {
if (level==0) allowedStates[start][end][state] = false;
else /*level==-1*/ vAllowedStates[start][end]=false;
continue;
}
double posterior = iS + oS - sentenceProb;
if (posterior > threshold) {
// spanMass[start][end]+=Math.exp(posterior);
if (level==0) allowedStates[start][end][state]=true;
else vAllowedStates[start][end]=true;
nowPossible++;
} else {
if (level==0) allowedStates[start][end][state] = false;
else vAllowedStates[start][end]=false;
}
continue;
}
// level >= 1 -> iterate over substates
boolean nonePossible = true;
for (int substate = 0; substate < numSubStatesArray[state]; substate++) {
totalStates++;
if (!allowedSubStates[start][end][state][substate]) continue;
previouslyPossible++;
double iS = iScore[start][end][state][substate];
double oS = oScore[start][end][state][substate];
if (iS==Double.NEGATIVE_INFINITY||oS==Double.NEGATIVE_INFINITY) {
allowedSubStates[start][end][state][substate] = false;
continue;
}
double posterior = iS + oS - sentenceProb;
if (posterior > threshold) {
allowedSubStates[start][end][state][substate]=true;
nowPossible++;
// spanMass[start][end]+=Math.exp(posterior);
nonePossible=false;
} else {
allowedSubStates[start][end][state][substate] = false;
}
/*if (thisScale>sentenceScale){
posterior *= Math.pow(GrammarTrainer.SCALE,thisScale-sentenceScale);
}*/
//}
//allowedStates[start][end][state][0] = !allFalse;
//int thisScale = iScale[start][end][state]+oScale[start][end][state];
/*if (sentenceScale>thisScale){
// too small anyways
allowedStates[start][end][state][0] = false;
continue;
}*/
}
if (nonePossible) allowedStates[start][end][state]=false;
}
}
}
// System.out.print("[");
// for(int st=0; st sentence, Tree tree,boolean noSmoothing,List posTags){
boolean keepGoldAlive = (tree!=null); // we are given the gold tree -> make sure we don't prune it away
clearArrays();
length = (short)sentence.size();
double score = 0;
Grammar curGrammar = null;
Lexicon curLexicon = null;
double[] accurateThresholds = {-8,-12,-12,-11,-12,-12,-14,-14};
// double[] accurateThresholds = {-10,-14,-14,-14,-14,-14,-16,-16};
double[] fastThresholds = {-8,-9.75,-10,-9.6,-9.66,-8.01,-7.4,-10,-10};
// double[] accurateThresholds = {-8,-9,-9,-9,-9,-9,-10};
// double[] fastThresholds = {-2,-8,-9,-8,-8,-7.5,-7,-8};
double[] pruningThreshold = null;
if (accurate)
pruningThreshold = accurateThresholds;
else
pruningThreshold = fastThresholds;
//int startLevel = -1;
for (level=startLevel; level<=endLevel; level++){
if (level==-1) continue; // don't do the pre-pre parse
if (!isBaseline && level==endLevel) continue;//
curGrammar = grammarCascade[level-startLevel];
curLexicon = lexiconCascade[level-startLevel];
// createArrays(level==startLevel,curGrammar.numStates,curGrammar.numSubStates,level,Double.NEGATIVE_INFINITY,false);
createArrays(level==0,curGrammar.numStates,curGrammar.numSubStates,level,Double.NEGATIVE_INFINITY,false);
initializeChart(sentence,curLexicon,level<1,noSmoothing,posTags,false);
final boolean viterbi = true, logScores = true;
if (level<1){
doConstrainedViterbiInsideScores(curGrammar,level==startLevel);
score = viScore[0][length][0];
} else {
doConstrainedInsideScores(curGrammar,viterbi,logScores);
score = iScore[0][length][0][0];
}
if (score==Double.NEGATIVE_INFINITY) continue;
// System.out.println("\nFound a parse for sentence with length "+length+". The LL is "+score+".");
if (level<1){
voScore[0][length][0] = 0.0;
doConstrainedViterbiOutsideScores(curGrammar,level==startLevel);
} else {
oScore[0][length][0][0] = 0.0;
doConstrainedOutsideScores(curGrammar,viterbi,logScores);
}
pruneChart(/*Double.NEGATIVE_INFINITY*/pruningThreshold[level+1], curGrammar.numSubStates, level);
if (keepGoldAlive) ensureGoldTreeSurvives(tree, level);
}
}
protected void ensureGoldTreeSurvives(Tree tree, int level){
List> children = tree.getChildren();
for (Tree child : children) {
if (!child.isLeaf())
ensureGoldTreeSurvives(child,level);
}
StateSet node = tree.getLabel();
short state = node.getState();
if (level<0){
vAllowedStates[node.from][node.to]=true;
}
else{
int start = node.from, end = node.to;
/*if (end-start==1 && !grammarTags[state]){ // POS tags -> use gold ones until lexicon is updated
allowedStates[start][end]= new boolean[numStates];
Arrays.fill(allowedStates[start][end], false);
allowedSubStates[start][end] = new boolean[numStates][];
}*/
allowedStates[start][end][state]=true;
if (allowedSubStates[start][end]==null) allowedSubStates[start][end] = new boolean[numStates][];
allowedSubStates[start][end][state] = null; // will be taken care of in createArrays
//boolean[] newArray = new boolean[numSubStatesArray[state]+1];
//Arrays.fill(newArray, true);
//allowedSubStates[node.from][node.to][state] = newArray;
}
}
private void setGoldTreeCountsToOne(Tree tree){
StateSet node = tree.getLabel();
short state = node.getState();
iScore[node.from][node.to][state][0]=1.0;
oScore[node.from][node.to][state][0]=1.0;
List> children = tree.getChildren();
for (Tree child : children) {
if (!child.isLeaf()) setGoldTreeCountsToOne(child);
}
}
public void updateFinalGrammarAndLexicon(Grammar grammar, Lexicon lexicon){
grammarCascade[endLevel-startLevel+1] = grammar;
lexiconCascade[endLevel-startLevel+1] = lexicon;
Grammar tmpGrammar = grammar.copyGrammar(false);
tmpGrammar.logarithmMode();
Lexicon tmpLexicon = lexicon.copyLexicon();
tmpLexicon.logarithmMode();
grammarCascade[endLevel-startLevel] = null;//tmpGrammar; <- since we don't pre-parse with G
lexiconCascade[endLevel-startLevel] = null;//tmpLexicon;
}
public Tree getBestParse(List sentence){
return getBestConstrainedParse(sentence, null, false);
}
public double getLogInsideScore(){
return logLikelihood;
}
public Tree getBestConstrainedParse(List sentence, List posTags, boolean[][][][] allowedS){//List[][] pStates) {
if (allowedS==null) return getBestConstrainedParse(sentence, posTags, false);
clearArrays();
length = (short)sentence.size();
Grammar curGrammar = grammarCascade[endLevel-startLevel+1];
Lexicon curLexicon = lexiconCascade[endLevel-startLevel+1];
double initVal = (viterbiParse) ? Double.NEGATIVE_INFINITY : 0;
int level = isBaseline ? 1 : endLevel;
allowedSubStates = allowedS;
createArrays(true,curGrammar.numStates,curGrammar.numSubStates,level,initVal,false);
setConstraints(allowedS);
return getBestConstrainedParse(sentence, posTags, true);
}
/**
* @param allowedS
*/
private void setConstraints(boolean[][][][] allowedS) {
allowedSubStates = allowedS;
for (int start = 0; start < length; start++) {
for (int end = start + 1; end <= length; end++) {
for (int state=0; state getBestConstrainedParse(List sentence, List posTags, boolean noPreparse) {
if (sentence.size()==0) return new Tree("ROOT");
if (!noPreparse) doPreParses(sentence,null,false,posTags);
bestTree = new Tree("ROOT");
double score = 0;
Grammar curGrammar = grammarCascade[endLevel-startLevel+1];
Lexicon curLexicon = lexiconCascade[endLevel-startLevel+1];
//numSubStatesArray = grammar.numSubStates;
//clearArrays();
double initVal = (viterbiParse) ? Double.NEGATIVE_INFINITY : 0;
int level = isBaseline ? 1 : endLevel;
createArrays(false,curGrammar.numStates,curGrammar.numSubStates,level,initVal,false);
initializeChart(sentence,curLexicon,false,false,posTags,false);
doConstrainedInsideScores(curGrammar,viterbiParse,viterbiParse);
score = iScore[0][length][0][0];
if (!viterbiParse) score = Math.log(score);// + (100*iScale[0][length][0]);
logLikelihood = score;
if (score != Double.NEGATIVE_INFINITY) {
// System.out.println("\nFinally found a parse for sentence with length "+length+". The LL is "+score+".");
if (!viterbiParse) {
oScore[0][length][0][0] = 1.0;
doConstrainedOutsideScores(curGrammar,viterbiParse,false);
doConstrainedMaxCScores(sentence,curGrammar,curLexicon,false);
}
}
else {
// System.err.println("Using scaling code for sentence with length "+length+".");
setupScaling();
initializeChart(sentence,curLexicon,false,false,posTags,true);
doScaledConstrainedInsideScores(curGrammar);
score = iScore[0][length][0][0];
if (!viterbiParse) score = Math.log(score) + (100*iScale[0][length][0]);
// System.out.println("Finally found a parse for sentence with length "+length+". The LL is "+score+".");
// System.out.println("Scale: "+iScale[0][length][0]);
oScore[0][length][0][0] = 1.0;
oScale[0][length][0] = 0;
doScaledConstrainedOutsideScores(curGrammar);
doConstrainedMaxCScores(sentence,curGrammar,curLexicon,true);
score = iScore[0][length][0][0];
if (!viterbiParse) score = Math.log(score);// + (100*iScale[0][length][0]);
}
grammar = curGrammar;
lexicon = curLexicon;
if (score != Double.NEGATIVE_INFINITY) {
if (viterbiParse) bestTree = extractBestViterbiParse(0, 0, 0, length, sentence);
else {
bestTree = extractBestMaxRuleParse(0, length, sentence);
savedScore = maxcScore[0][length][0];
}
}
// sentencesParsed++;
// System.out.println("For parsing "+sentencesParsed+" I hat to touch "+edgesTouched/((double)sentencesParsed)+" on average.");
return bestTree;
}
public double getModelScore(Tree parsedTree) {
if (viterbiParse) return logLikelihood;
return savedScore;
}
public double getConfidence(Tree tree){
if (logLikelihood == Double.NEGATIVE_INFINITY) return logLikelihood;
// try{
double treeLL = getLogLikelihood(tree);
double sentenceLL = getLogLikelihood();
return treeLL - sentenceLL;
// } catch (Exception e){
// System.err.println("Couldn't compute LL of tree: " + tree);
// return Double.NEGATIVE_INFINITY;
// }
}
public double getLogLikelihood(Tree tree){
if (logLikelihood == Double.NEGATIVE_INFINITY) return logLikelihood;
if (viterbiParse) return logLikelihood;
ArrayList> resultList = new ArrayList>();
Tree newTree = TreeAnnotations.processTree(tree,1,0,binarization,false);
resultList.add(newTree);
StateSetTreeList resultStateSetTrees = new StateSetTreeList(resultList, grammar.numSubStates, false, tagNumberer);
if (llParser==null) llParser = new ArrayParser(grammar, lexicon);
for (Tree t : resultStateSetTrees){
llParser.doInsideScores(t,false,false,null); // Only inside scores are needed here
double ll = Math.log(t.getLabel().getIScore(0));
ll += 100*t.getLabel().getIScale();
return ll;
}
return Double.NEGATIVE_INFINITY;
}
public double getLogLikelihood(){
if (logLikelihood == Double.NEGATIVE_INFINITY) return logLikelihood;
if (viterbiParse) return logLikelihood;
logLikelihood = Math.log(iScore[0][length][0][0]);// +
if (iScale != null) logLikelihood += ScalingTools.LOGSCALE*iScale[0][length][0];
return logLikelihood;
}
/** Assumes that inside and outside scores (sum version, not viterbi) have been computed.
* In particular, the narrowRExtent and other arrays need not be updated.
*/
void doConstrainedMaxCScores(List sentence, Grammar grammar, Lexicon lexicon, final boolean scale) {
short[] numSubStatesArray = grammar.numSubStates;
double initVal = Double.NEGATIVE_INFINITY;
maxcScore = new double[length][length + 1][numStates];
maxcSplit = new int[length][length + 1][numStates];
maxcChild = new int[length][length + 1][numStates];
maxcLeftChild = new int[length][length + 1][numStates];
maxcRightChild = new int[length][length + 1][numStates];
ArrayUtil.fill(maxcScore, Double.NEGATIVE_INFINITY);
double logNormalizer = iScore[0][length][0][0];
// double thresh2 = threshold*logNormalizer;
for (int diff = 1; diff <= length; diff++) {
//System.out.print(diff + " ");
for (int start = 0; start < (length - diff + 1); start++) {
int end = start + diff;
Arrays.fill(maxcSplit[start][end], -1);
Arrays.fill(maxcChild[start][end], -1);
Arrays.fill(maxcLeftChild[start][end], -1);
Arrays.fill(maxcRightChild[start][end], -1);
if (diff > 1) {
// diff > 1: Try binary rules
for (int pState=0; pState= narrowR); // can this right constituent fit next to the left constituent?
if (!iPossibleR) { continue; }
int min1 = narrowR;
int min2 = wideLExtent[end][rState];
int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent?
if (min > narrowL) { continue; }
int max1 = wideRExtent[start][lState];
int max2 = narrowL;
int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent?
if (min > max) { continue; }
double[][][] scores = r.getScores2();
int nLeftChildStates = numSubStatesArray[lState]; // == scores.length;
int nRightChildStates = numSubStatesArray[rState]; // == scores[0].length;
double scoreToBeat = maxcScore[start][end][pState];
for (int split = min; split <= max; split++) {
double ruleScore = 0;
if (!allowedStates[start][split][lState]) continue;
if (!allowedStates[split][end][rState]) continue;
double leftChildScore = maxcScore[start][split][lState];
double rightChildScore = maxcScore[split][end][rState];
if (leftChildScore==initVal||rightChildScore==initVal) continue;
double scalingFactor = 0.0;
if (scale) scalingFactor = Math.log(ScalingTools.calcScaleFactor(
oScale[start][end][pState]+iScale[start][split][lState]+
iScale[split][end][rState]-iScale[0][length][0]));
double gScore = leftChildScore + scalingFactor + rightChildScore;
if (gScore < scoreToBeat) continue; // no chance of finding a better derivation
for (int lp = 0; lp < nLeftChildStates; lp++) {
double lIS = iScore[start][split][lState][lp];
if (lIS == 0) continue;
// if (lIS < thresh2) continue;
//if (!allowedSubStates[start][split][lState][lp]) continue;
for (int rp = 0; rp < nRightChildStates; rp++) {
if (scores[lp][rp]==null) continue;
double rIS = iScore[split][end][rState][rp];
if (rIS == 0) continue;
// if (rIS < thresh2) continue;
//if (!allowedSubStates[split][end][rState][rp]) continue;
for (int np = 0; np < nParentStates; np++) {
//if (!allowedSubStates[start][end][pState][np]) continue;
double pOS = oScore[start][end][pState][np];
if (pOS == 0) continue;
// if (pOS < thresh2) continue;
double ruleS = scores[lp][rp][np];
if (ruleS == 0) continue;
ruleScore += (pOS * ruleS * lIS * rIS) / logNormalizer;
}
}
}
if (ruleScore==0) continue;
if (doVariational){
double norm = 0;
for (int np = 0; np < nParentStates; np++) {
norm += oScore[start][end][pState][np]/logNormalizer*iScore[start][end][pState][np];
}
ruleScore /= norm;
}
// double gScore = ruleScore * leftChildScore * scalingFactor * rightChildScore;
gScore += Math.log(ruleScore);
if (gScore > scoreToBeat) {
scoreToBeat = gScore;
maxcScore[start][end][pState] = gScore;
maxcSplit[start][end][pState] = split;
maxcLeftChild[start][end][pState] = lState;
maxcRightChild[start][end][pState] = rState;
}
}
}
}
} else { // diff == 1
// We treat TAG --> word exactly as if it was a unary rule, except the score of the rule is
// given by the lexicon rather than the grammar and that we allow another unary on top of it.
//for (int tag : lexicon.getAllTags()){
for (int tag=0; tag maxcScoreStartEnd[pState]) {
maxcScoreStartEnd[pState] = gScore;
maxcChild[start][end][pState] = cState;
}
}
}
// for (int i = 0; i < numStates; i++) {
// if (maxcScore[start][end][i]+(1-unaryBonus[i]) > maxcScoreStartEnd[i]){
// maxcScore[start][end][i]+=(1-unaryBonus[i]);
// } else {
// maxcScore[start][end][i] = maxcScoreStartEnd[i];
// maxcChild[start][end][i] = unaryChild[i];
// }
// }
if (foundOne&&doVariational) maxcScoreStartEnd = closeVariationalRules(ruleScores,start,end);
maxcScore[start][end] = maxcScoreStartEnd;
}
}
}
/**
* Returns the best parse, the one with maximum expected labelled recall.
* Assumes that the maxc* arrays have been filled.
*/
public Tree extractBestMaxRuleParse(int start, int end, List sentence ) {
return extractBestMaxRuleParse1(start, end, 0, sentence);
}
/**
* Returns the best parse for state "state", potentially starting with a unary rule
*/
public Tree extractBestMaxRuleParse1(int start, int end, int state, List sentence ) {
//System.out.println(start+", "+end+";");
int cState = maxcChild[start][end][state];
if (cState == -1) {
return extractBestMaxRuleParse2(start, end, state, sentence);
} else {
List> child = new ArrayList>();
child.add( extractBestMaxRuleParse2(start, end, cState, sentence) );
String stateStr = (String) tagNumberer.object(state);
if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2);
totalUsedUnaries++;
//System.out.println("Adding a unary spanning from "+start+" to "+end+". P: "+stateStr+" C: "+child.get(0).getLabel());
int intermediateNode = grammar.getUnaryIntermediate((short)state,(short)cState);
// if (intermediateNode==0){
// System.out.println("Added a bad unary from "+start+" to "+end+". P: "+stateStr+" C: "+child.get(0).getLabel());
// }
if (intermediateNode>0){
List> restoredChild = new ArrayList>();
nTimesRestoredUnaries++;
String stateStr2 = (String)tagNumberer.object(intermediateNode);
if (stateStr2.endsWith("^g")) stateStr2 = stateStr2.substring(0,stateStr2.length()-2);
restoredChild.add(new Tree(stateStr2, child));
//System.out.println("Restored a unary from "+start+" to "+end+": "+stateStr+" -> "+stateStr2+" -> "+child.get(0).getLabel());
return new Tree(stateStr,restoredChild);
}
return new Tree(stateStr, child);
}
}
/**
* Returns the best parse for state "state", but cannot start with a unary
*/
public Tree extractBestMaxRuleParse2(int start, int end, int state, List sentence ) {
List> children = new ArrayList>();
String stateStr = (String)tagNumberer.object(state);//+""+start+""+end;
if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2);
boolean posLevel = (end - start == 1);
if (posLevel) {
if (grammar.isGrammarTag(state)){
List> childs = new ArrayList>();
childs.add(new Tree(sentence.get(start)));
String stateStr2 = (String)tagNumberer.object(maxcChild[start][end][state]);//+""+start+""+end;
children.add(new Tree(stateStr2,childs));
}
else children.add(new Tree(sentence.get(start)));
} else {
int split = maxcSplit[start][end][state];
if (split == -1) {
System.err.println("Warning: no symbol can generate the span from "+ start+ " to "+end+".");
System.err.println("The score is "+maxcScore[start][end][state]+" and the state is supposed to be "+stateStr);
System.err.println("The insideScores are "+Arrays.toString(iScore[start][end][state])+" and the outsideScores are " +Arrays.toString(oScore[start][end][state]));
System.err.println("The maxcScore is "+maxcScore[start][end][state]);
//return extractBestMaxRuleParse2(start, end, maxcChild[start][end][state], sentence);
return new Tree("ROOT");
}
int lState = maxcLeftChild[start][end][state];
int rState = maxcRightChild[start][end][state];
Tree leftChildTree = extractBestMaxRuleParse1(start, split, lState, sentence);
Tree rightChildTree = extractBestMaxRuleParse1(split, end, rState, sentence);
children.add(leftChildTree);
children.add(rightChildTree);
}
return new Tree(stateStr, children);
}
/** Fills in the iScore array of each category over each span
* of length 2 or more.
*/
void doConstrainedViterbiInsideScores(Grammar grammar, boolean level0grammar) {
short[] numSubStatesArray = grammar.numSubStates;
//double[] oldIScores = new double[maxNSubStates];
//int smallestScale = 10, largestScale = -10;
for (int diff = 1; diff <= length; diff++) {
for (int start = 0; start < (length - diff + 1); start++) {
int end = start + diff;
final int lastState = (level0grammar) ? 1 : numSubStatesArray.length;
for (int pState=0; pState= narrowR); // can this right constituent fit next to the left constituent?
if (!iPossibleR) { continue; }
int min1 = narrowR;
int min2 = wideLExtent[end][rState];
int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent?
if (min > narrowL) { continue; }
int max1 = wideRExtent[start][lState];
int max2 = narrowL;
int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent?
if (min > max) { continue; }
// new: loop over all substates
double[][][] scores = r.getScores2();
double pS = Double.NEGATIVE_INFINITY;
if (scores[0][0]!=null) pS = scores[0][0][0];
if (pS == Double.NEGATIVE_INFINITY) continue;
for (int split = min; split <= max; split++) {
if (!vAllowedStates[start][split]) continue;
if (!vAllowedStates[split][end]) continue;
double lS = viScore[start][split][lState];
if (lS == Double.NEGATIVE_INFINITY) continue;
double rS = viScore[split][end][rState];
if (rS == Double.NEGATIVE_INFINITY) continue;
double tot = pS + lS + rS;
if (tot >= bestIScore) { bestIScore = tot;}
}
}
if (bestIScore > oldIScore) { // this way of making "parentState" is better
// than previous
viScore[start][end][pState] = bestIScore;
if (oldIScore == Double.NEGATIVE_INFINITY) {
if (start > narrowLExtent[end][pState]) {
narrowLExtent[end][pState] = start;
wideLExtent[end][pState] = start;
} else {
if (start < wideLExtent[end][pState]) {
wideLExtent[end][pState] = start;
}
}
if (end < narrowRExtent[start][pState]) {
narrowRExtent[start][pState] = end;
wideRExtent[start][pState] = end;
} else {
if (end > wideRExtent[start][pState]) {
wideRExtent[start][pState] = end;
}
}
}
}
}
final int lastStateU = (level0grammar&&diff>1) ? 1 : numSubStatesArray.length;
for (int pState=0; pState= bestIScore) { bestIScore = tot; }
}
if (bestIScore > oldIScore) {
viScore[start][end][pState] = bestIScore;
if (oldIScore == Double.NEGATIVE_INFINITY) {
if (start > narrowLExtent[end][pState]) {
narrowLExtent[end][pState] = start;
wideLExtent[end][pState] = start;
} else {
if (start < wideLExtent[end][pState]) {
wideLExtent[end][pState] = start;
}
}
if (end < narrowRExtent[start][pState]) {
narrowRExtent[start][pState] = end;
wideRExtent[start][pState] = end;
} else {
if (end > wideRExtent[start][pState]) {
wideRExtent[start][pState] = end;
}
}
}
// }
//}
}
}
}
}
}
// void doConstrainedViterbiSubstateInsideScores(Grammar grammar) {
// numSubStatesArray = grammar.numSubStates;
//
// for (int diff = 1; diff <= length; diff++) {
// for (int start = 0; start < (length - diff + 1); start++) {
// int end = start + diff;
// final int lastState = numSubStatesArray.length;
// for (int pState=0; pState= narrowR); // can this right constituent fit next to the left constituent?
// if (!iPossibleR) { continue; }
//
// int min1 = narrowR;
// int min2 = wideLExtent[end][rState];
// int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent?
// if (min > narrowL) { continue; }
//
// int max1 = wideRExtent[start][lState];
// int max2 = narrowL;
// int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent?
// if (min > max) { continue; }
//
// // new: loop over all substates
// double[][][] scores = r.getScores2();
// for (int np = 0; np < nParentSubStates; np++) {
// if (!allowedSubStates[start][end][pState][np]) continue;
// for (int split = min; split <= max; split++) {
// if (!allowedStates[start][split][lState]) continue;
// if (!allowedStates[split][end][rState]) continue;
//
// for (int lp = 0; lp < scores.length; lp++) {
// //if (!allowedSubStates[start][split][lState][lp]) continue;
// double lS = iScore[start][split][lState][lp];
// if (lS == Double.NEGATIVE_INFINITY) continue;
//
// for (int rp = 0; rp < scores[0].length; rp++) {
// //if (!allowedSubStates[split][end][rState][rp]) continue;
// double pS = Double.NEGATIVE_INFINITY;
// if (scores[lp][rp]!=null) pS = scores[lp][rp][np];
// if (pS==Double.NEGATIVE_INFINITY){
// continue;
// //System.out.println("s "+start+" sp "+split+" e "+end+" pS "+pS+" rS "+rS);
// }
// double rS = iScore[split][end][rState][rp];
// if (rS == Double.NEGATIVE_INFINITY) continue;
//
// double tot = pS + lS + rS;
// if (tot >= bestIScore[np]) { bestIScore[np] = tot;}
// }
// }
// }
// }
// }
// boolean firstTime = true;
// for (int s=0; s oldIScore[s]) { // this way of making "parentState" is better
// // than previous
// iScore[start][end][pState][s] = bestIScore[s];
// if (firstTime && oldIScore[s] == Double.NEGATIVE_INFINITY) {
// firstTime = false;
// if (start > narrowLExtent[end][pState]) {
// narrowLExtent[end][pState] = start;
// wideLExtent[end][pState] = start;
// } else {
// if (start < wideLExtent[end][pState]) {
// wideLExtent[end][pState] = start;
// }
// }
// if (end < narrowRExtent[start][pState]) {
// narrowRExtent[start][pState] = end;
// wideRExtent[start][pState] = end;
// } else {
// if (end > wideRExtent[start][pState]) {
// wideRExtent[start][pState] = end;
// }
// }
// }
// }
// }
// }
// final int lastStateU = numSubStatesArray.length;
// for (int pState=0; pState= bestIScore[np]) { bestIScore[np] = tot; }
// }
// }
// }
// boolean firstTime = true;
// for (int s=0; s oldIScore[s]) {
// iScore[start][end][pState][s] = bestIScore[s];
// if (firstTime && oldIScore[s] == Double.NEGATIVE_INFINITY) {
// firstTime = false;
// if (start > narrowLExtent[end][pState]) {
// narrowLExtent[end][pState] = start;
// wideLExtent[end][pState] = start;
// } else {
// if (start < wideLExtent[end][pState]) {
// wideLExtent[end][pState] = start;
// }
// }
// if (end < narrowRExtent[start][pState]) {
// narrowRExtent[start][pState] = end;
// wideRExtent[start][pState] = end;
// } else {
// if (end > wideRExtent[start][pState]) {
// wideRExtent[start][pState] = end;
// }
// }
// }
// }
// }
// }
// }
// }
// }
void doConstrainedViterbiOutsideScores(Grammar grammar, boolean level0grammar) {
for (int diff = length; diff >= 1; diff--) {
for (int start = 0; start + diff <= length; start++) {
int end = start + diff;
final int lastState = (level0grammar) ? 1 : numStates;
for (int cState=0; cState1 && !grammar.isGrammarTag[cState]) continue;
if (!vAllowedStates[start][end]) continue;
double iS = viScore[start][end][cState];
if (iS == Double.NEGATIVE_INFINITY) { continue; }
double oldOScore = voScore[start][end][cState];
double bestOScore = oldOScore;
UnaryRule[] rules = grammar.getClosedViterbiUnaryRulesByChild(cState);
for (int r = 0; r < rules.length; r++) {
UnaryRule ur = rules[r];
int pState = ur.parentState;
if (cState == pState) continue;
double oS = voScore[start][end][pState];
if (oS == Double.NEGATIVE_INFINITY) { continue; }
double[][] scores = ur.getScores2();
double pS = scores[0][0];
double tot = oS + pS;
if (tot > bestOScore) {
bestOScore = tot;
}
}
if (bestOScore > oldOScore) {
voScore[start][end][cState] = bestOScore;
}
}
for (int pState=0; pState 2) {
int min2 = wideLExtent[end][rState];
min = (min1 > min2 ? min1 : min2);
if (max1 < min) { continue; }
int max2 = wideRExtent[start][lState];
max = (max1 < max2 ? max1 : max2);
if (max < min) { continue; }
}
double[][][] scores = br.getScores2();
double pS = Double.NEGATIVE_INFINITY;//scores[0][0][0];
if (scores[0][0]!=null) pS = scores[0][0][0];
if (pS == Double.NEGATIVE_INFINITY) { continue; }
for (int split = min; split <= max; split++) {
if (!vAllowedStates[start][split]) continue;
if (!vAllowedStates[split][end]) continue;
double lS = viScore[start][split][lState];
if (lS == Double.NEGATIVE_INFINITY) { continue; }
double rS = viScore[split][end][rState];
if (rS == Double.NEGATIVE_INFINITY) { continue; }
double totL = pS + rS + oS;
if (totL > voScore[start][split][lState]) {
voScore[start][split][lState] = totL;
}
double totR = pS + lS + oS;
if (totR > voScore[split][end][rState]) {
voScore[split][end][rState] = totR;
}
}
}
}
}
}
}
// void doConstrainedViterbiSubstateOutsideScores(Grammar grammar) {
// for (int diff = length; diff >= 1; diff--) {
// for (int start = 0; start + diff <= length; start++) {
// int end = start + diff;
// final int lastState = numSubStatesArray.length;
// for (int pState=0; pState bestOScore[cp]) {
// bestOScore[cp] = tot;
// }
// }
// }
// for (int s=0; s oldOScore[s]) {
// oScore[start][end][cState][s] = bestOScore[s];
// }
// }
// }
// }
// for (int pState=0; pState 2) {
// int min2 = wideLExtent[end][rState];
// min = (min1 > min2 ? min1 : min2);
// if (max1 < min) { continue; }
// int max2 = wideRExtent[start][lState];
// max = (max1 < max2 ? max1 : max2);
// if (max < min) { continue; }
// }
//
// double[][][] scores = br.getScores2();
// for (int split = min; split <= max; split++) {
// if (!allowedStates[start][split][lState]) continue;
// if (!allowedStates[split][end][rState]) continue;
//
// for (int lp=0; lp bestLOScore) {
// bestLOScore = totL;
// }
// double totR = pS + lS + oS;
// if (totR > bestROScore) {
// bestROScore = totR;
// }
// }
// if (bestLOScore > oldLOScore) {
// oScore[start][split][lState][lp] = bestLOScore;
// }
// if (bestROScore > oldROScore) {
// oScore[split][end][rState][rp] = bestROScore;
// }
// }
// }
// }
// }
// }
// }
// }
// }
public void printUnaryStats(){
System.out.println("Touched "+touchedRules+" rules.");
System.out.println("Used a total of "+totalUsedUnaries+" unaries.");
System.out.println("Restored "+nTimesRestoredUnaries+" unary chains.");
}
/**
* Return the single best parse.
* Note that the returned tree may be missing intermediate nodes in
* a unary chain because it parses with a unary-closed grammar.
*/
public Tree extractBestViterbiParse(int gState, int gp, int start, int end, List sentence ) {
// find sources of inside score
// no backtraces so we can speed up the parsing for its primary use
double bestScore = iScore[start][end][gState][gp];
String goalStr = (String)tagNumberer.object(gState);
if (goalStr.endsWith("^g")) goalStr = goalStr.substring(0,goalStr.length()-2);
if (outputSub) goalStr = goalStr + "-" + gp;
if (outputScore) goalStr = goalStr + " " + bestScore;
//System.out.println("Looking for "+goalStr+" from "+start+" to "+end+" with score "+ bestScore+".");
if (end - start == 1) {
// if the goal state is a preterminal state, then it can't transform into
// anything but the word below it
if (!grammarTags[gState]) {
List> child = new ArrayList>();
child.add(new Tree(sentence.get(start)));
return new Tree(goalStr, child);
}
// if the goal state is not a preterminal state, then find a way to
// transform it into one
else {
double veryBestScore = Double.NEGATIVE_INFINITY;
int newIndex = -1;
int newCp = -1;
UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(gState);
double childScore = bestScore;
for (int r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
int cState = ur.childState;
if (cState == gState) continue;
if (grammarTags[cState]) continue;
if (!allowedStates[start][end][cState]) continue;
double[][] scores = ur.getScores2();
for (int cp=0; cp= veryBestScore) {
childScore = iScore[start][end][cState][cp];
veryBestScore = ruleScore;
newIndex = cState;
newCp = cp;
}
}
}
List> child1 = new ArrayList>();
child1.add(new Tree(sentence.get(start)));
String goalStr1 = (String) tagNumberer.object(newIndex);
if (outputSub) goalStr1 = goalStr1 + "-" + newCp;
if (outputScore) goalStr1 = goalStr1 + " " + childScore;
if (goalStr1==null)
System.out.println("goalStr1==null with newIndex=="+newIndex+" goalStr=="+goalStr);
List> child = new ArrayList>();
child.add(new Tree(goalStr1, child1));
return new Tree(goalStr, child);
}
}
// check binaries first
BinaryRule[] parentRules = grammar.splitRulesWithP(gState);
for (int split = start + 1; split < end; split++) {
//for (Iterator binaryI = grammar.bRuleIteratorByParent(gState, gp); binaryI.hasNext();) {
//BinaryRule br = (BinaryRule) binaryI.next();
for (int i = 0; i < parentRules.length; i++) {
BinaryRule br = parentRules[i];
int lState = br.leftChildState;
if (iScore[start][split][lState]==null) continue;
int rState = br.rightChildState;
if (iScore[split][end][rState]==null) continue;
//new: iterate over substates
double[][][] scores = br.getScores2();
for (int lp=0; lp leftChildTree = extractBestViterbiParse(lState, lp, start, split, sentence);
Tree rightChildTree = extractBestViterbiParse(rState, rp, split, end, sentence);
List> children = new ArrayList>();
children.add(leftChildTree);
children.add(rightChildTree);
Tree result = new Tree(goalStr, children);
//System.out.println("Binary node: "+result);
//result.setScore(score);
return result;
}
}
}
}
}
// check unaries
//for (Iterator unaryI = grammar.uRuleIteratorByParent(gState, gp); unaryI.hasNext();) {
//UnaryRule ur = (UnaryRule) unaryI.next();
UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(gState);
for (int r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
int cState = ur.childState;
if (cState == gState) continue;
if (iScore[start][end][cState]==null) continue;
//new: iterate over substates
double[][] scores = ur.getScores2();
for (int cp=0; cp childTree = extractBestViterbiParse(cState, cp, start, end, sentence);
List> children = new ArrayList>();
children.add(childTree);
// short intermediateNode = grammar.getUnaryIntermediate((short)gState,(short)cState);
// if (intermediateNode>0){
// List> restoredChild = new ArrayList>();
// nTimesRestoredUnaries++;
// String stateStr2 = (String)tagNumberer.object(intermediateNode);
// if (stateStr2.endsWith("^g")) stateStr2 = stateStr2.substring(0,stateStr2.length()-2);
// if (outputSub) stateStr2 = stateStr2 + "-" + 0;
// if (outputScore) stateStr2 = stateStr2 + " " + childScore;
//
// restoredChild.add(new Tree(stateStr2, children));
// //System.out.println("Restored a unary from "+start+" to "+end+": "+stateStr+" -> "+stateStr2+" -> "+child.get(0).getLabel());
// return new Tree(goalStr,restoredChild);
// }
// else {
Tree result = new Tree(goalStr, children);
return result;
// }
}
}
}
System.err.println("Warning: could not find the optimal way to build state "+goalStr+" spanning from "+ start+ " to "+end+".");
return new Tree("ROOT");
}
public double computeTightThresholds(List sentence) {
clearArrays();
length = (short)sentence.size();
double score = 0;
Grammar curGrammar = null;
Lexicon curLexicon = null;
// double[] pruningThreshold = {-6,-12,-14,-14,-14,-14,-14,-14};//Double.NEGATIVE_INFINITY;//Math.log(1.0e-10);
// double[] pruningThreshold = {-6,-10,-10,-10,-10,-10,-10,-10};//Double.NEGATIVE_INFINITY;//Math.log(1.0e-10);
// double[] pruningThreshold = {-6,-9.75,-10,-9.6,-9.66,-8.01,-7.4,-10};//Double.NEGATIVE_INFINITY;//Math.log(1.0e-10);
double[] pruningThreshold = {-16,-16,-16,-16,-16,-16,-16,-16};
//int startLevel = -1;
for (int level=startLevel; level=0){
minThresh = getTightestThrehold(0,length,0, true, level);
if (minThresh == Double.NEGATIVE_INFINITY) {
System.out.println("Something is wrong.");
return -20;
}
System.out.println("Can set the threshold for level "+level+" to "+minThresh);
maxThresholds[level] = Math.min(maxThresholds[level],minThresh);
}
// pruneChart(minThresh-1, curGrammar.numSubStates, level);
pruneChart(pruningThreshold[level+1], curGrammar.numSubStates, level);
}
return -1.0;
}
private double getTightestThrehold(int start, int end, int state, boolean canStartWithUnary, int level) {
boolean posLevel = (end - start == 1);
if (posLevel) return -2;
double minChildren = Double.POSITIVE_INFINITY;
if (canStartWithUnary){
int cState = maxcChild[start][end][state];
if (cState != -1) {
return getTightestThrehold(start, end, cState, false,level);
}
}
int split = maxcSplit[start][end][state];
double lThresh = getTightestThrehold(start, split, maxcLeftChild[start][end][state], true,level);
double rThresh = getTightestThrehold(split, end, maxcRightChild[start][end][state], true,level);
minChildren = Math.min(lThresh,rThresh);
double sentenceProb = (level<1) ? viScore[0][length][0] : iScore[0][length][0][0];
double maxThreshold = Double.NEGATIVE_INFINITY;
for (int substate=0; substate < grammar.numSubStates[state]; substate++){
double iS = (level<1) ? viScore[start][end][state] : iScore[start][end][state][substate];
double oS = (level<1) ? voScore[start][end][state] : oScore[start][end][state][substate];
if (iS==Double.NEGATIVE_INFINITY||oS==Double.NEGATIVE_INFINITY) continue;
double posterior = iS + oS - sentenceProb;
if (posterior > maxThreshold) maxThreshold = posterior;
}
return Math.min(maxThreshold,minChildren);
}
public void doGoldInsideOutsideScores(Tree tree, List sentence) {
Grammar curGrammar = grammarCascade[endLevel-startLevel+1];
Lexicon curLexicon = lexiconCascade[endLevel-startLevel+1];
//pruneChart(Double.POSITIVE_INFINITY/*pruningThreshold[level+1]*/, curGrammar.numSubStates, endLevel);
allowedStates = new boolean[length][length+1][numStates];
ensureGoldTreeSurvives(tree, endLevel);
double initVal = 0;
int level = isBaseline ? 1 : endLevel;
createArrays(false/*false*/,curGrammar.numStates,curGrammar.numSubStates,level,initVal,false);
//setGoldTreeCountsToOne(tree);
initializeChart(sentence,curLexicon,false,true,null,false);
// doConstrainedInsideScores(curGrammar);
// logLikelihood = Math.log(iScore[0][length][0][0]); // + (100*iScale[0][length][0]);
//
// oScore[0][length][0][0] = 1.0;
// doConstrainedOutsideScores(curGrammar);
}
public Tree removeStars(Tree tree) {
String transformedLabel = tree.getLabel();
int starIndex = transformedLabel.indexOf("*");
if (starIndex != -1) transformedLabel = transformedLabel.substring(0,starIndex);
if (tree.isPreTerminal()) {
return new Tree(transformedLabel,tree.getChildren());
}
List> transformedChildren = new ArrayList>();
for (Tree child : tree.getChildren()) {
transformedChildren.add(removeStars(child));
}
return new Tree(transformedLabel, transformedChildren);
}
private double[] closeVariationalRules(double[][] ruleScores, int start, int end) {
double[] closedScores = new double[numStates];
for (int i = 0; i < numStates; i++) {
closedScores[i] = maxcScore[start][end][i];
}
for (int length=1; length<10; length++){
for (int startState=0; startState closedScores[parentState]){
closedScores[parentState] = newScore;
maxcChild[start][end][parentState] = childState;
}
}
}
return closedScores;
}
void doScaledConstrainedInsideScores(Grammar grammar) {
double initVal = 0;
short[] numSubStatesArray = grammar.numSubStates;
//int smallestScale = 10, largestScale = -10;
for (int diff = 1; diff <= length; diff++) {
//smallestScale = 10; largestScale = -10;
//System.out.print(diff + " ");
for (int start = 0; start < (length - diff + 1); start++) {
int end = start + diff;
for (int pState=0; pState= narrowR); // can this right constituent fit next to the left constituent?
if (!iPossibleR) { continue; }
int min1 = narrowR;
int min2 = wideLExtent[end][rState];
int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent?
if (min > narrowL) { continue; }
int max1 = wideRExtent[start][lState];
int max2 = narrowL;
int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent?
if (min > max) { continue; }
// TODO switch order of loops for efficiency
double[][][] scores = r.getScores2();
int nLeftChildStates = numSubStatesArray[lState];
int nRightChildStates = numSubStatesArray[rState];
for (int split = min; split <= max; split++) {
boolean changeThisRound = false;
if (allowedStates[start][split][lState] == false) continue;
if (allowedStates[split][end][rState] == false) continue;
for (int lp = 0; lp < nLeftChildStates; lp++) {
double lS = iScore[start][split][lState][lp];
if (lS == initVal) continue;
for (int rp = 0; rp < nRightChildStates; rp++) {
if (scores[lp][rp]==null) continue;
double rS = iScore[split][end][rState][rp];
if (rS == initVal) continue;
for (int np = 0; np < nParentStates; np++) {
if (!allowedSubStates[start][end][pState][np]) continue;
double pS = scores[lp][rp][np];
if (pS == initVal) continue;
double thisRound = pS*lS*rS;
unscaledScoresToAdd[np] += thisRound;
somethingChanged = true;
changeThisRound = true;
}
}
}
if (!changeThisRound) continue;
//boolean firstTime = false;
int parentScale = iScale[start][end][pState];
int currentScale = iScale[start][split][lState]+iScale[split][end][rState];
currentScale = ScalingTools.scaleArray(unscaledScoresToAdd,currentScale);
if (parentScale!=currentScale) {
if (parentScale==Integer.MIN_VALUE){ // first time to build this span
iScale[start][end][pState] = currentScale;
} else {
int newScale = Math.max(currentScale,parentScale);
ScalingTools.scaleArrayToScale(unscaledScoresToAdd,currentScale,newScale);
ScalingTools.scaleArrayToScale(iScore[start][end][pState],parentScale,newScale);
iScale[start][end][pState] = newScale;
}
}
for (int np = 0; np < nParentStates; np++) {
iScore[start][end][pState][np] += unscaledScoresToAdd[np];
}
Arrays.fill(unscaledScoresToAdd,0);
}
}
if (somethingChanged) {
if (start > narrowLExtent[end][pState]) {
narrowLExtent[end][pState] = start;
wideLExtent[end][pState] = start;
} else {
if (start < wideLExtent[end][pState]) {
wideLExtent[end][pState] = start;
}
}
if (end < narrowRExtent[start][pState]) {
narrowRExtent[start][pState] = end;
wideRExtent[start][pState] = end;
} else {
if (end > wideRExtent[start][pState]) {
wideRExtent[start][pState] = end;
}
}
}
}
// now do the unaries
double[][] scoresAfterUnaries = new double[numStates][];
for (int pState=0; pState narrowLExtent[end][pState]) {
narrowLExtent[end][pState] = start;
wideLExtent[end][pState] = start;
} else {
if (start < wideLExtent[end][pState]) {
wideLExtent[end][pState] = start;
}
}
if (end < narrowRExtent[start][pState]) {
narrowRExtent[start][pState] = end;
wideRExtent[start][pState] = end;
} else {
if (end > wideRExtent[start][pState]) {
wideRExtent[start][pState] = end;
}
}
}
// in any case copy/add the scores from before
for (int np = 0; np < nParentStates; np++) {
if (scoresAfterUnaries[pState]==null) continue;
double val = scoresAfterUnaries[pState][np];
if (val>0) {
iScore[start][end][pState][np] += val;
}
}
}
}
}
}
void doScaledConstrainedOutsideScores(Grammar grammar) {
double initVal = 0;
short[] numSubStatesArray = grammar.numSubStates;
// Arrays.fill(scoresToAdd,initVal);
for (int diff = length; diff >= 1; diff--) {
for (int start = 0; start + diff <= length; start++) {
int end = start + diff;
// do unaries
double[][] scoresAfterUnaries = new double[numStates][];
for (int cState=0; cState0) {
oScore[start][end][cState][cp] += val;
}
}
}
// do binaries
if (diff==1) continue; // there is no space for a binary
for (int pState=0; pState 2) {
int min2 = wideLExtent[end][rState];
min = (min1 > min2 ? min1 : min2);
if (max1 < min) { continue; }
int max2 = wideRExtent[start][lState];
max = (max1 < max2 ? max1 : max2);
if (max < min) { continue; }
}
double[][][] scores = br.getScores2();
int nLeftChildStates = numSubStatesArray[lState];
int nRightChildStates = numSubStatesArray[rState];
for (int split = min; split <= max; split++) {
if (allowedStates[start][split][lState] == false) continue;
if (allowedStates[split][end][rState] == false) continue;
boolean somethingChanged = false;
for (int lp=0; lp initVal){
oScore[start][split][lState][cp] += scoresToAdd[cp];
}
}
Arrays.fill(scoresToAdd, 0);
}
if (DoubleArrays.max(unscaledScoresToAdd)!=0){//oScale[start][end][pState]!=Integer.MIN_VALUE && iScale[start][split][lState]!=Integer.MIN_VALUE){
int rightScale = oScale[split][end][rState];
int currentScale = oScale[start][end][pState]+iScale[start][split][lState];
currentScale = ScalingTools.scaleArray(unscaledScoresToAdd,currentScale);
if (rightScale!=currentScale) {
if (rightScale==Integer.MIN_VALUE){ // first time to build this span
oScale[split][end][rState] = currentScale;
} else {
int newScale = Math.max(currentScale,rightScale);
ScalingTools.scaleArrayToScale(unscaledScoresToAdd,currentScale,newScale);
ScalingTools.scaleArrayToScale(oScore[split][end][rState],rightScale,newScale);
oScale[split][end][rState] = newScale;
}
}
for (int cp=0; cp initVal) {
oScore[split][end][rState][cp] += unscaledScoresToAdd[cp];
}
}
Arrays.fill(unscaledScoresToAdd, 0);
}
}
}
}
}
}
}
protected void setupScaling(){
// create arrays for scaling coefficients
iScale = new int[length][length + 1][];
oScale = new int[length][length + 1][];
for (int start = 0; start < length; start++) {
for (int end = start + 1; end <= length; end++) {
iScale[start][end] = new int[numStates];
oScale[start][end] = new int[numStates];
Arrays.fill(iScale[start][end], Integer.MIN_VALUE);
Arrays.fill(oScale[start][end], Integer.MIN_VALUE);
}
}
// scrub the iScores array
for (int start = 0; start < length; start++) {
for (int end = start + 1; end <= length; end++) {
for (int state=0; state parse = getBestParse(nextSentence);
nextSentence = null;
ArrayList> result = new ArrayList>();
result.add(parse);
synchronized(queue) {
queue.add(result,-nextSentenceID);
queue.notifyAll();
}
return null;
}
public CoarseToFineMaxRuleParser newInstance(){
CoarseToFineMaxRuleParser newParser = new CoarseToFineMaxRuleParser(grammar, lexicon, unaryPenalty, endLevel, viterbiParse, outputSub, outputScore, accurate, this.doVariational,useGoldPOS, false);
newParser.initCascade(this);
return newParser;
}
public double getSentenceProbability(int start, int end, boolean sumScores){
// System.out.println((allowedStates[start][end][0]));
// System.out.println((allowedSubStates[start][end][0][0]));
// System.out.println(Arrays.toString(iScore[start][end][0]));
double score = 0;
if (sumScores){
for (int pState=0; pState 0) {
posteriorsToDump = new ArrayList(blockSize);
}
if (posteriorsToDump.size() == blockSize || blockSize == -1) {
fileName = fileName + "." + nThBlock++;
try {
ObjectOutputStream out = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(fileName)));
out.writeObject(posteriorsToDump);
out.flush();
out.close();
} catch (IOException e) {
System.out.println("IOException: "+e);
}
if (blockSize==-1) return;
posteriorsToDump = new ArrayList(blockSize);
}
Posterior posterior = new Posterior(iScore, oScore, iScale, oScale, allowedStates);
posteriorsToDump.add(posterior);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy