All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.PCFGLA.ParserConstrainer Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
package edu.berkeley.nlp.PCFGLA;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import edu.berkeley.nlp.syntax.SpanTree;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;

public class ParserConstrainer implements Callable {

	StateSetTreeList stateSetTrees;

	Grammar grammar;

	Lexicon lexicon;

	SpanPredictor spanPredictor;

	String outBaseName;

	double threshold;

	String consName;

	boolean keepGoldTreeAlive;

	boolean useHierarchicalParser;

	static int treesPerBlock;

	int myID;

	public ParserConstrainer(StateSetTreeList stateSetTrees, Grammar grammar,
			Lexicon lexicon, SpanPredictor spanPredictor, String outBaseName,
			double threshold, boolean keepGoldTreeAlive, int myID, String cons,
			boolean useHierarchicalParser) {

		this.stateSetTrees = stateSetTrees;
		this.grammar = grammar;
		this.lexicon = lexicon;
		this.spanPredictor = spanPredictor;
		this.outBaseName = outBaseName;
		this.threshold = threshold;
		this.consName = cons;
		this.keepGoldTreeAlive = keepGoldTreeAlive;
		this.myID = myID;
		this.useHierarchicalParser = useHierarchicalParser;
	}

	public static void main(String[] args) {
		OptionParser optParser = new OptionParser(ConditionalTrainer.Options.class);
		ConditionalTrainer.Options opts = (ConditionalTrainer.Options) optParser
				.parse(args, false);

		// provide feedback on command-line arguments
		System.out.println("Calling Constrainer with "
				+ optParser.getPassedInOptions());

		String path = opts.path;
		// int lang = opts.lang;
		System.out.println("Loading trees from " + path + " and using language "
				+ opts.treebank);
		String testSetString = opts.section;
		boolean devTestSet = testSetString.equals("dev");
		boolean finalTestSet = testSetString.equals("final");
		boolean trainTestSet = testSetString.equals("train");
		System.out.println(" using " + testSetString + " test set");

		Corpus corpus = new Corpus(path, opts.treebank,
				opts.trainingFractionToKeep, !trainTestSet);
		List> testTrees = null;
		if (devTestSet)
			testTrees = corpus.getDevTestingTrees();
		if (finalTestSet)
			testTrees = corpus.getFinalTestingTrees();
		if (trainTestSet)
			testTrees = corpus.getTrainTrees();

		
			boolean manualAnnotation = false;
			testTrees = Corpus.binarizeAndFilterTrees(testTrees,
					opts.verticalMarkovization, opts.horizontalMarkovization, opts.maxL,
					opts.binarization, manualAnnotation, GrammarTrainer.VERBOSE,
					opts.markUnaryParents);
	
		if (!devTestSet && opts.collapseUnaries)
			System.out.println("Collpasing unary chains.");
		testTrees = Corpus.filterTreesForConditional(testTrees,
				opts.filterAllUnaries, opts.filterStupidFrickinWHNP, !devTestSet
						&& opts.collapseUnaries);

		boolean keepGoldAlive = opts.keepGoldTreeAlive || trainTestSet;

		String inFileName = opts.inFile;
		System.out.println("Loading grammar from " + inFileName + ".");
		ParserData pData = ParserData.Load(inFileName);
		if (pData == null) {
			System.out
					.println("Failed to load grammar from file " + inFileName + ".");
			System.exit(1);
		}
		Grammar grammar = pData.getGrammar();
		grammar.splitRules();
		Lexicon lexicon = pData.getLexicon();
		lexicon.explicitlyComputeScores(grammar.finalLevel);
		SpanPredictor spanPredictor = pData.getSpanPredictor();

		if (opts.flattenParameters != 1.0) {
			System.out.println("Flattening parameters with exponent "
					+ opts.flattenParameters + " to reduce overconfidence.");
			grammar.removeUnlikelyRules(0, opts.flattenParameters);
			lexicon.removeUnlikelyTags(0, opts.flattenParameters);
		}

		Numberer.setNumberers(pData.getNumbs());
		Numberer tagNumberer = Numberer.getGlobalNumberer("tags");

		StateSetTreeList stateSetTrees = new StateSetTreeList(testTrees,
				grammar.numSubStates, false, tagNumberer);

		testTrees = null;
		String outBaseName = opts.outFileName;
		double threshold = Math.exp(opts.logT);

		int nChunks = opts.nChunks;
		int nTrees = stateSetTrees.size();
		System.out.println("There are " + nTrees + " trees in this set.");
		treesPerBlock = (int) Math.ceil(nTrees / (double) nChunks);
		System.out.println("Will store " + treesPerBlock
				+ " constraints per file, in " + nChunks + " files.");

		System.out.println("All states with posterior probability below "
				+ threshold + " will be pruned.");
		if (keepGoldAlive)
			System.out.println("But the gold tree will survive!");
		System.out.println("The constraints will be written to " + outBaseName
				+ ".");

		// split the trees into chunks
		StateSetTreeList[] trainingTrees = new StateSetTreeList[nChunks];

		for (int i = 0; i < nChunks; i++) {
			trainingTrees[i] = new StateSetTreeList();
		}
		int block = -1;
		int inBlock = 0;
		for (int i = 0; i < nTrees; i++) {
			if (i % treesPerBlock == 0) {
				block++;
				// System.out.println(inBlock);
				inBlock = 0;
			}
			trainingTrees[block].add(stateSetTrees.get(i));
			inBlock++;
		}
		for (int i = 0; i < nChunks; i++) {
			System.out.println("Process " + i + " has " + trainingTrees[i].size()
					+ " trees.");
		}
		stateSetTrees = null;
		ExecutorService pool = Executors.newFixedThreadPool(nChunks);
		Future[] submits = new Future[nChunks];

		ParserConstrainer thisThreadConstrainer = null;
		if (nChunks == 1)
			thisThreadConstrainer = new ParserConstrainer(trainingTrees[0], grammar,
					lexicon, spanPredictor, outBaseName, threshold, keepGoldAlive, 0,
					opts.cons, opts.hierarchicalChart);
		else {
			for (int i = 0; i < nChunks; i++) {
				ParserConstrainer constrainer = new ParserConstrainer(trainingTrees[i],
						grammar, lexicon, spanPredictor, outBaseName, threshold,
						keepGoldAlive, i, opts.cons, opts.hierarchicalChart);
				submits[i] = pool.submit(constrainer);
			}

			while (true) {
				boolean done = true;
				for (Future task : submits) {
					done &= task.isDone();
				}
				if (done)
					break;
			}
			// pool.shutdown();
		}
		try {
			PrintWriter outputData = (opts.outputLog == null) ? new PrintWriter(
					new OutputStreamWriter(System.out))
					: new PrintWriter(new OutputStreamWriter(new FileOutputStream(
							opts.outputLog), "UTF-8"), true);

			for (int i = 0; i < nChunks; i++) {
				StringBuilder sb = null;
				if (nChunks == 1) {
					sb = thisThreadConstrainer.call();
				} else {
					sb = (StringBuilder) submits[i].get();
				}
				outputData.print(sb.toString());
			}

			if (opts.outputLog != null) {
				outputData.flush();
				outputData.close();
			}
		} catch (ExecutionException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (InterruptedException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (UnsupportedEncodingException e1) {
			// TODO Auto-generated catch block
			e1.printStackTrace();
		} catch (FileNotFoundException e1) {
			// TODO Auto-generated catch block
			e1.printStackTrace();
		}

		System.out.println("Done computing constraints.");
	}

	/**
	 * 
	 * @param opts
	 */

	public StringBuilder call() {
		ConstrainedTwoChartsParser parser = (grammar instanceof HierarchicalAdaptiveGrammar) ? new ConstrainedHierarchicalTwoChartParser(
				grammar, lexicon, spanPredictor, grammar.finalLevel)
				: new ConstrainedTwoChartsParser(grammar, lexicon, spanPredictor);

		StringBuilder sb = new StringBuilder();
		int recentHistoryIndex = 0;
		// int sentenceNumber = 1;

		boolean[][][][][] recentHistory = new boolean[treesPerBlock][][][][];
		boolean[][][][][] myConstraints = null;
		boolean useCons = consName != null;

		if (useCons)
			myConstraints = loadData(consName + "-" + myID + ".data");
		boolean[][][][] cons = null;

		for (Tree testTree : stateSetTrees) {
			List yield = testTree.getYield();
			List testSentence = new ArrayList(yield.size());

			for (StateSet el : yield) {
				testSentence.add(el.getWord());
			}
			sb.append("\n" + (myID * treesPerBlock + recentHistoryIndex + 1)
					+ ". Length " + testSentence.size());

		
			if (useCons) {
				parser.projectConstraints(myConstraints[recentHistoryIndex], false);
				cons = myConstraints[recentHistoryIndex];
			}

			Tree sTree = null;
			if (keepGoldTreeAlive) {
				// System.out.println("keeping gold tree alive");
				sTree = testTree;
			}
			boolean[][][][] possibleStates = parser.getPossibleStates(testSentence,
					sTree, threshold, cons, sb);
			assert sTree == null || contains(possibleStates, sTree);

			if (useCons)
				myConstraints[recentHistoryIndex] = null;
			recentHistory[recentHistoryIndex++] = possibleStates;

			if (recentHistoryIndex % 1000 == 0)
				System.out.print(".");
			// sentenceNumber++;
			// if (recentHistoryIndex>0 && (recentHistoryIndex % treesPerBlock == 0))
			// {
			// String fileName = outBaseName+"-"+blockIndex+".data";
			// saveData(recentHistory, fileName);
			// blockIndex++;
			// if (useCons && sentenceNumber tree) {
		boolean[] bs = possibleStates[tree.getLabel().from][tree.getLabel().to][tree
				.getLabel().getState()];
		
		if (tree.isLeaf())
			return true;
		if (bs == null) {
			assert false;
			return false;
		}
		boolean hasTrue = false;
		for (boolean b : bs)
			hasTrue |= b;
		if (!hasTrue) {
			assert false;
			return false;
		}
		boolean allThere = true;
		for (Tree child : tree.getChildren()) {
			allThere &= contains(possibleStates, child);
		}
		return allThere;
	}

	public static boolean saveData(boolean[][][][][] data, String fileName) {
		try {
			// here's some code from online; it looks good and gzips the output!
			// there's a whole explanation at
			// http://www.ecst.csuchico.edu/~amk/foo/advjava/notes/serial.html
			// Create the necessary output streams to save the scribble.
			FileOutputStream fos = new FileOutputStream(fileName); // Save to file
			GZIPOutputStream gzos = new GZIPOutputStream(fos); // Compressed
			ObjectOutputStream out = new ObjectOutputStream(gzos); // Save objects
			out.writeObject(data); // Write the mix of grammars
			out.flush(); // Always flush the output.
			out.close(); // And close the stream.
			gzos.close();
			fos.close();
		} catch (IOException e) {
			System.out.println("IOException: " + e);
			return false;
		}
		return true;
	}

	public static boolean isGoldReachable(SpanTree gold,
			List[][] possibleStates, Numberer tagNumberer) {

		boolean reachable = true;

		reachable = possibleStates[gold.getStart()][gold.getEnd()]
				.contains(tagNumberer.number(gold.getLabel()));

		if (reachable && (!gold.isLeaf())) {

			for (SpanTree child : gold.getChildren()) {

				reachable = isGoldReachable(child, possibleStates, tagNumberer);

				if (!reachable)
					return false;

			}

		}

		if (!reachable) {

			System.out.println("Cannot reach state " + gold.getLabel()
					+ " spanning from " + gold.getStart() + " to " + gold.getEnd() + ".");

		}

		return reachable;

	}

	public static SpanTree convertToSpanTree(Tree tree) {

		if (tree.isPreTerminal()) {

			return new SpanTree(tree.getLabel());

		}

		if (tree.getChildren().size() > 2)
			System.out.println("Binarize properly first!");

		SpanTree spanTree = new SpanTree(tree.getLabel());

		List> spanChildren = new ArrayList>();
		for (Tree child : tree.getChildren()) {

			SpanTree spanChild = convertToSpanTree(child);

			spanChildren.add(spanChild);

		}

		spanTree.setChildren(spanChildren);

		return spanTree;

	}

	public static boolean[][][][][] loadData(String fileName) {
		boolean[][][][][] data = null;
		try {
			FileInputStream fis = new FileInputStream(fileName); // Load from file
			GZIPInputStream gzis = new GZIPInputStream(fis); // Compressed
			ObjectInputStream in = new ObjectInputStream(gzis); // Load objects
			data = (boolean[][][][][]) in.readObject(); // Read the mix of grammars
			in.close(); // And close the stream.
			gzis.close();
			fis.close();
		} catch (IOException e) {
			System.out.println("IOException\n" + e);
			return null;
		} catch (ClassNotFoundException e) {
			System.out.println("Class not found!");
			return null;
		}
		return data;
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy