All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.PCFGLA.ConditionalTrainer Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
package edu.berkeley.nlp.PCFGLA;


import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

import edu.berkeley.nlp.PCFGLA.Corpus.TreeBankType;
import edu.berkeley.nlp.PCFGLA.smoothing.NoSmoothing;
import edu.berkeley.nlp.PCFGLA.smoothing.SmoothAcrossParentBits;
import edu.berkeley.nlp.PCFGLA.smoothing.Smoother;
import edu.berkeley.nlp.discPCFG.ConditionalMerger;
import edu.berkeley.nlp.discPCFG.DefaultLinearizer;
import edu.berkeley.nlp.discPCFG.HiearchicalAdaptiveLinearizer;
import edu.berkeley.nlp.discPCFG.HierarchicalLinearizer;
import edu.berkeley.nlp.discPCFG.Linearizer;
import edu.berkeley.nlp.discPCFG.ParsingObjectiveFunction;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.math.OW_LBFGSMinimizer;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;

/**
 * Reads in the Penn Treebank and generates N_GRAMMARS different grammars.
 *
 * @author Slav Petrov
 */
public class ConditionalTrainer {

	/**
	 * @author adampauls
	 *
	 */
	public static interface ParsingObjectFunctionFactory {

		/**
		 * @param opts
		 * @param outFileName
		 * @param linearizer
		 * @param trainStateSetTrees
		 * @param regularize
		 * @param newSigma
		 * @return
		 */
		public ParsingObjectiveFunction newParsingObjectiveFunction(Options opts,
				String outFileName, Linearizer linearizer,
				StateSetTreeList trainStateSetTrees, int regularize, double newSigma);

	}
	public static class Options {

		@Option(name = "-out", usage = "Output File for Grammar")
		public String outFileName;

		@Option(name = "-outDir", usage = "Output Directory for Grammar")
		public String outDir;

		@Option(name = "-path", usage = "Path to Corpus")
		public String path = null;

		@Option(name = "-SMcycles", usage = "The number of split&merge iterations (Default: 6)")
		public int numSplits = 6;

		@Option(name = "-mergingPercentage", usage = "Merging percentage (Default: 0.0)")
		public double mergingPercentage = 0;

		@Option(name = "-baseline", usage = "Just read of the MLE baseline grammar")
		public boolean baseline = false;

		@Option(name = "-treebank", usage = "Language:  WSJ, CHNINESE, GERMAN, CONLL, SINGLEFILE (Default: ENGLISH)")
		public TreeBankType treebank = TreeBankType.WSJ;

		@Option(name = "-splitMaxIt", usage = "Maximum number of EM iterations after splitting (Default: 50)")
		public int splitMaxIterations = 100;

		@Option(name = "-splitMinIt", usage = "Minimum number of EM iterations after splitting (Default: 50)")
		public int splitMinIterations = 50;

		@Option(name = "-mergeMaxIt", usage = "Maximum number of EM iterations after merging (Default: 20)")
		public int mergeMaxIterations = 20;

		@Option(name = "-mergeMinIt", usage = "Minimum number of EM iterations after merging (Default: 20)")
		public int mergeMinIterations = 20;

		@Option(name = "-di", usage = "The number of allowed iterations in which the validation likelihood drops. (Default: 6)")
		public int di = 6;

		@Option(name = "-trfr", usage = "The fraction of the training corpus to keep (Default: 1.0)\n")
		public double trainingFractionToKeep = 1.0;

		@Option(name = "-filter", usage = "Filter rules with prob below this threshold (Default: 1.0e-30)")
		public double filter = 1.0e-30;

		@Option(name = "-smooth", usage = "Type of grammar smoothing used.")
		public String smooth = "NoSmoothing";
		
		@Option(name = "-b", usage = "LEFT/RIGHT Binarization (Default: RIGHT)")
		public Binarization binarization = Binarization.RIGHT;

 		@Option(name = "-noSplit", usage = "Don't split - just load and continue training an existing grammar (true/false) (Default:false)")
		public boolean noSplit = false;

 		@Option(name = "-initializeZero", usage = "Initialize conditional weights with zero")
		public boolean initializeZero = false;

 		@Option(name = "-in", usage = "Input File for Grammar")
		public String inFile = null;

		@Option(name = "-randSeed", usage = "Seed for random number generator")
		public int randSeed = 8;

		@Option(name = "-sep", usage = "Set merging threshold for grammar and lexicon separately (Default: false)")
		public boolean separateMergingThreshold = false;

		@Option(name = "-hor", usage = "Horizontal Markovization (Default: 0)")
		public int horizontalMarkovization = 0;

		@Option(name = "-sub", usage = "Number of substates to split (Default: 1)")
		public int nSubStates = 1;

		@Option(name = "-ver", usage = "Vertical Markovization (Default: 1)")
		public int verticalMarkovization = 1;

		@Option(name = "-v", usage = "Verbose/Quiet (Default: Quiet)\n")
		public boolean verbose = false;

		@Option(name = "-r", usage = "Level of Randomness at init (Default: 1)\n")
		public double randomization = 1.0;

		@Option(name = "-sm1", usage = "Lexicon smoothing parameter 1")
		public double smoothingParameter1 = 0.5;

		@Option(name = "-sm2", usage = "Lexicon smoothing parameter 2)")
		public double smoothingParameter2 = 0.1;
		
		@Option(name = "-rare", usage = "Rare word threshold (Default 4)")
		public int rare = 4;

		@Option(name = "-spath", usage = "Whether or not to store the best path info (true/false) (Default: true)")
		public boolean findClosedUnaryPaths = true;

		@Option(name = "-unkT", usage = "Threshold for unknown words (Default: 5)")
		public int unkThresh = 5;
		
		@Option(name = "-doConditional", usage = "Do conditional training")
		public boolean doConditional = false;
		 
		@Option(name = "-regularize", usage = "Regularize during optimization: 0-no regularization, 1-l1, 2-l2")
		public int regularize = 0;

		@Option(name = "-onlyMerge", usage = "Do only a conditional merge")
		public boolean onlyMerge = false;

		@Option(name = "-sigma", usage = "Regularization coefficient")
		public double sigma = 1.0;
		
		@Option(name = "-cons", usage = "File with constraints")
		public String cons = null; 

		@Option(name = "-nProcess", usage = "Distribute on that many cores")
		public int nProcess = 1;

		@Option(name = "-doNOTprojectConstraints", usage = "Do NOT project constraints")
		public boolean doNOTprojectConstraints = false;
		
		@Option(name = "-section", usage = "Which section of the corpus to process.")
		public String section = "train";

		@Option(name = "-outputLog", usage = "Print output to this file rather than STDOUT.")
		public String outputLog = null;

		@Option(name = "-maxL", usage = "Skip sentences which are longer than this.")
		public int maxL = 10000;

		@Option(name = "-nChunks", usage = "Store constraints in that many files.")
		public int nChunks = 1;

		@Option(name = "-logT", usage = "Log threshold for pruning")
		public double logT = -10;
		
		@Option(name = "-lasso", usage="Start of by regularizing less and make the regularization stronger with time")
		public boolean lasso = false;

		@Option(name = "-hierarchical", usage="Use hierarchical rules")
		public boolean hierarchical = false;

		@Option(name = "-keepGoldTreeAlive", usage="Don't prune the gold train when computing constraints")
		public boolean keepGoldTreeAlive = false;

		@Option(name = "-flattenParameters", usage="Flatten parameters to reduce overconfidence")
		public double flattenParameters = 1.0;

		@Option(name = "-usePosteriorTraining", usage="Adam's new objective function")
		public boolean usePosteriorTraining = false;
		
		@Option(name = "-dontLoad", usage="Don't load anything from the pipeline")
		public boolean dontLoad = false;

		@Option(name = "-predefinedMaxSplit", usage="Use predifined number of subcategories")
		public boolean predefinedMaxSplit = false;

		@Option(name = "-collapseUnaries", usage="Dont throw away trees with unaries, just collapse the unary chains")
		public boolean collapseUnaries = false;

		@Option(name = "-connectedLexicon", usage="Score each word with the sum of its score and its signature score")
		public boolean connectedLexicon = false;

		@Option(name = "-adaptive", usage="Use adpatively refined rules")
		public boolean adaptive = false;

		@Option(name = "-checkDerivative", usage="Check the derivative of the objective function against an estimate with finite difference")
		public boolean checkDerivative = false;

		@Option(name = "-initRandomness", usage="Amount of randomness to initialize the grammar with")
		public double initRandomness = 1.0;
		
		@Option(name = "-markUnaryParents", usage="Filter all training trees with any unaries (other than lexical and ROOT productions)")
		public boolean markUnaryParents = false;
		
		@Option(name = "-filterAllUnaries", usage="Mark any unary parent with a ^u")
		public boolean filterAllUnaries = false;
		
		@Option(name = "-filterStupidFrickinWHNP", usage="Temp hack!")
		public boolean filterStupidFrickinWHNP = false;
		
		@Option(name = "-initializeDir", usage="Temp hack!")
		public String initializeDir = null;
		
		@Option(name = "-allPosteriorsWeight", usage="Weight for the all posteriors regularizer")
		public double allPosteriorsWeight = 0.0;
		
		@Option(name="-dontSaveGrammarsAfterEachIteration")
		public static boolean dontSaveGrammarsAfterEachIteration = false;
		
		@Option(name="-hierarchicalChart")
		public static boolean hierarchicalChart = false;
	
		@Option(name="-testAll", usage="Test grammars after each iteration, proceed by splitting the best")
		public boolean testAll = false;

		@Option(name="-lockGrammar", usage="Lock grammar weights, learn only span feature weights")
		public static boolean lockGrammar = false;
		@Option(name="-featurizedLexicon", usage="Use featurized lexicon (no fixed signature classes")
		public boolean featurizedLexicon = false;

		@Option(name = "-spanFeatures", usage="Use span features")
		public boolean spanFeatures = false;

	

		@Option(name="-useFirstAndLast", usage="Use first and last span words as span features")
		public static boolean useFirstAndLast = false;
		@Option(name="-usePreviousAndNext", usage="Use previous and next span words as span features")
		public static boolean usePreviousAndNext = false;
		@Option(name="-useBeginAndEndPairs", usage="Use begin and end word-pairs as span features")
		public static boolean useBeginAndEndPairs = false;
		@Option(name="-useSyntheticClass", usage="Distiguish between real and synthetic constituents")
		public static boolean useSyntheticClass = false;
		@Option(name="-usePunctuation", usage="Use punctuation cues")
		public static boolean usePunctuation = false;
		@Option(name="-minFeatureFrequency", usage="Use punctuation cues")
		public static int minFeatureFrequency = 0;
		@Option(name = "-lbfgsHistorySize", usage = "Max size of L-BFGS history (use -1 for defaults)")
		public int lbfgsHistorySize = -1;
		
		//-spanFeatures -usePunctuation -useSyntheticClass -useFirstAndLast -usePreviousAndNext -useBeginAndEndPairs
	}

	private static ParsingObjectFunctionFactory parsingObjectFunctionFactory = new ParsingObjectFunctionFactory()
	{

		public ParsingObjectiveFunction newParsingObjectiveFunction(Options opts,
				String outFileName, Linearizer linearizer,
				StateSetTreeList trainStateSetTrees, int regularize, double newSigma) {
			return ConditionalTrainer.newParsingObjectiveFunction(opts, outFileName,
					linearizer,
					trainStateSetTrees, regularize, newSigma);
		}
		
	};
	
	public static void setParsingObjectiveFunctionFactory(
			ParsingObjectFunctionFactory fact) {
		parsingObjectFunctionFactory = fact;
	}
	
	public static void main(String[] args) {
		
		OptionParser optParser = new OptionParser(Options.class);
		Options opts = (Options) optParser.parse(args, false);
		// provide feedback on command-line arguments
		System.out.println("Calling ConditionalTrainer with " + optParser.getPassedInOptions());

    
    String path = opts.path;
//    int lang = opts.lang;
    System.out.println("Loading trees from "+path+" and using language "+opts.treebank);
           
    double trainingFractionToKeep = opts.trainingFractionToKeep;
    
    int maxSentenceLength = opts.maxL;
    System.out.println("Will remove sentences with more than "+maxSentenceLength+" words.");
    
    Binarization binarization = opts.binarization; 
    System.out.println("Using "+ binarization.name() + " binarization.");// and "+annotateString+".");

    double randomness = opts.randomization;
    System.out.println("Using a randomness value of "+randomness);
    
    String outFileName = opts.outFileName;
    if (outFileName==null) {
    	System.out.println("Output File name is required.");
    	System.exit(-1);
    }
    else System.out.println("Using grammar output file "+outFileName+".");
    
    GrammarTrainer.VERBOSE = opts.verbose;
    GrammarTrainer.RANDOM = new Random(opts.randSeed);
    System.out.println("Random number generator seeded at "+opts.randSeed+".");

    boolean manualAnnotation = false;
    boolean baseline = opts.baseline;
    boolean noSplit = opts.noSplit;
    int numSplitTimes = opts.numSplits;
    if (baseline) numSplitTimes = 0;
    String splitGrammarFile = opts.inFile;
    int allowedDroppingIters = opts.di;

    int maxIterations = opts.splitMaxIterations;
    int minIterations = opts.splitMinIterations;
    if (minIterations>0)
    	System.out.println("I will do at least "+minIterations+" iterations.");

    double[] smoothParams = {opts.smoothingParameter1,opts.smoothingParameter2};
    System.out.println("Using smoothing parameters "+smoothParams[0]+" and "+smoothParams[1]);
    
    if (opts.connectedLexicon) System.out.println("Using connected lexicon.");
    if (opts.featurizedLexicon) System.out.println("Using featuized lexicon.");
    
//    boolean allowMoreSubstatesThanCounts = false;
    boolean findClosedUnaryPaths = opts.findClosedUnaryPaths;

    Corpus corpus = new Corpus(path,opts.treebank,trainingFractionToKeep,false);
    List> trainTrees = Corpus.binarizeAndFilterTrees(corpus
				.getTrainTrees(), opts.verticalMarkovization,
				opts.horizontalMarkovization, maxSentenceLength, binarization, manualAnnotation,GrammarTrainer.VERBOSE, opts.markUnaryParents);
		List> validationTrees = Corpus.binarizeAndFilterTrees(corpus
				.getValidationTrees(), opts.verticalMarkovization,
				opts.horizontalMarkovization, maxSentenceLength, binarization, manualAnnotation,GrammarTrainer.VERBOSE, opts.markUnaryParents);
    Numberer tagNumberer =  Numberer.getGlobalNumberer("tags");
    
    if (opts.collapseUnaries) System.out.println("Collpasing unary chains.");
    if (trainTrees!=null)trainTrees = Corpus.filterTreesForConditional(trainTrees, opts.filterAllUnaries,opts.filterStupidFrickinWHNP,opts.collapseUnaries);
    if (validationTrees!=null) validationTrees = Corpus.filterTreesForConditional(validationTrees,opts.filterAllUnaries,opts.filterStupidFrickinWHNP,opts.collapseUnaries);
    int nTrees = trainTrees.size();
    System.out.println("There are "+nTrees+" trees in the training set.");

//	List> devTrees = Corpus.binarizeAndFilterTrees(corpus
//			.getDevTestingTrees(), opts.verticalMarkovization,
//			opts.horizontalMarkovization, maxSentenceLength, binarization, manualAnnotation,GrammarTrainer.VERBOSE, opts.markUnaryParents);
//
//    		for (Tree t : devTrees){
//	System.out.println(t);
//}

		double filter = opts.filter;

    short nSubstates = (short)opts.nSubStates;
    short[] numSubStatesArray = initializeSubStateArray(trainTrees, validationTrees, tagNumberer, nSubstates);
    if (baseline) {
    	short one = 1;
    	Arrays.fill(numSubStatesArray, one);
    	System.out.println("Training just the baseline grammar (1 substate for all states)");
    	randomness = 0.0f;
    }
      	
    if (GrammarTrainer.VERBOSE){
	    for (int i=0; i0){
    	System.out.println("Will merge "+(int)(mergingPercentage*100)+"% of the splits in each round.");
    	System.out.println("The threshold for merging lexical and phrasal categories will be set separately: "+separateMergingThreshold);
    }
    
    StateSetTreeList trainStateSetTrees = new StateSetTreeList(trainTrees, numSubStatesArray, false, tagNumberer);
    StateSetTreeList validationStateSetTrees = new StateSetTreeList(validationTrees, numSubStatesArray, false, tagNumberer);//deletePC);

    
    // replaces rare words with their signatures
    if (!(opts.connectedLexicon)||!opts.doConditional||opts.unkThresh<0){
    	System.out.println("Replacing words which have been seen less than "+opts.unkThresh+" times with their signature.");
    	Corpus.replaceRareWords(trainStateSetTrees,new SimpleLexicon(numSubStatesArray,-1), Math.abs(opts.unkThresh));
    }


    if (splitGrammarFile!=null)
  		maxLexicon.labelTrees(trainStateSetTrees);

    
    if (splitGrammarFile!=null) lexicon = maxLexicon;

  	if (splitGrammarFile!=null && spanPredictor==null	 && opts.spanFeatures){
  		System.out.println("Adding a span predictor since there was none!");
    	spanPredictor = new SpanPredictor(maxLexicon.nWords, trainStateSetTrees, tagNumberer, maxLexicon.wordIndexer);
			linearizer = new HiearchicalAdaptiveLinearizer(maxGrammar, maxLexicon, spanPredictor, maxGrammar.finalLevel);
  	}

    
    // get rid of the old trees
    trainTrees = null;
    validationTrees = null;
    corpus = null;
    System.gc();

    
    // If we're training without loading a split grammar, then we run once without splitting.
    if (splitGrammarFile==null) {
			int n = 0;
	  	grammar = new Grammar(numSubStatesArray, findClosedUnaryPaths, new NoSmoothing(), null, filter);
	    lexicon = new SimpleLexicon(numSubStatesArray,-1,smoothParams, new NoSmoothing(),filter, trainStateSetTrees);
	    
			boolean secondHalf = false;
			for (Tree stateSetTree : trainStateSetTrees) {
				secondHalf = (n++>nTrees/2.0); 
				lexicon.trainTree(stateSetTree, randomness, null, secondHalf,false,opts.rare);
				grammar.tallyUninitializedStateSetTree(stateSetTree);
			}
			lexicon.optimize();
			grammar.optimize(randomness);
			// System.out.println(grammar);
			boolean noUnaryChains = true;
			Grammar grammar2 = grammar.copyGrammar(noUnaryChains);
			SimpleLexicon lexicon2 = lexicon.copyLexicon();
    	System.out.println("Known word cut-off at "+opts.unkThresh+" occurences.");


    	if (opts.adaptive){
    		System.out.println("Using hierarchical adaptive grammar and lexicon.");
      	grammar2 = new HierarchicalAdaptiveGrammar(grammar2);
    		lexicon2 = (opts.featurizedLexicon) ?
    				new HierarchicalFullyConnectedAdaptiveLexiconWithFeatures(numSubStatesArray,-1,smoothParams, new NoSmoothing(), trainStateSetTrees, opts.unkThresh):
    					new HierarchicalFullyConnectedAdaptiveLexicon(numSubStatesArray,-1,smoothParams, new NoSmoothing(), trainStateSetTrees, opts.unkThresh);
		    if (opts.spanFeatures)
		    	spanPredictor = new SpanPredictor(lexicon2.nWords, trainStateSetTrees, tagNumberer, lexicon2.wordIndexer);

				linearizer = new HiearchicalAdaptiveLinearizer(grammar2, lexicon2, spanPredictor, 0);
    	} else if (opts.connectedLexicon&&opts.doConditional) {
    		lexicon2 = new HierarchicalFullyConnectedLexicon(numSubStatesArray,-1,smoothParams, new NoSmoothing(), trainStateSetTrees, opts.unkThresh);
    		linearizer = new DefaultLinearizer(grammar2, lexicon2, spanPredictor);
    	} else {
    		linearizer = new DefaultLinearizer(grammar2, lexicon2, spanPredictor);
    	}
    		
			if (opts.initializeZero) System.out.println("Initializing weigths with zero!");
			
			Random rand = GrammarTrainer.RANDOM;
  	  double[] init = linearizer.getLinearizedWeights();
  	  if (opts.initializeZero) {
//  	  	Arrays.fill(init, 0);
  	  	for (int i=0; i0){
	  	System.out.println("Regularizing with sigma="+sigma);
	  }

	  LBFGSMinimizer minimizer = null;
    

	  int maxIter = (opts.noSplit) ? 2 : 4;
  	for (int it=1; it= 0)
					minimizer.setMaxHistorySize(opts.lbfgsHistorySize);
    	double newSigma = sigma;
    	if (opts.lasso && !opts.noSplit){
    		newSigma = sigma + 3 - it;
    		System.out.println("The regularization parameter for this round will be: "+newSigma);
    	}
    	if (it==1) {
    		
				objective = parsingObjectFunctionFactory.newParsingObjectiveFunction(opts, outFileName, linearizer, trainStateSetTrees, regularize, newSigma);
    		minimizer.setMinIteratons(15);
    	} else {
  		  minimizer.setMinIteratons(5);
    	}
    	objective.setSigma(newSigma);
    	
    	double[] weights = objective.getCurrentWeights();
    	
    	
    	if (it == 1 && opts.checkDerivative)
    	{
    		System.out.print("\nChecking derivative: ");
    		double f = objective.valueAt(weights);
    		double[] deriv = objective.derivativeAt(weights);
    		double[] fDif = deriv.clone();
    		final double h = 1e-4;
    		for (int i = 0; i < 1; ++i)
    		{
    			double[] newWeights = weights.clone();
    			newWeights[i] += h;
    			double fplush = objective.valueAt(newWeights);
    			double finiteDif = (fplush - f) / h;
    			if (finiteDif - deriv[i] > 0.1)
    			{
    				System.out.println("Derivative is whack!");
    			}
    			fDif[i] = finiteDif;
    		}
    		System.out.println("done");
    	}
    	System.out.print("\nChecking weights: ");
    	int invalid = 0;
      for (int i=0; i0){
	    		iter += 1;
	    		System.out.println("Beginning iteration "+(iter-1)+":");
	
	  			// 1) Compute the validation likelihood of the previous iteration
	  			System.out.print("Calculating validation likelihood...");
	  			double validationLikelihood = calculateLogLikelihood(previousGrammar, previousLexicon, validationStateSetTrees);  // The validation LL of previousGrammar/previousLexicon
	  			System.out.println("done: "+validationLikelihood);
	  			
	  			// 2) Perform the E step while computing the training likelihood of the previous iteration
	  			System.out.print("Calculating training likelihood...");
	  			grammar = new Grammar(grammar.numSubStates, grammar.findClosedPaths, grammar.smoother, grammar, grammar.threshold);
	//  			lexicon = new SimpleLexicon(grammar.numSubStates,	SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, null, new NoSmoothing(), opts.unkThresh);
	  			lexicon = maxLexicon.copyLexicon();
	  			boolean updateOnlyLexicon = false;
	  			double trainingLikelihood = doOneEStep(previousGrammar,previousLexicon,grammar,lexicon,trainStateSetTrees,updateOnlyLexicon,opts.rare);  // The training LL of previousGrammar/previousLexicon
	  			System.out.println("done: "+trainingLikelihood);
	
	  			// 3) Perform the M-Step
	  			lexicon.optimize(); // M Step   
	  			grammar.optimize(0); // M Step
	  			
	  			// 4) Check whether previousGrammar/previousLexicon was in fact better than the best
	  			if(iter= maxLikelihood) {
	  				maxLikelihood = validationLikelihood;
	  				maxGrammar = previousGrammar;
	  				maxLexicon = previousLexicon;
	  				droppingIter = 0;
	  			} else { droppingIter++; }
	
	  			// 5) advance the 'pointers'
	    		previousGrammar = grammar;
	     		previousLexicon = lexicon;
    		}
    	} while ((droppingIter < allowedDroppingIters) && (!baseline) && (iter maxLikelihood) {
      maxLikelihood = validationLikelihood;
      maxGrammar = previousGrammar;
      maxLexicon = previousLexicon;
    }
    
//    System.out.println(lexicon);
//    System.out.println(grammar);
    
    ParserData pData = new ParserData(maxLexicon, maxGrammar, null, Numberer.getNumberers(), numSubStatesArray, opts.verticalMarkovization, opts.horizontalMarkovization, binarization);
    System.out.println("Saving grammar to "+outFileName+".");
    System.out.println("It gives a validation data log likelihood of: "+maxLikelihood);
    if (pData.Save(outFileName)) System.out.println("Saving successful.");
    else System.out.println("Saving failed!");
    
    //System.exit(0);
  }


	/**
	 * @param opts
	 * @param outFileName
	 * @param lexicon
	 * @param grammar
	 * @param trainStateSetTrees
	 * @param regularize
	 * @param sigma
	 * @return
	 */
	private static ParsingObjectiveFunction newParsingObjectiveFunction(
			Options opts, String outFileName, Linearizer linearizer, StateSetTreeList trainStateSetTrees, int regularize, double sigma) {
		return /*opts.usePosteriorTraining? new PosteriorTrainingObjectiveFunction(linearizer, trainStateSetTrees, sigma, 
				regularize, opts.boostIncorrect, opts.cons, opts.nProcess, outFileName, 
				opts.doGEM, opts.doNOTprojectConstraints, opts.allPosteriorsWeight): */new ParsingObjectiveFunction(linearizer, trainStateSetTrees, sigma, 
				regularize, opts.cons, opts.nProcess, outFileName, opts.doNOTprojectConstraints, opts.connectedLexicon);
	}


	/**
	 * @param previousGrammar
	 * @param previousLexicon
	 * @param grammar
	 * @param lexicon
	 * @param trainStateSetTrees
	 * @return
	 */
	public static double doOneEStep(Grammar previousGrammar, Lexicon previousLexicon, Grammar grammar, Lexicon lexicon, StateSetTreeList trainStateSetTrees,
			boolean updateOnlyLexicon, int unkThreshold) {
		boolean secondHalf = false;
		ArrayParser parser = new ArrayParser(previousGrammar,previousLexicon);
		double trainingLikelihood = 0;
		int n = 0;
		int nTrees = trainStateSetTrees.size();
		for (Tree stateSetTree : trainStateSetTrees) {
			secondHalf = (n++>nTrees/2.0); 
			boolean noSmoothing = true, debugOutput = false;
			parser.doInsideOutsideScores(stateSetTree,noSmoothing,debugOutput);                    // E Step
			double ll = stateSetTree.getLabel().getIScore(0);
			ll = Math.log(ll) + (100*stateSetTree.getLabel().getIScale());//System.out.println(stateSetTree);
			if ((Double.isInfinite(ll) || Double.isNaN(ll))) {
				if (GrammarTrainer.VERBOSE){
					System.out.println("Training sentence "+n+" is given "+ll+" log likelihood!");
					System.out.println("Root iScore "+ stateSetTree.getLabel().getIScore(0)+" scale "+stateSetTree.getLabel().getIScale());
				}
			}
			else {
			  lexicon.trainTree(stateSetTree, -1, previousLexicon, secondHalf,noSmoothing,unkThreshold);
				if (!updateOnlyLexicon) grammar.tallyStateSetTree(stateSetTree, previousGrammar);      // E Step
				trainingLikelihood  += ll;  // there are for some reason some sentences that are unparsable 
			}
		}
		return trainingLikelihood;
	}


	/**
	 * @param maxGrammar
	 * @param maxLexicon
	 * @param validationStateSetTrees
	 * @return
	 */
	public static double calculateLogLikelihood(Grammar maxGrammar, Lexicon maxLexicon, StateSetTreeList validationStateSetTrees) {
		ArrayParser parser = new ArrayParser(maxGrammar, maxLexicon);
		int unparsable = 0;
		double maxLikelihood = 0;
		for (Tree stateSetTree : validationStateSetTrees) {
			parser.doInsideScores(stateSetTree,false,false,null);  // Only inside scores are needed here
			double ll = stateSetTree.getLabel().getIScore(0);
			ll = Math.log(ll) + (100*stateSetTree.getLabel().getIScale());
			if (Double.isInfinite(ll) || Double.isNaN(ll)) { 
				unparsable++;
				//printBadLLReason(stateSetTree, lexicon);
			}
			else maxLikelihood += ll;  // there are for some reason some sentences that are unparsable 
		}
//		if (unparsable>0) System.out.print("Number of unparsable trees: "+unparsable+".");
		return maxLikelihood;
	}


	/**
	 * @param stateSetTree
	 */
	public static void printBadLLReason(Tree stateSetTree, SophisticatedLexicon lexicon) {
		System.out.println(stateSetTree.toString());
		boolean lexiconProblem = false;
		List words = stateSetTree.getYield();
		Iterator wordIterator = words.iterator();
		for (StateSet stateSet : stateSetTree.getPreTerminalYield()) {
			String word = wordIterator.next().getWord();
			boolean lexiconProblemHere = true;
			for (int i = 0; i < stateSet.numSubStates(); i++) {
				double score = stateSet.getIScore(i);
				if (!(Double.isInfinite(score) || Double.isNaN(score))) {
					lexiconProblemHere = false;
				}
			}
			if (lexiconProblemHere) {
				System.out.println("LEXICON PROBLEM ON STATE " + stateSet.getState()+" word "+word);
				System.out.println("  word "+lexicon.wordCounter.getCount(stateSet.getWord()));
				for (int i=0; i> trees, boolean verbose) {
    double likelihood = 0, l=0;
    for (Tree tree : trees) {
      l = tree.getLabel().getIScore(0);
      if (verbose) System.out.println("LL is "+l+".");
      if (Double.isInfinite(l) || Double.isNaN(l)){
        System.out.println("LL is not finite.");
      }
      else {
        likelihood += l;
      }
    }
    return likelihood;
  }
  
  
  /**
   * This updates the inside-outside probabilities for the list of trees using the parser's
   * doInsideScores and doOutsideScores methods.
   * 
   * @param trees A list of binarized, annotated StateSet Trees.
   * @param parser The parser to score the trees.
   */
  public static void updateStateSetTrees (List> trees, ArrayParser parser) {
    for (Tree tree : trees) {
      parser.doInsideOutsideScores(tree,false,false);
    }
  }


  /**
   * Convert a single Tree[String] to Tree[StateSet]
   * 
   * @param tree
   * @param numStates
   * @param tagNumberer
   * @return
   */
  
  public static short[] initializeSubStateArray(List> trainTrees,
			List> validationTrees, Numberer tagNumberer, short nSubStates){
//			boolean dontSplitTags) {
		// first generate unsplit grammar and lexicon
    short[] nSub = new short[2];
    nSub[0] = 1;
    nSub[1] = nSubStates;

    // do the validation set so that the numberer sees all tags and we can
		// allocate big enough arrays
		// note: although this variable is never read, this constructor adds the
		// validation trees into the tagNumberer as a side effect, which is
		// important
    StateSetTreeList trainStateSetTrees = new StateSetTreeList(trainTrees, nSub, true, tagNumberer);
    @SuppressWarnings("unused")
		StateSetTreeList validationStateSetTrees = new StateSetTreeList(validationTrees, nSub, true, tagNumberer);

    StateSetTreeList.initializeTagNumberer(trainTrees, tagNumberer);
    StateSetTreeList.initializeTagNumberer(validationTrees, tagNumberer);
    
    short numStates = (short)tagNumberer.total();
    short[] nSubStateArray = new short[numStates];
  	Arrays.fill(nSubStateArray, nSubStates);
  	//System.out.println("Everything is split in two except for the root.");
  	nSubStateArray[0] = 1; // that's the ROOT
    return nSubStateArray;
  }



  public static boolean[][][][][] loadDataNoZip(String fileName) {
  	boolean[][][][][] data = null;
    try {
      FileInputStream fis = new FileInputStream(fileName); // Load from file
//      GZIPInputStream gzis = new GZIPInputStream(fis); // Compressed
      ObjectInputStream in = new ObjectInputStream(fis); // Load objects
      data = (boolean[][][][][])in.readObject(); // Read the mix of grammars
      in.close(); // And close the stream.
    } catch (IOException e) {
      System.out.println("IOException\n"+e);
      return null;
    } catch (ClassNotFoundException e) {
      System.out.println("Class not found!");
      return null;
    }
    return data;
  }
  
	public static boolean saveDataNoZip(boolean[][][][][] data, String fileName){
    try {
      //here's some code from online; it looks good and gzips the output!
      //  there's a whole explanation at http://www.ecst.csuchico.edu/~amk/foo/advjava/notes/serial.html
      // Create the necessary output streams to save the scribble.
      FileOutputStream fos = new FileOutputStream(fileName); // Save to file
//      GZIPOutputStream gzos = new GZIPOutputStream(fos); // Compressed
      ObjectOutputStream out = new ObjectOutputStream(fos); // Save objects
      out.writeObject(data); // Write the mix of grammars
      out.flush(); // Always flush the output.
      out.close(); // And close the stream.
    } catch (IOException e) {
      System.out.println("IOException: "+e);
      return false;
    }
    return true;
  }

  private static final double TOL = 1e-5;
	protected static boolean matches(double x, double y) {
    return (Math.abs(x - y) / (Math.abs(x) + Math.abs(y) + 1e-10) < TOL);
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy