All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.neural.rnn.TopNGramRecord Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.neural.rnn;

import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;

import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.SentenceUtils;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;

/**
 * This class stores the best K ngrams for each class for a model.  It
 * does so by keeping priority queues of the best K trees seen,
 * eliminating duplicates.
 * 
* The interface is not too advanced at the moment. To use, first * create it with the number of classes and the counts of the ngrams, * and then add trees one at a time with countTree. The results can * be extracted with toString(). There is no way to directly access * the internal results. * * @author John Bauer */ public class TopNGramRecord { /** how many ngrams to store for each class */ final private int ngramCount; /** How many classes we are storing */ final private int numClasses; /** Longest ngram to keep */ final private int maximumLength; Map>> classToNGrams = Generics.newHashMap(); public TopNGramRecord(int numClasses, int ngramCount, int maximumLength) { this.numClasses = numClasses; this.ngramCount = ngramCount; this.maximumLength = maximumLength; for (int i = 0; i < numClasses; ++i) { Map> innerMap = Generics.newHashMap(); classToNGrams.put(i, innerMap); } } /** * Adds the tree and all its subtrees to the appropriate * PriorityQueues for each predicted class. */ public void countTree(Tree tree) { Tree simplified = simplifyTree(tree); for (int i = 0; i < numClasses; ++i) { countTreeHelper(simplified, i, classToNGrams.get(i)); } } /** * Remove everything but the skeleton, the predictions, and the labels */ private Tree simplifyTree(Tree tree) { CoreLabel newLabel = new CoreLabel(); newLabel.set(RNNCoreAnnotations.Predictions.class, RNNCoreAnnotations.getPredictions(tree)); newLabel.setValue(tree.label().value()); if (tree.isLeaf()) { return tree.treeFactory().newLeaf(newLabel); } List children = Generics.newArrayList(tree.children().length); for (int i = 0; i < tree.children().length; ++i) { children.add(simplifyTree(tree.children()[i])); } return tree.treeFactory().newTreeNode(newLabel, children); } private int countTreeHelper(Tree tree, int prediction, Map> ngrams) { if (tree.isLeaf()) { return 1; } int treeSize = 0; for (Tree child : tree.children()) { treeSize += countTreeHelper(child, prediction, ngrams); } if (maximumLength > 0 && treeSize > maximumLength) { return treeSize; } PriorityQueue queue = getPriorityQueue(treeSize, prediction, ngrams); // TODO: should we allow classes which aren't the best possible // class for this tree to be included in the results? if (!queue.contains(tree)) { queue.add(tree); } if (queue.size() > ngramCount) { queue.poll(); } return treeSize; } private PriorityQueue getPriorityQueue(int size, int prediction, Map> ngrams) { PriorityQueue queue = ngrams.get(size); if (queue != null) { return queue; } queue = new PriorityQueue<>(ngramCount + 1, scoreComparator(prediction)); ngrams.put(size, queue); return queue; } private Comparator scoreComparator(final int prediction) { return (tree1, tree2) -> { double score1 = RNNCoreAnnotations.getPredictions(tree1).get(prediction); double score2 = RNNCoreAnnotations.getPredictions(tree2).get(prediction); if (score1 < score2) { return -1; } else if (score1 > score2) { return 1; } else { return 0; } }; } public String toString() { StringBuilder result = new StringBuilder(); for (int prediction = 0; prediction < numClasses; ++prediction) { result.append("Best scores for class " + prediction + "\n"); Map> ngrams = classToNGrams.get(prediction); for (Map.Entry> entry : ngrams.entrySet()) { List trees = Generics.newArrayList(entry.getValue()); Collections.sort(trees, scoreComparator(prediction)); result.append(" Len " + entry.getKey() + "\n"); for (int i = trees.size() - 1; i >= 0; i--) { Tree tree = trees.get(i); result.append(" " + SentenceUtils.listToString(tree.yield()) + " [" + RNNCoreAnnotations.getPredictions(tree).get(prediction) + "]\n"); } } } return result.toString(); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy