All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.lexparser.LinearGrammarSmoother Maven / Gradle / Ivy

Go to download

Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.

There is a newer version: 3.9.2
Show newest version
package edu.stanford.nlp.parser.lexparser;

import java.util.Arrays;
import java.util.Set;

import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import java.util.function.Function;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;

/**
 * Implements linear rule smoothing a la Petrov et al. (2006).
 * 
 * @author Spence Green
 *
 */
public class LinearGrammarSmoother implements Function, Pair> {

  private static final boolean DEBUG = false;
  
  private double ALPHA = 0.01;

//  private static final String SYNTH_NODE_MARK = "@";
//  
//  private static final Pattern pContext = Pattern.compile("(\\|.+)$");
  
  // Do not include @ in this list! @ marks synthetic nodes!
  // Stole these from PennTreebankLanguagePack
  private final String[] annotationIntroducingChars = {"-", "=", "|", "#", "^", "~", "_"};
  private final Set annoteChars = Generics.newHashSet(Arrays.asList(annotationIntroducingChars));

  private final TrainOptions trainOptions;

  private final Index stateIndex;
  private final Index tagIndex;

  public LinearGrammarSmoother(TrainOptions trainOptions, Index stateIndex, Index tagIndex) {
    this.trainOptions = trainOptions;
    this.stateIndex = stateIndex;
    this.tagIndex = tagIndex;
  }

  /**
   * Destructively modifies the input and returns it as a convenience.
   */
  public Pair apply(Pair bgug) {
    
    ALPHA = trainOptions.ruleSmoothingAlpha;
    Counter symWeights = new ClassicCounter<>();
    Counter symCounts = new ClassicCounter<>();

    //Tally unary rules
    for (UnaryRule rule : bgug.first()) {
      if ( ! tagIndex.contains(rule.parent)) {
        updateCounters(rule, symWeights,symCounts);  
      }
    }

    //Tally binary rules
    for (BinaryRule rule : bgug.second()) {
      updateCounters(rule, symWeights, symCounts);  
    }

    //Compute smoothed rule scores, unary
    for (UnaryRule rule : bgug.first()) {
      if ( ! tagIndex.contains(rule.parent)) {
        rule.score = smoothRuleWeight(rule,symWeights,symCounts);
      }
    }

    //Compute smoothed rule scores, binary
    for (BinaryRule rule : bgug.second()) {
      rule.score = smoothRuleWeight(rule,symWeights,symCounts);  
    }
    
    if(DEBUG) {
      System.err.printf("%s: %d basic symbols in the grammar%n",this.getClass().getName(),symWeights.keySet().size());
      for(String s : symWeights.keySet())
        System.err.print(s + ",");
      System.err.println();
    }
    
    return bgug;
  }


  private void updateCounters(Rule rule, Counter symWeights,
      Counter symCounts) {
    String label = stateIndex.get(rule.parent());
    String basicCat = basicCategory(label);
    symWeights.incrementCount(basicCat, Math.exp(rule.score()));
    symCounts.incrementCount(basicCat);
  }


  private float smoothRuleWeight(Rule rule, Counter symWeights, Counter symCounts) {  
    String label = stateIndex.get(rule.parent());
    String basicCat = basicCategory(label);

    double pSum = symWeights.getCount(basicCat);
    double n = symCounts.getCount(basicCat);

    double pRule = Math.exp(rule.score());

    double pSmooth = (1.0 - ALPHA)*pRule;
    pSmooth += ALPHA*(pSum / n);
    pSmooth = Math.log(pSmooth);
    
    if(DEBUG)
      System.err.printf("%s\t%.4f%n", rule.toString(),pSmooth);

    return (float) pSmooth;
  }


  private int postBasicCategoryIndex(String category) {
    boolean sawAtZero = false;
    String seenAtZero = "\u0000";
    int i;
    for (i = 0; i < category.length(); i++) {
      String ch = category.substring(i,i+1);
      if (annoteChars.contains(ch)) {
        if (i == 0) {
          sawAtZero = true;
          seenAtZero = ch;
        } else if (sawAtZero && ch == seenAtZero) {
          sawAtZero = false;
        } else {
          break;
        }
      }
    }
    return i;
  }

  public String basicCategory(String category) {
    if (category == null) {
      return null;

    } else {
      String basicCat = category.substring(0, postBasicCategoryIndex(category));
      
//wsg2011: Tried adding the context of synthetic nodes to the basic category, but this lowered F1.
//      if(String.valueOf(category.charAt(0)).equals(SYNTH_NODE_MARK)) {
//        Matcher m = pContext.matcher(category);
//        if(m.find()) {
//          String context = m.group(1);
//          basicCat = basicCat + context;
//        
//        } else {
//          throw new RuntimeException(String.format("%s: Synthetic label lacks context: %s",this.getClass().getName(),category));
//        }
//      } 
      
      return basicCat;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy