![JAR search and dependency download from the Maven repository](/logo.png)
edu.berkeley.nlp.PCFGLA.GrammarMerger Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of berkeleyparser Show documentation
Show all versions of berkeleyparser Show documentation
The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).
The newest version!
package edu.berkeley.nlp.PCFGLA;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import edu.berkeley.nlp.PCFGLA.Corpus.TreeBankType;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.CommandLineUtils;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.PriorityQueue;
public class GrammarMerger {
/**
* @param args
*/
public static void main(String[] args) {
if (args.length < 1) {
System.out.println("usage: java GrammarMerger \n" +
"\t\t -i Input File for Grammar (Required)\n" +
"\t\t -o Output File for Merged Grammar (Required)\n"+
"\t\t -p Merging percentage (Default: 0.5)\n" +
"\t\t -2p Merging percentage for non-siblings (Default: 0.0)\n" +
"\t\t -top Keep top N substates, overrides -p!" +
" -path Path to Corpus (Default: null)\n" +
// " -lang Language: 1-ENG, 2-CHN, 3-GER, 4-ARB (Default: 1)\n" +
"\t\t -chsh If this is enabled, then we train on a short segment of\n" +
"\t\t the Chinese treebank (Default: false)" +
"\t\t -trfr The fraction of the training corpus to keep (Default: 1.0)\n" +
"\t\t -maxIt Maximum number of EM iterations (Default: 100)"+
"\t\t -minIt Minimum number of EM iterations (Default: 5)"+
"\t\t -f Filter rules with prob under f (Default: -1)"+
"\t\t -dL Delete labels? (true/false) (Default: false)"+
"\t\t -ent Use Entropic prior (Default: false)"+
"\t\t -maxL Maximum sentence length (Default: 10000)"+
"\t\t -sep Set merging threshold for grammar and lexicon separately (Default: false)"
);
System.exit(2);
}
// provide feedback on command-line arguments
System.out.print("Running with arguments: ");
for (String arg : args) {
System.out.print(" '"+arg+"'");
}
System.out.println("");
// parse the input arguments
Map input = CommandLineUtils.simpleCommandLineParser(args);
double mergingPercentage = Double.parseDouble(CommandLineUtils.getValueOrUseDefault(input, "-p", "0.5"));
double mergingPercentage2 = Double.parseDouble(CommandLineUtils.getValueOrUseDefault(input, "-2p", "0.0"));
String outFileName = CommandLineUtils.getValueOrUseDefault(input, "-o", null);
String inFileName = CommandLineUtils.getValueOrUseDefault(input, "-i", null);
System.out.println("Loading grammar from " + inFileName + ".");
ParserData pData = ParserData.Load(inFileName);
if (pData == null) {
System.out.println("Failed to load grammar from file" + inFileName + ".");
System.exit(1);
}
int minIterations = Integer.parseInt(CommandLineUtils.getValueOrUseDefault(input,"-minIt","0"));
if (minIterations>0)
System.out.println("I will do at least "+minIterations+" iterations.");
boolean separateMerge = CommandLineUtils.getValueOrUseDefault(input, "-sep", "").equals("true");
int maxIterations = Integer.parseInt(CommandLineUtils.getValueOrUseDefault(input,"-maxIt","100"));
if (maxIterations>0)
System.out.println("But at most "+maxIterations+" iterations.");
boolean deleteLabels = CommandLineUtils.getValueOrUseDefault(input, "-dL", "").equals("true");
boolean useEntropicPrior = CommandLineUtils.getValueOrUseDefault(input, "-ent", "").equals("true");
int maxSentenceLength = Integer.parseInt(CommandLineUtils.getValueOrUseDefault(input, "-maxL", "10000"));
System.out.println("Will remove sentences with more than "+maxSentenceLength+" words.");
String path = CommandLineUtils.getValueOrUseDefault(input, "-path", null);
// int lang = Integer.parseInt(CommandLineUtils.getValueOrUseDefault(input, "-lang", "1"));
// System.out.println("Loading trees from "+path+" and using language "+lang);
boolean chineseShort = Boolean.parseBoolean(CommandLineUtils
.getValueOrUseDefault(input, "-chsh", "false"));
double trainingFractionToKeep = Double.parseDouble(CommandLineUtils
.getValueOrUseDefault(input, "-trfr", "1.0"));
Grammar grammar = pData.getGrammar();
Lexicon lexicon = pData.getLexicon();
Numberer.setNumberers(pData.getNumbs());
int h_markov = pData.h_markov;
int v_markov = pData.v_markov;
Binarization bin = pData.bin;
short[] numSubStatesArray = pData.numSubStatesArray;
Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
double filter = Double.parseDouble(CommandLineUtils.getValueOrUseDefault(input, "-f","-1"));
if(filter>0) System.out.println("Will remove rules with prob under "+filter);
Corpus corpus = new Corpus(path,TreeBankType.WSJ,trainingFractionToKeep,false);
//int nTrees = corpus.getTrainTrees().size();
//binarize trees
List> trainTrees = Corpus.binarizeAndFilterTrees(corpus
.getTrainTrees(), v_markov, h_markov, maxSentenceLength, bin, false, false);
List> validationTrees = Corpus.binarizeAndFilterTrees(corpus
.getValidationTrees(), v_markov, h_markov, maxSentenceLength, bin, false, false);
int nTrees = trainTrees.size();
System.out.println("There are "+nTrees+" trees in the training set.");
StateSetTreeList trainStateSetTrees = new StateSetTreeList(trainTrees, numSubStatesArray, false, tagNumberer);
StateSetTreeList validationStateSetTrees = new StateSetTreeList(validationTrees, numSubStatesArray, false, tagNumberer);
// get rid of the old trees
// trainTrees = null;
// validationTrees = null;
// corpus = null;
// System.gc();
// System.out.println("before merging, we have split trees:");
// for (int i=0; i stateSetTree : trainStateSetTrees) {
// boolean secondHalf = (n++>nTrees/2.0);
// newParser.doInsideOutsideScores(stateSetTree,noSmoothing,debugOutput); // E Step
// double ll = stateSetTree.getLabel().getIScore(0);
// ll = Math.log(ll) + (100*stateSetTree.getLabel().getIScale());//System.out.println(stateSetTree);
// if (Double.isInfinite(ll)||Double.isNaN(ll)) {
// System.out.println("Training sentence "+n+" is given "+ll+" log likelihood!");
// GrammarTrainer.printBadLLReason(stateSetTree,lexicon);
// }
// else {
// //System.out.println("Training sentence "+n+" is good.");
// trainingLikelihood += ll;
// newLexicon.trainTree(stateSetTree, -1, lexicon, secondHalf,true);
// }
// }
System.out.println("The training LL is "+trainingLikelihood);
newLexicon.optimize();//Grammar.RandomInitializationType.INITIALIZE_WITH_SMALL_RANDOMIZATION); // M Step
// do 5 iterations of EM to clean things up
SophisticatedLexicon previousLexicon = null;
Grammar previousGrammar = null;
System.out.println("Doing some iterations of EM to clean things up...");
double maxLikelihood = Double.NEGATIVE_INFINITY;
int droppingIter = 0;
int iter = 0;
while ((droppingIter < 2)&& (iter stateSetTree : trainStateSetTrees) {
boolean secondHalf = (n++>nTrees/2.0);
newParser.doInsideOutsideScores(stateSetTree,noSmoothing,debugOutput); // E Step
double ll = stateSetTree.getLabel().getIScore(0);
ll = Math.log(ll) + (100*stateSetTree.getLabel().getIScale());//System.out.println(stateSetTree);
if (Double.isInfinite(ll)||Double.isNaN(ll)) {
System.out.println("Training sentence "+n+" is given "+ll+" log likelihood!");
GrammarTrainer.printBadLLReason(stateSetTree,previousLexicon);
}
else {
trainingLikelihood += ll;
newGrammar.tallyStateSetTree(stateSetTree, previousGrammar); // E Step
newLexicon.trainTree(stateSetTree, -1, previousLexicon, secondHalf,false,4 /*opts.rare*/);
}
}
System.out.println("The training LL is "+trainingLikelihood);
newLexicon.optimize();//Grammar.RandomInitializationType.INITIALIZE_WITH_SMALL_RANDOMIZATION); // M Step
newGrammar.optimize(0);// Grammar.RandomInitializationType.INITIALIZE_WITH_SMALL_RANDOMIZATION); // M Step
newParser = new ArrayParser(newGrammar, newLexicon);
//System.out.println("Evaluating new grammar");
double validationLikelihood = 0;
n = 0;
for (Tree stateSetTree : validationStateSetTrees) {
n++;
newParser.doInsideScores(stateSetTree,false,false, null); // E Step
double ll = stateSetTree.getLabel().getIScore(0);
ll = Math.log(ll) + (100*stateSetTree.getLabel().getIScale());//System.out.println(stateSetTree);
if (Double.isInfinite(ll)||Double.isNaN(ll)) {
System.out.println("Validation sentence "+n+" is given -inf log likelihood!");
}
else validationLikelihood += ll; // there are for some reason some sentences that are unparsable
}
System.out.println("The validation LL after merging and "+(iter+1)+" iterations is "+validationLikelihood);
if (iter maxLikelihood) {
maxLikelihood = validationLikelihood;
grammar = newGrammar;
lexicon = newLexicon;
droppingIter = 0;
}
else { droppingIter++; }
if (iter>0 && iter%5==0){
pData = new ParserData(newLexicon, newGrammar, null, Numberer.getNumberers(), newNumSubStatesArray, v_markov, h_markov, bin);
System.out.println("Saving grammar to "+outFileName+"-it-"+iter+".");
System.out.println("It gives a validation data log likelihood of: "+maxLikelihood);
if (pData.Save(outFileName+"-it-"+iter)) System.out.println("Saving successful");
else System.out.println("Saving failed!");
pData = null;
}
}
System.out.println("Saving grammar to "+outFileName+".");
System.out.println("It gives a validation data log likelihood of: "+maxLikelihood);
// for (int i=0; i lexiconStates = new PriorityQueue();
PriorityQueue grammarStates = new PriorityQueue();
short[] numSubStatesArray = grammar.numSubStates;
short[] newNumSubStatesArray = newGrammar.numSubStates;
Numberer tagNumberer = grammar.tagNumberer;
for (short state=0; state stateSetTree : trainStateSetTrees) {
parser.doInsideOutsideScores(stateSetTree,noSmoothing,debugOutput); // E Step
double ll = stateSetTree.getLabel().getIScore(0);
ll = Math.log(ll) + (100*stateSetTree.getLabel().getIScale());//System.out.println(stateSetTree);
if (!Double.isInfinite(ll))
grammar.tallyMergeScores(stateSetTree, deltas, mergeWeights);
}
return deltas;
}
/**
* @param grammar
* @param lexicon
* @param trainStateSetTrees
* @return
*/
public static double[][] computeMergeWeights(Grammar grammar, Lexicon lexicon, StateSetTreeList trainStateSetTrees) {
double[][] mergeWeights = new double[grammar.numSubStates.length][(int)ArrayUtil.max(grammar.numSubStates)];
double trainingLikelihood = 0;
ArrayParser parser = new ArrayParser(grammar, lexicon);
boolean noSmoothing = false, debugOutput = false;
int n = 0;
for (Tree stateSetTree : trainStateSetTrees) {
parser.doInsideOutsideScores(stateSetTree,noSmoothing,debugOutput); // E Step
double ll = stateSetTree.getLabel().getIScore(0);
ll = Math.log(ll) + (100*stateSetTree.getLabel().getIScale());//System.out.println(stateSetTree);
if (Double.isInfinite(ll)) {
System.out.println("Training sentence "+n+" is given -inf log likelihood!");
}
else {
trainingLikelihood += ll; // there are for some reason some sentences that are unparsable
grammar.tallyMergeWeights(stateSetTree, mergeWeights);
}
n++;
}
System.out.println("The trainings LL before merging is "+trainingLikelihood);
// normalize the weights
grammar.normalizeMergeWeights(mergeWeights);
return mergeWeights;
}
/**
* @param deltas
* @return
*/
public static boolean[][][] determineMergePairs(double[][][] deltas, boolean separateMerge, double mergingPercentage, Grammar grammar) {
boolean[][][] mergeThesePairs = new boolean[grammar.numSubStates.length][][];
short[] numSubStatesArray = grammar.numSubStates;
// set the threshold so that p percent of the splits are merged again.
ArrayList deltaSiblings = new ArrayList();
ArrayList deltaPairs = new ArrayList();
ArrayList deltaLexicon = new ArrayList();
ArrayList deltaGrammar = new ArrayList();
int nSiblings = 0, nPairs = 0, nSiblingsGr = 0, nSiblingsLex=0;
for (int state=0; state2 && mergingPercentage2>0) threshold2 = deltaPairs.get((int)(nPairs*mergingPercentage2));
// } else {
// int top = Integer.parseInt(topNmerge);
// System.out.println("Keeping the top "+top+" substates.");
// threshold = deltaSiblings.get(nPairs-top);
// }
System.out.println("Setting the threshold for siblings to "+threshold+".");
}
// if (maxSubStates>2 && mergingPercentage2>0) System.out.println("Setting the threshold for other pairs to "+threshold2);
int mergePair = 0, mergeSiblings = 0;
for (int state=0; state0) {
// for (int j=i+1; j
© 2015 - 2025 Weber Informatics LLC | Privacy Policy