All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.discPCFG.ParsingObjectiveFunction Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
/**
 * 
 */
package edu.berkeley.nlp.discPCFG;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.zip.GZIPInputStream;

import edu.berkeley.nlp.PCFGLA.ArrayParser;
import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.ConditionalTrainer;
import edu.berkeley.nlp.PCFGLA.ConstrainedHierarchicalTwoChartParser;
import edu.berkeley.nlp.PCFGLA.ConstrainedTwoChartsParser;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.SpanPredictor;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;

/**
 * @author petrov
 *
 */
public class ParsingObjectiveFunction implements ObjectiveFunction {
  
	public static final int NO_REGULARIZATION = 0;
	public static final int L1_REGULARIZATION = 1;
	public static final int L2_REGULARIZATION = 2;
	
	Grammar grammar;
  SimpleLexicon lexicon;
  SpanPredictor spanPredictor;
  Linearizer linearizer;

  int myRegularization;
  double sigma;

  double lastValue;
  double[] lastDerivative;
  double[] lastUnregularizedDerivative;
  double[] x;

  int dimension;
  
  int nGrammarWeights, nLexiconWeights, nSpanWeights;

	int nProcesses;
	String consBaseName;
	
	StateSetTreeList[] trainingTrees;
  ExecutorService pool;
  Calculator[] tasks;
  double bestObjectiveSoFar;
  String outFileName;

  double[] spanGoldCounts;
	
  public int dimension() {
    return dimension;
  }

  public double valueAt(double[] x) {
    ensureCache(x);
    return lastValue;
  }

  public double[] derivativeAt(double[] x) {
    ensureCache(x);
    return lastDerivative;
  }

  public double[] unregularizedDerivativeAt(double[] x) {
    ensureCache(x);
    return lastUnregularizedDerivative;
  }

  
  private void ensureCache(double[] proposed_x) {
  	if (requiresUpdate(proposed_x)){ 

  		linearizer.delinearizeWeights(proposed_x);
  		grammar = linearizer.getGrammar();
  		lexicon = linearizer.getLexicon();
  		spanPredictor = linearizer.getSpanPredictor();
  		
      if (this.x == null) this.x = proposed_x.clone(); 
      else{ 
      	for (int xi=0; xi 1) {
				for (int i = 0; i < nProcesses; i++) {
					Future submit = pool.submit(tasks[i]);// execute(tasks[i]);
					submits[i] = submit;
				}
				while (true) {
					boolean done = true;
					for (Future task : submits) {
						done &= task.isDone();
					}
					if (done)
						break;
				}
			}

			// accumulate
			double objective = 0;
			int nUnparasble = 0, nIncorrectLL = 0;
			double[] derivatives = new double[dimension];
			for (int i = 0; i < nProcesses; i++) {
				Counts counts = null;
				if (nProcesses == 1) {
					counts = tasks[0].call();
				} else {
					try {
						counts = (Counts) submits[i].get();
					} catch (ExecutionException e) {
						// TODO Auto-generated catch block
						e.printStackTrace();
						System.out.println(e.getMessage());
						System.out.println(e.getLocalizedMessage());
					} catch (InterruptedException e) {
						// TODO Auto-generated catch block
						e.printStackTrace();
					}
				}
				objective += counts.myObjective;// tasks[i].getMyObjective();
				for (int j = 0; j < dimension; j++) {
					derivatives[j] += counts.myDerivatives[j];
				}
				nUnparasble += counts.unparsableTrees;
				nIncorrectLL += counts.incorrectLLTrees;
			}
			
			if (spanPredictor!=null){
//				System.out.println("donwscaling span derivatives");
				int offset = dimension - spanGoldCounts.length;
				double total = 0;
		  	for (int rule=0; rule 0)
				System.out.println(nUnparasble + " trees were not parsable.");
			if (nIncorrectLL > 0)
				System.out.println(nIncorrectLL+" trees had a higher gold LL than all LL.");
  		
//  		pool.shutdown();
  		System.out.print("\nThe objective was "+objective);
//      double[] derivatives = computeDerivatives(expectedGCounts, expectedCounts);
    	lastUnregularizedDerivative = derivatives.clone();
      
      switch (myRegularization){
	      case L2_REGULARIZATION:
	      	objective = l2_regularize(objective, derivatives);
	    	  System.out.print(" and is "+objective+" after L2 regularization");
	    	  break;
	      case L1_REGULARIZATION:
	      	objective = l1_regularize(objective, derivatives);
	    	  System.out.print(" and is "+objective+" after L1 regularization");
	      default:
	      	break;
      }		    
  	  System.out.print(".\n");

  	  objective *= -1.0; // flip sign since we are working with a minimizer rather than with a maximizer
      for (int index = 0; index < derivatives.length; index++) {
        // 'x' and 'derivatives' have same layout
        derivatives[index] *= -1.0;
        lastUnregularizedDerivative[index] *= -1.0;
      }
      
    	
      lastValue = objective;
      lastDerivative = derivatives;
//      
//      for (int i=0; i<50; i++){
//      	System.out.print(derivatives[derivatives.length-1-i]+" ");
//      }
//      
      if (objective= curBlock.length){
	  			int blockNumber = ((block*nProcesses)+myID);
	  			curBlock = loadData(consName+"-"+blockNumber+".data");
	  			block++;
	  			i = 0;
	  			System.out.print(".");
	  		}
  			if (!doNotProjectConstraints) eParser.projectConstraints(curBlock[i], false);
  			myConstraints[tree] = curBlock[i];
  			i++;
  			if (myConstraints[tree].length!=myTrees.get(tree).getYield().size()){
    			System.out.println("My ID: "+myID+", block: "+block+", sentence: "+i);
    			System.out.println("Sentence length and constraints length do not match!");
    			myConstraints[tree] = null;
  			}
  		}

  	}

	  /**
	   * The most important part of the classifier learning process!  This method determines, for the given weight vector
	   * x, what the (negative) log conditional likelihood of the data is, as well as the derivatives of that likelihood
	   * wrt each weight parameter.
	   */
	  public Counts call() {
	  	double myObjective = 0;
	  	myDerivatives = new double[dimension];
//			double[] myDerivatives = new double[nCounts];
			unparsableTrees = 0;
			incorrectLLTrees = 0;
			
			if (myConstraints==null) loadConstraints();
			
	  	int i = -1;
	  	int block = 0;
	  	double totalBias = 0;
	  	for (Tree stateSetTree : myTrees) {
	  		i++;
	  		List yield = stateSetTree.getYield();
	    	
	  		boolean noSmoothing = false /*true*/, debugOutput = false;
	    	
	  		// parse the sentence
	    	boolean[][][][] cons = null;
	    	if (consName!=null){
	    		cons = myConstraints[i];
	    		if (cons.length != yield.size()){
	    			System.out.println("My ID: "+myID+", block: "+block+", sentence: "+i);
	    			System.out.println("Sentence length ("+yield.size()+") and constraints length ("+cons.length+") do not match!");
	    			System.exit(-1);
	    		}
	    	}
	    	double allLL = eParser.doConstrainedInsideOutsideScores(yield,cons,noSmoothing,null,null,false);


	    	// compute the ll of the gold tree
	    	double goldLL = (ConditionalTrainer.Options.hierarchicalChart) ?
	    			eParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput, eParser.spanScores):
	    			gParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput, eParser.spanScores);
	  		
	    	if (i%500==0) System.out.print(".");
	    	
	    	if (!sanityCheckLLs(goldLL, allLL, stateSetTree)) {
	    		myObjective += -1000;
	    		continue;
	    	}

	    	if (false){ // compute exhaustive iS/oS to get exact expectations and then compute bias
	    		double[] myExpectedCounts = new double[myDerivatives.length];
	    		eParser.incrementExpectedCounts(linearizer, myExpectedCounts, yield);

	    		double[] myExactExpectedCounts = new double[myDerivatives.length];
	    		double exactLL = eParser.doConstrainedInsideOutsideScores(yield,null,noSmoothing,null,null,false);
	    		eParser.incrementExpectedCounts(linearizer, myExactExpectedCounts, yield);
	    		
	    		double bias = 0;
	    		for (int ii=0; ii stateSetTree) {
			if (SloppyMath.isVeryDangerous(allLL) || SloppyMath.isVeryDangerous(goldLL)) {
				unparsableTrees++;
	  		return false;
	  	}

			if (goldLL - allLL > 1.0e-4){
	  		System.out.println("Something is wrong! The gold LL is " + goldLL + " and the all LL is " + allLL);//+"\n"+sentence+"\n"+stateSetTree);
	  		System.out.println(stateSetTree);
	  		incorrectLLTrees++;
	  		return false;
	  	}
			return true;
		}
  }
  
  
	public double l2_regularize(double objective, double[] derivatives){
	  // Incorporate penalty terms (regularization) into the objective and derivatives
  	if (SloppyMath.isVeryDangerous(objective)) return objective;
		double sigma2 = sigma*sigma;
	  double penalty = 0.0;
	  for (int index = 0; index < x.length; index++) {
	  	//if (lastX[index]==10000 || Double.isInfinite(lastX[index])) continue;
	    penalty += x[index]*x[index];
	  }
//	  System.out.print(" penalty="+penalty);
	  objective -= penalty / (2*sigma2);
	  
	  for (int index = 0; index < x.length; index++) {
	    // 'x' and 'derivatives' have same layout
	  	//if (lastX[index]==10000 || Double.isInfinite(lastX[index])) continue;
	    derivatives[index] -= x[index]/sigma2;
	    if (SloppyMath.isVeryDangerous(derivatives[index])){
	  		System.out.println("Setting regularized derivative to zero because it is Inf.");
	  		derivatives[index] = 0;
	  	}
	  }
	  return objective;

  }
  
	public double l1_regularize(double objective, double[] derivatives){
	  // Incorporate penalty terms (regularization) into the objective and derivatives
  	if (SloppyMath.isVeryDangerous(objective)) return objective;
		double sigma2 = sigma*sigma;
		double sigma2span = 1;//(sigma-2)*(sigma-2);
		double sigma2lex = sigma2;//1;//1;//(sigma-2)*(sigma-2);
	  int ind = 0;
	  
	  int penaltyGr=0, penaltyLex=0, penaltySpan=0;
		for (int i = 0; i < nGrammarWeights; i++) {
	    penaltyGr += Math.abs(x[ind++]);
	  }
		penaltyGr /= (2*sigma2);
		for (int i = 0; i < nLexiconWeights; i++) {
	    penaltyLex += Math.abs(x[ind++]);
	  }
		penaltyLex /= (2*sigma2lex);
		for (int i = 0; i < nSpanWeights; i++) {
	    penaltySpan += Math.abs(x[ind++]);
	  }
		penaltySpan /= (2*sigma2span);
	  
  	
	  objective -= (penaltyGr + penaltyLex + penaltySpan);
	  
	  int index = 0;
	  for (int i = 0; i < nGrammarWeights; i++) {
	  	double mySigma = sigma2;

	    if (x[index] < 0) derivatives[index] -= -1.0/mySigma;
	    else if (x[index] > 0) derivatives[index] -= 1.0/mySigma;
	    else {
	    	if (derivatives[index] < -1.0/mySigma) derivatives[index] -= 1.0/mySigma;
	    	else if (derivatives[index] > 1.0/mySigma) derivatives[index] -= -1.0/mySigma;
	    	else { derivatives[index] = 0; lastUnregularizedDerivative[index] = 0; } // probably already 0;
	    }
	     
	    if (SloppyMath.isVeryDangerous(derivatives[index])||Math.abs(derivatives[index])>1.0e10){
	  		System.out.println("Setting regularized derivative to zero because it is "+derivatives[index]);
	  		derivatives[index] = 0; lastUnregularizedDerivative[index] = 0;
	  	}
	    index++;
	  }
	  for (int i = 0; i < nLexiconWeights; i++) {
	  	double mySigma = sigma2lex;

	    if (x[index] < 0) derivatives[index] -= -1.0/mySigma;
	    else if (x[index] > 0) derivatives[index] -= 1.0/mySigma;
	    else {
	    	if (derivatives[index] < -1.0/mySigma) derivatives[index] -= 1.0/mySigma;
	    	else if (derivatives[index] > 1.0/mySigma) derivatives[index] -= -1.0/mySigma;
	    	else { derivatives[index] = 0; lastUnregularizedDerivative[index] = 0; } // probably already 0;
	    }
	     
	    if (SloppyMath.isVeryDangerous(derivatives[index])||Math.abs(derivatives[index])>1.0e10){
	  		System.out.println("Setting regularized derivative to zero because it is "+derivatives[index]);
	  		derivatives[index] = 0; lastUnregularizedDerivative[index] = 0;
	  	}
	    index++;
	  }
	  for (int i = 0; i < nSpanWeights; i++) {
	  	double mySigma = sigma2span;

	    if (x[index] < 0) derivatives[index] -= -1.0/mySigma;
	    else if (x[index] > 0) derivatives[index] -= 1.0/mySigma;
	    else {
	    	if (derivatives[index] < -1.0/mySigma) derivatives[index] -= 1.0/mySigma;
	    	else if (derivatives[index] > 1.0/mySigma) derivatives[index] -= -1.0/mySigma;
	    	else { derivatives[index] = 0; lastUnregularizedDerivative[index] = 0; } // probably already 0;
	    }
	     
	    if (SloppyMath.isVeryDangerous(derivatives[index])||Math.abs(derivatives[index])>1.0e10){
	  		System.out.println("Setting regularized derivative to zero because it is "+derivatives[index]);
	  		derivatives[index] = 0; lastUnregularizedDerivative[index] = 0;
	  	}
	    index++;
	  }
	  
	  return objective;

  }
  
//  
//  public double[] computeDerivatives(double[] expectedGoldCounts, double[] expectedCounts){
//  	double[] derivatives = new double[dimension()];
//  	
//    int nDangerous = 0;
//		if (spanPredictor!=null){
//			int offset = dimension - spanGoldCounts.length;
//	  	for (int rule=0; rule1.0e10){
//    		nDangerous++;
//    		System.out.println("Setting derivative to zero because it is "+expectedGoldCounts[rule]+" - "+expectedCounts[rule]+" = "+derivatives[rule]);
//    		derivatives[rule] = 0;
//    	}
//    }
//
//  	if (nDangerous>0) System.out.println("Set "+nDangerous+" derivatives to 0 since they were dangerous.");
//  	return derivatives;
//  }
  

  
  public ParsingObjectiveFunction() {
  }

  public ParsingObjectiveFunction(Linearizer linearizer, StateSetTreeList trainTrees,
  		double sigma, int regularization, String consName, int nProc, String outName, 
  		boolean doNotProjectConstraints, boolean combinedLexicon) {
    this.sigma = sigma;
    this.myRegularization = regularization;
		this.grammar = linearizer.getGrammar();//.copyGrammar();
		this.lexicon = linearizer.getLexicon();//.copyLexicon();
		this.spanPredictor = linearizer.getSpanPredictor();
		this.linearizer = linearizer;
		this.outFileName = outName;
		this.dimension = linearizer.dimension();
	  
		
		nGrammarWeights = linearizer.getNGrammarWeights();
		nLexiconWeights = linearizer.getNLexiconWeights();
		nSpanWeights = linearizer.getNSpanWeights();
		
		if (spanPredictor!=null)
			this.spanGoldCounts = spanPredictor.countGoldSpanFeatures(trainTrees);
		
	  int nTreesPerBlock = trainTrees.size()/nProc;
	  this.consBaseName = consName;
	  boolean[][][][][] tmp = edu.berkeley.nlp.PCFGLA.ParserConstrainer.loadData(consName+"-0.data");
	  if (tmp!=null) nTreesPerBlock = tmp.length;
	  
	  // split the trees into chunks
	  this.nProcesses = nProc;
	  trainingTrees = new StateSetTreeList[nProcesses];
//	  allowedStates = new ArrayList[nProcesses];
	  for (int i=0; i();
	  }
	  int block = -1;
	  int inBlock = 0;
	  for (int i=0; i double[] getLogProbabilities(EncodedDatum datum, double[] weights, Encoding encoding, IndexLinearizer indexLinearizer) {
		// TODO Auto-generated method stub
		return null;
	}

	/**
	 * @param newSigma
	 */
	public void setSigma(double newSigma) {
		sigma = newSigma;		
		x = null;
		bestObjectiveSoFar = Double.POSITIVE_INFINITY;
	}
}







© 2015 - 2025 Weber Informatics LLC | Privacy Policy