All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.nndep.Util Maven / Gradle / Ivy

Go to download

Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.

There is a newer version: 3.9.2
Show newest version
package edu.stanford.nlp.parser.nndep;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.process.CoreLabelTokenFactory;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.logging.Redwood;

import java.util.*;
import java.io.*;


/**
 *
 *  Some utility functions for the neural network dependency parser.
 *
 *  @author Danqi Chen
 *  @author Jon Gauthier
 */

public class Util  {

  /** A logger for this class */
  private static final Redwood.RedwoodChannels log = Redwood.channels(Util.class);

  private Util() {} // static methods

  private static Random random;

  /**
   * Normalize word embeddings by setting mean = rMean, std = rStd
   */
  public static double[][] scaling(double[][] A, double rMean, double rStd) {
    int count = 0;
    double mean = 0.0;
    double std = 0.0;
    for (double[] aA : A)
      for (double v : aA) {
        count += 1;
        mean += v;
        std += v * v;
      }
    mean = mean / count;
    std = Math.sqrt(std / count - mean * mean);

    log.info("Scaling word embeddings:");
    log.info(String.format("(mean = %.2f, std = %.2f) -> (mean = %.2f, std = %.2f)", mean, std, rMean, rStd));

    double[][] rA = new double[A.length][A[0].length];
    for (int i = 0; i < rA.length; ++ i)
      for (int j = 0; j < rA[i].length; ++ j)
        rA[i][j] = (A[i][j] - mean) * rStd / std + rMean;
    return rA;
  }

  /**
   *  Normalize word embeddings by setting mean = 0, std = 1
   */
  public static double[][] scaling(double[][] A) {
    return scaling(A, 0.0, 1.0);
  }

  // return strings sorted by frequency, and filter out those with freq. less than cutOff.

  /**
   * Build a dictionary of words collected from a corpus.
   * 

* Filters out words with a frequency below the given {@code cutOff}. * * @return Words sorted by decreasing frequency, filtered to remove * any words with a frequency below {@code cutOff} */ public static List generateDict(List str, int cutOff) { Counter freq = new IntCounter<>(); for (String aStr : str) freq.incrementCount(aStr); List keys = Counters.toSortedList(freq, false); List dict = new ArrayList<>(); for (String word : keys) { if (freq.getCount(word) >= cutOff) dict.add(word); } return dict; } public static List generateDict(List str) { return generateDict(str, 1); } /** * @return Shared random generator used in this package */ static Random getRandom() { if (random != null) return random; else return getRandom(System.currentTimeMillis()); } /** * Set up shared random generator to use the given seed. * * @return Shared random generator object */ static Random getRandom(long seed) { random = new Random(seed); log.info(String.format("Random generator initialized with seed %d%n", seed)); return random; } public static List getRandomSubList(List input, int subsetSize) { int inputSize = input.size(); if (subsetSize > inputSize) subsetSize = inputSize; Random random = getRandom(); for (int i = 0; i < subsetSize; i++) { int indexToSwap = i + random.nextInt(inputSize - i); T temp = input.get(i); input.set(i, input.get(indexToSwap)); input.set(indexToSwap, temp); } return input.subList(0, subsetSize); } // TODO replace with GrammaticalStructure#readCoNLLGrammaticalStructureCollection public static void loadConllFile(String inFile, List sents, List trees, boolean unlabeled, boolean cPOS) { CoreLabelTokenFactory tf = new CoreLabelTokenFactory(false); try (BufferedReader reader = IOUtils.readerFromString(inFile)) { List sentenceTokens = new ArrayList<>(); DependencyTree tree = new DependencyTree(); for (String line : IOUtils.getLineIterable(reader, false)) { String[] splits = line.split("\t"); if (splits.length < 10) { if (sentenceTokens.size() > 0) { trees.add(tree); CoreMap sentence = new CoreLabel(); sentence.set(CoreAnnotations.TokensAnnotation.class, sentenceTokens); sents.add(sentence); tree = new DependencyTree(); sentenceTokens = new ArrayList<>(); } } else { String word = splits[1], pos = cPOS ? splits[3] : splits[4], depType = splits[7]; int head = -1; try { head = Integer.parseInt(splits[6]); } catch (NumberFormatException e) { continue; } CoreLabel token = tf.makeToken(word, 0, 0); token.setTag(pos); token.set(CoreAnnotations.CoNLLDepParentIndexAnnotation.class, head); token.set(CoreAnnotations.CoNLLDepTypeAnnotation.class, depType); sentenceTokens.add(token); if (!unlabeled) tree.add(head, depType); else tree.add(head, Config.UNKNOWN); } } } catch (IOException e) { throw new RuntimeIOException(e); } } public static void loadConllFile(String inFile, List sents, List trees) { loadConllFile(inFile, sents, trees, false, false); } public static void writeConllFile(String outFile, List sentences, List trees) { try { PrintWriter output = IOUtils.getPrintWriter(outFile); for (int i = 0; i < sentences.size(); i++) { CoreMap sentence = sentences.get(i); DependencyTree tree = trees.get(i); List tokens = sentence.get(CoreAnnotations.TokensAnnotation.class); for (int j = 1, size = tokens.size(); j <= size; ++j) { CoreLabel token = tokens.get(j - 1); output.printf("%d\t%s\t_\t%s\t%s\t_\t%d\t%s\t_\t_%n", j, token.word(), token.tag(), token.tag(), tree.getHead(j), tree.getLabel(j)); } output.println(); } output.close(); } catch (Exception e) { throw new RuntimeIOException(e); } } public static void printTreeStats(String str, List trees) { log.info(Config.SEPARATOR + ' ' + str); int nTrees = trees.size(); int nonTree = 0; int multiRoot = 0; int nonProjective = 0; for (DependencyTree tree : trees) { if (!tree.isTree()) ++nonTree; else { if (!tree.isProjective()) ++nonProjective; if (!tree.isSingleRoot()) ++multiRoot; } } log.info(String.format("#Trees: %d%n", nTrees)); log.info(String.format("%d tree(s) are illegal (%.2f%%).%n", nonTree, nonTree * 100.0 / nTrees)); log.info(String.format("%d tree(s) are legal but have multiple roots (%.2f%%).%n", multiRoot, multiRoot * 100.0 / nTrees)); log.info(String.format("%d tree(s) are legal but not projective (%.2f%%).%n", nonProjective, nonProjective * 100.0 / nTrees)); } public static void printTreeStats(List trees) { printTreeStats("", trees); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy