edu.berkeley.nlp.discPCFG.ParsingObjectiveFunction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of berkeleyparser Show documentation
Show all versions of berkeleyparser Show documentation
The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).
The newest version!
/**
*
*/
package edu.berkeley.nlp.discPCFG;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.zip.GZIPInputStream;
import edu.berkeley.nlp.PCFGLA.ArrayParser;
import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.ConditionalTrainer;
import edu.berkeley.nlp.PCFGLA.ConstrainedHierarchicalTwoChartParser;
import edu.berkeley.nlp.PCFGLA.ConstrainedTwoChartsParser;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.SpanPredictor;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;
/**
* @author petrov
*
*/
public class ParsingObjectiveFunction implements ObjectiveFunction {
public static final int NO_REGULARIZATION = 0;
public static final int L1_REGULARIZATION = 1;
public static final int L2_REGULARIZATION = 2;
Grammar grammar;
SimpleLexicon lexicon;
SpanPredictor spanPredictor;
Linearizer linearizer;
int myRegularization;
double sigma;
double lastValue;
double[] lastDerivative;
double[] lastUnregularizedDerivative;
double[] x;
int dimension;
int nGrammarWeights, nLexiconWeights, nSpanWeights;
int nProcesses;
String consBaseName;
StateSetTreeList[] trainingTrees;
ExecutorService pool;
Calculator[] tasks;
double bestObjectiveSoFar;
String outFileName;
double[] spanGoldCounts;
public int dimension() {
return dimension;
}
public double valueAt(double[] x) {
ensureCache(x);
return lastValue;
}
public double[] derivativeAt(double[] x) {
ensureCache(x);
return lastDerivative;
}
public double[] unregularizedDerivativeAt(double[] x) {
ensureCache(x);
return lastUnregularizedDerivative;
}
private void ensureCache(double[] proposed_x) {
if (requiresUpdate(proposed_x)){
linearizer.delinearizeWeights(proposed_x);
grammar = linearizer.getGrammar();
lexicon = linearizer.getLexicon();
spanPredictor = linearizer.getSpanPredictor();
if (this.x == null) this.x = proposed_x.clone();
else{
for (int xi=0; xi 1) {
for (int i = 0; i < nProcesses; i++) {
Future submit = pool.submit(tasks[i]);// execute(tasks[i]);
submits[i] = submit;
}
while (true) {
boolean done = true;
for (Future task : submits) {
done &= task.isDone();
}
if (done)
break;
}
}
// accumulate
double objective = 0;
int nUnparasble = 0, nIncorrectLL = 0;
double[] derivatives = new double[dimension];
for (int i = 0; i < nProcesses; i++) {
Counts counts = null;
if (nProcesses == 1) {
counts = tasks[0].call();
} else {
try {
counts = (Counts) submits[i].get();
} catch (ExecutionException e) {
// TODO Auto-generated catch block
e.printStackTrace();
System.out.println(e.getMessage());
System.out.println(e.getLocalizedMessage());
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
objective += counts.myObjective;// tasks[i].getMyObjective();
for (int j = 0; j < dimension; j++) {
derivatives[j] += counts.myDerivatives[j];
}
nUnparasble += counts.unparsableTrees;
nIncorrectLL += counts.incorrectLLTrees;
}
if (spanPredictor!=null){
// System.out.println("donwscaling span derivatives");
int offset = dimension - spanGoldCounts.length;
double total = 0;
for (int rule=0; rule 0)
System.out.println(nUnparasble + " trees were not parsable.");
if (nIncorrectLL > 0)
System.out.println(nIncorrectLL+" trees had a higher gold LL than all LL.");
// pool.shutdown();
System.out.print("\nThe objective was "+objective);
// double[] derivatives = computeDerivatives(expectedGCounts, expectedCounts);
lastUnregularizedDerivative = derivatives.clone();
switch (myRegularization){
case L2_REGULARIZATION:
objective = l2_regularize(objective, derivatives);
System.out.print(" and is "+objective+" after L2 regularization");
break;
case L1_REGULARIZATION:
objective = l1_regularize(objective, derivatives);
System.out.print(" and is "+objective+" after L1 regularization");
default:
break;
}
System.out.print(".\n");
objective *= -1.0; // flip sign since we are working with a minimizer rather than with a maximizer
for (int index = 0; index < derivatives.length; index++) {
// 'x' and 'derivatives' have same layout
derivatives[index] *= -1.0;
lastUnregularizedDerivative[index] *= -1.0;
}
lastValue = objective;
lastDerivative = derivatives;
//
// for (int i=0; i<50; i++){
// System.out.print(derivatives[derivatives.length-1-i]+" ");
// }
//
if (objective= curBlock.length){
int blockNumber = ((block*nProcesses)+myID);
curBlock = loadData(consName+"-"+blockNumber+".data");
block++;
i = 0;
System.out.print(".");
}
if (!doNotProjectConstraints) eParser.projectConstraints(curBlock[i], false);
myConstraints[tree] = curBlock[i];
i++;
if (myConstraints[tree].length!=myTrees.get(tree).getYield().size()){
System.out.println("My ID: "+myID+", block: "+block+", sentence: "+i);
System.out.println("Sentence length and constraints length do not match!");
myConstraints[tree] = null;
}
}
}
/**
* The most important part of the classifier learning process! This method determines, for the given weight vector
* x, what the (negative) log conditional likelihood of the data is, as well as the derivatives of that likelihood
* wrt each weight parameter.
*/
public Counts call() {
double myObjective = 0;
myDerivatives = new double[dimension];
// double[] myDerivatives = new double[nCounts];
unparsableTrees = 0;
incorrectLLTrees = 0;
if (myConstraints==null) loadConstraints();
int i = -1;
int block = 0;
double totalBias = 0;
for (Tree stateSetTree : myTrees) {
i++;
List yield = stateSetTree.getYield();
boolean noSmoothing = false /*true*/, debugOutput = false;
// parse the sentence
boolean[][][][] cons = null;
if (consName!=null){
cons = myConstraints[i];
if (cons.length != yield.size()){
System.out.println("My ID: "+myID+", block: "+block+", sentence: "+i);
System.out.println("Sentence length ("+yield.size()+") and constraints length ("+cons.length+") do not match!");
System.exit(-1);
}
}
double allLL = eParser.doConstrainedInsideOutsideScores(yield,cons,noSmoothing,null,null,false);
// compute the ll of the gold tree
double goldLL = (ConditionalTrainer.Options.hierarchicalChart) ?
eParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput, eParser.spanScores):
gParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput, eParser.spanScores);
if (i%500==0) System.out.print(".");
if (!sanityCheckLLs(goldLL, allLL, stateSetTree)) {
myObjective += -1000;
continue;
}
if (false){ // compute exhaustive iS/oS to get exact expectations and then compute bias
double[] myExpectedCounts = new double[myDerivatives.length];
eParser.incrementExpectedCounts(linearizer, myExpectedCounts, yield);
double[] myExactExpectedCounts = new double[myDerivatives.length];
double exactLL = eParser.doConstrainedInsideOutsideScores(yield,null,noSmoothing,null,null,false);
eParser.incrementExpectedCounts(linearizer, myExactExpectedCounts, yield);
double bias = 0;
for (int ii=0; ii stateSetTree) {
if (SloppyMath.isVeryDangerous(allLL) || SloppyMath.isVeryDangerous(goldLL)) {
unparsableTrees++;
return false;
}
if (goldLL - allLL > 1.0e-4){
System.out.println("Something is wrong! The gold LL is " + goldLL + " and the all LL is " + allLL);//+"\n"+sentence+"\n"+stateSetTree);
System.out.println(stateSetTree);
incorrectLLTrees++;
return false;
}
return true;
}
}
public double l2_regularize(double objective, double[] derivatives){
// Incorporate penalty terms (regularization) into the objective and derivatives
if (SloppyMath.isVeryDangerous(objective)) return objective;
double sigma2 = sigma*sigma;
double penalty = 0.0;
for (int index = 0; index < x.length; index++) {
//if (lastX[index]==10000 || Double.isInfinite(lastX[index])) continue;
penalty += x[index]*x[index];
}
// System.out.print(" penalty="+penalty);
objective -= penalty / (2*sigma2);
for (int index = 0; index < x.length; index++) {
// 'x' and 'derivatives' have same layout
//if (lastX[index]==10000 || Double.isInfinite(lastX[index])) continue;
derivatives[index] -= x[index]/sigma2;
if (SloppyMath.isVeryDangerous(derivatives[index])){
System.out.println("Setting regularized derivative to zero because it is Inf.");
derivatives[index] = 0;
}
}
return objective;
}
public double l1_regularize(double objective, double[] derivatives){
// Incorporate penalty terms (regularization) into the objective and derivatives
if (SloppyMath.isVeryDangerous(objective)) return objective;
double sigma2 = sigma*sigma;
double sigma2span = 1;//(sigma-2)*(sigma-2);
double sigma2lex = sigma2;//1;//1;//(sigma-2)*(sigma-2);
int ind = 0;
int penaltyGr=0, penaltyLex=0, penaltySpan=0;
for (int i = 0; i < nGrammarWeights; i++) {
penaltyGr += Math.abs(x[ind++]);
}
penaltyGr /= (2*sigma2);
for (int i = 0; i < nLexiconWeights; i++) {
penaltyLex += Math.abs(x[ind++]);
}
penaltyLex /= (2*sigma2lex);
for (int i = 0; i < nSpanWeights; i++) {
penaltySpan += Math.abs(x[ind++]);
}
penaltySpan /= (2*sigma2span);
objective -= (penaltyGr + penaltyLex + penaltySpan);
int index = 0;
for (int i = 0; i < nGrammarWeights; i++) {
double mySigma = sigma2;
if (x[index] < 0) derivatives[index] -= -1.0/mySigma;
else if (x[index] > 0) derivatives[index] -= 1.0/mySigma;
else {
if (derivatives[index] < -1.0/mySigma) derivatives[index] -= 1.0/mySigma;
else if (derivatives[index] > 1.0/mySigma) derivatives[index] -= -1.0/mySigma;
else { derivatives[index] = 0; lastUnregularizedDerivative[index] = 0; } // probably already 0;
}
if (SloppyMath.isVeryDangerous(derivatives[index])||Math.abs(derivatives[index])>1.0e10){
System.out.println("Setting regularized derivative to zero because it is "+derivatives[index]);
derivatives[index] = 0; lastUnregularizedDerivative[index] = 0;
}
index++;
}
for (int i = 0; i < nLexiconWeights; i++) {
double mySigma = sigma2lex;
if (x[index] < 0) derivatives[index] -= -1.0/mySigma;
else if (x[index] > 0) derivatives[index] -= 1.0/mySigma;
else {
if (derivatives[index] < -1.0/mySigma) derivatives[index] -= 1.0/mySigma;
else if (derivatives[index] > 1.0/mySigma) derivatives[index] -= -1.0/mySigma;
else { derivatives[index] = 0; lastUnregularizedDerivative[index] = 0; } // probably already 0;
}
if (SloppyMath.isVeryDangerous(derivatives[index])||Math.abs(derivatives[index])>1.0e10){
System.out.println("Setting regularized derivative to zero because it is "+derivatives[index]);
derivatives[index] = 0; lastUnregularizedDerivative[index] = 0;
}
index++;
}
for (int i = 0; i < nSpanWeights; i++) {
double mySigma = sigma2span;
if (x[index] < 0) derivatives[index] -= -1.0/mySigma;
else if (x[index] > 0) derivatives[index] -= 1.0/mySigma;
else {
if (derivatives[index] < -1.0/mySigma) derivatives[index] -= 1.0/mySigma;
else if (derivatives[index] > 1.0/mySigma) derivatives[index] -= -1.0/mySigma;
else { derivatives[index] = 0; lastUnregularizedDerivative[index] = 0; } // probably already 0;
}
if (SloppyMath.isVeryDangerous(derivatives[index])||Math.abs(derivatives[index])>1.0e10){
System.out.println("Setting regularized derivative to zero because it is "+derivatives[index]);
derivatives[index] = 0; lastUnregularizedDerivative[index] = 0;
}
index++;
}
return objective;
}
//
// public double[] computeDerivatives(double[] expectedGoldCounts, double[] expectedCounts){
// double[] derivatives = new double[dimension()];
//
// int nDangerous = 0;
// if (spanPredictor!=null){
// int offset = dimension - spanGoldCounts.length;
// for (int rule=0; rule1.0e10){
// nDangerous++;
// System.out.println("Setting derivative to zero because it is "+expectedGoldCounts[rule]+" - "+expectedCounts[rule]+" = "+derivatives[rule]);
// derivatives[rule] = 0;
// }
// }
//
// if (nDangerous>0) System.out.println("Set "+nDangerous+" derivatives to 0 since they were dangerous.");
// return derivatives;
// }
public ParsingObjectiveFunction() {
}
public ParsingObjectiveFunction(Linearizer linearizer, StateSetTreeList trainTrees,
double sigma, int regularization, String consName, int nProc, String outName,
boolean doNotProjectConstraints, boolean combinedLexicon) {
this.sigma = sigma;
this.myRegularization = regularization;
this.grammar = linearizer.getGrammar();//.copyGrammar();
this.lexicon = linearizer.getLexicon();//.copyLexicon();
this.spanPredictor = linearizer.getSpanPredictor();
this.linearizer = linearizer;
this.outFileName = outName;
this.dimension = linearizer.dimension();
nGrammarWeights = linearizer.getNGrammarWeights();
nLexiconWeights = linearizer.getNLexiconWeights();
nSpanWeights = linearizer.getNSpanWeights();
if (spanPredictor!=null)
this.spanGoldCounts = spanPredictor.countGoldSpanFeatures(trainTrees);
int nTreesPerBlock = trainTrees.size()/nProc;
this.consBaseName = consName;
boolean[][][][][] tmp = edu.berkeley.nlp.PCFGLA.ParserConstrainer.loadData(consName+"-0.data");
if (tmp!=null) nTreesPerBlock = tmp.length;
// split the trees into chunks
this.nProcesses = nProc;
trainingTrees = new StateSetTreeList[nProcesses];
// allowedStates = new ArrayList[nProcesses];
for (int i=0; i();
}
int block = -1;
int inBlock = 0;
for (int i=0; i double[] getLogProbabilities(EncodedDatum datum, double[] weights, Encoding encoding, IndexLinearizer indexLinearizer) {
// TODO Auto-generated method stub
return null;
}
/**
* @param newSigma
*/
public void setSigma(double newSigma) {
sigma = newSigma;
x = null;
bestObjectiveSoFar = Double.POSITIVE_INFINITY;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy