All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.lexparser.ChineseSimWordAvgDepGrammar Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.parser.lexparser;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.NumberFormat;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;


/**
 * A Dependency grammar that smoothes by averaging over similar words.
 *
 * @author Galen Andrew
 * @author Pi-Chuan Chang
 */
@SuppressWarnings("deprecation")
public class ChineseSimWordAvgDepGrammar extends MLEDependencyGrammar  {

  /** A logger for this class */
  private static Redwood.RedwoodChannels log = Redwood.channels(ChineseSimWordAvgDepGrammar.class);

  private static final long serialVersionUID = -1845503582705055342L;

  private static final double simSmooth = 10.0;

  private static final String argHeadFile = "simWords/ArgHead.5";
  private static final String headArgFile = "simWords/HeadArg.5";
  private Map, List>> simArgMap;
  private Map, List>> simHeadMap;

  private static final boolean debug = true;

  private static final boolean verbose = false;
  //private static final double MIN_PROBABILITY = Math.exp(-100.0);

  public ChineseSimWordAvgDepGrammar(TreebankLangParserParams tlpParams, boolean directional, boolean distance, boolean coarseDistance, boolean basicCategoryTagsInDependencyGrammar, Options op, Index wordIndex, Index tagIndex) {
    super(tlpParams, directional, distance, coarseDistance, basicCategoryTagsInDependencyGrammar, op, wordIndex, tagIndex);

    simHeadMap = getMap(headArgFile);
    simArgMap = getMap(argHeadFile);
  }

  public Map, List>> getMap(String filename) {
    Map, List>> hashMap = Generics.newHashMap();
    try {
      BufferedReader wordMapBReader = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "UTF-8"));

      String wordMapLine;
      Pattern linePattern = Pattern.compile("sim\\((.+)/(.+):(.+)/(.+)\\)=(.+)");
      while ((wordMapLine = wordMapBReader.readLine()) != null) {
        Matcher m = linePattern.matcher(wordMapLine);
        if (!m.matches()) {
          log.info("Ill-formed line in similar word map file: " + wordMapLine);
          continue;
        }

        Pair iTW = new Pair<>(wordIndex.addToIndex(m.group(1)), m.group(2));
        double score = Double.parseDouble(m.group(5));

        List> tripleList = hashMap.get(iTW);
        if (tripleList == null) {
          tripleList = new ArrayList<>();
          hashMap.put(iTW, tripleList);
        }

        tripleList.add(new Triple<>(wordIndex.addToIndex(m.group(3)), m.group(4), score));
      }
    } catch (IOException e) {
      throw new RuntimeException("Problem reading similar words file!");
    }

    return hashMap;
  }

  @Override
  public double scoreTB(IntDependency dependency) {
    //return op.testOptions.depWeight * Math.log(probSimilarWordAvg(dependency));
    return op.testOptions.depWeight * Math.log(probTBwithSimWords(dependency));
  }

  public void setLex(Lexicon lex) {
    this.lex = lex;
  }

  private ClassicCounter statsCounter = new ClassicCounter<>();

  public void dumpSimWordAvgStats() {
    log.info("SimWordAvg stats:");
    log.info(statsCounter);
  }

  /*
  ** An alternative kind of smoothing.
  ** The first one is "probSimilarWordAvg" implemented by Galen
  ** This one is trying to modify "probTB" in MLEDependencyGrammar using the simWords list we have
  ** -pichuan
  */

  private double probTBwithSimWords(IntDependency dependency) {
    boolean leftHeaded = dependency.leftHeaded && directional;
    IntTaggedWord unknownHead = new IntTaggedWord(-1, dependency.head.tag);
    IntTaggedWord unknownArg = new IntTaggedWord(-1, dependency.arg.tag);
    if (verbose) {
      System.out.println("Generating " + dependency);
    }

    short distance = dependency.distance;
    // int hW = dependency.head.word;
    // int aW = dependency.arg.word;
    IntTaggedWord aTW = dependency.arg;
    // IntTaggedWord hTW = dependency.head;

    double pb_stop_hTWds = getStopProb(dependency);

    boolean isRoot = rootTW(dependency.head);
    if (dependency.arg.word == -2) {
      // did we generate stop?
      if (isRoot) {
        return 0.0;
      }
      return pb_stop_hTWds;
    }

    double pb_go_hTWds = 1.0 - pb_stop_hTWds;

    if (isRoot) {
      pb_go_hTWds = 1.0;
    }

    // generate the argument

    int valenceBinDistance = valenceBin(distance);

    // KEY:
    // c_     count of
    // p_     MLE prob of
    // pb_    MAP prob of
    // a      arg
    // h      head
    // T      tag
    // W      word
    // d      direction
    // ds     distance

    IntDependency temp = new IntDependency(dependency.head, dependency.arg, leftHeaded, valenceBinDistance);
    double c_aTW_hTWd = argCounter.getCount(temp);
    temp = new IntDependency(dependency.head, unknownArg, leftHeaded, valenceBinDistance);
    double c_aT_hTWd = argCounter.getCount(temp);
    temp = new IntDependency(dependency.head, wildTW, leftHeaded, valenceBinDistance);
    double c_hTWd = argCounter.getCount(temp);
    temp = new IntDependency(unknownHead, dependency.arg, leftHeaded, valenceBinDistance);
    double c_aTW_hTd = argCounter.getCount(temp);
    temp = new IntDependency(unknownHead, unknownArg, leftHeaded, valenceBinDistance);
    double c_aT_hTd = argCounter.getCount(temp);
    temp = new IntDependency(unknownHead, wildTW, leftHeaded, valenceBinDistance);
    double c_hTd = argCounter.getCount(temp);
    temp = new IntDependency(wildTW, dependency.arg, false, -1);
    double c_aTW = argCounter.getCount(temp);
    temp = new IntDependency(wildTW, unknownArg, false, -1);
    double c_aT = argCounter.getCount(temp);

    // do the magic
    double p_aTW_hTd = (c_hTd > 0.0 ? c_aTW_hTd / c_hTd : 0.0);
    double p_aT_hTd = (c_hTd > 0.0 ? c_aT_hTd / c_hTd : 0.0);
    double p_aTW_aT = (c_aTW > 0.0 ? c_aTW / c_aT : 1.0);

    double pb_aTW_hTWd; // = (c_aTW_hTWd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smooth_aTW_hTWd);
    double pb_aT_hTWd = (c_aT_hTWd + smooth_aT_hTWd * p_aT_hTd) / (c_hTWd + smooth_aT_hTWd);

    double score; // = (interp * pb_aTW_hTWd + (1.0 - interp) * p_aTW_aT * pb_aT_hTWd) * pb_go_hTWds;


    /* smooth by simWords -pichuan */
    List> sim2arg = simArgMap.get(new Pair<>(dependency.arg.word, stringBasicCategory(dependency.arg.tag)));
    List> sim2head = simHeadMap.get(new Pair<>(dependency.head.word, stringBasicCategory(dependency.head.tag)));

    List simArg = new ArrayList<>();
    List simHead= new ArrayList<>();

    if (sim2arg != null) {
      for (Triple t : sim2arg) {
        simArg.add(t.first);
      }
    }

    if (sim2head != null) {
      for (Triple t : sim2head) {
        simHead.add(t.first);
      }
    }

    double cSim_aTW_hTd = 0;
    double cSim_hTd = 0;
    for (int h : simHead) {
      IntTaggedWord hWord = new IntTaggedWord(h, dependency.head.tag);
      temp = new IntDependency(hWord, dependency.arg, dependency.leftHeaded, dependency.distance);
      cSim_aTW_hTd += argCounter.getCount(temp);

      temp = new IntDependency(hWord, wildTW, dependency.leftHeaded, dependency.distance);
      cSim_hTd += argCounter.getCount(temp);
    }
    double pSim_aTW_hTd = (cSim_hTd > 0.0 ? cSim_aTW_hTd / cSim_hTd : 0.0);  // P(Wa,Ta|Th)

    if (debug) {
      //if (simHead.size() > 0 && cSim_hTd == 0.0) {
        if (pSim_aTW_hTd > 0.0) {
        //System.out.println("# simHead("+dependency.head.word+"-"+wordNumberer.object(dependency.head.word)+") =\t"+cSim_hTd);
          System.out.println(dependency+"\t"+pSim_aTW_hTd);
          //System.out.println(wordNumberer);
        }
    }


    //pb_aTW_hTWd = (c_aTW_hTWd + smooth_aTW_hTWd * pSim_aTW_hTd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smooth_aTW_hTWd + smooth_aTW_hTWd);

    //if (pSim_aTW_hTd > 0.0) {
    double smoothSim_aTW_hTWd = 17.7;
    double smooth_aTW_hTWd = 17.7*2;

    //smooth_aTW_hTWd = smooth_aTW_hTWd*2;
    pb_aTW_hTWd = (c_aTW_hTWd + smoothSim_aTW_hTWd * pSim_aTW_hTd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smoothSim_aTW_hTWd + smooth_aTW_hTWd);
    System.out.println(dependency);
    System.out.println(c_aTW_hTWd+" + "+ smoothSim_aTW_hTWd+" * "+pSim_aTW_hTd+" + "+smooth_aTW_hTWd+" * "+p_aTW_hTd);
    System.out.println("--------------------------------  = "+pb_aTW_hTWd);
    System.out.println(c_hTWd+" + "+ smoothSim_aTW_hTWd+" + "+smooth_aTW_hTWd);
    System.out.println();
    //}

    //pb_aT_hTWd = (c_aT_hTWd + smooth_aT_hTWd * p_aT_hTd) / (c_hTWd + smooth_aT_hTWd);

    score = (interp * pb_aTW_hTWd + (1.0 - interp) * p_aTW_aT * pb_aT_hTWd) * pb_go_hTWds;

    if (verbose) {
      NumberFormat nf = NumberFormat.getNumberInstance();
      nf.setMaximumFractionDigits(2);
      System.out.println("  c_aTW_hTWd: " + c_aTW_hTWd + "; c_aT_hTWd: " + c_aT_hTWd + "; c_hTWd: " + c_hTWd);
      System.out.println("  c_aTW_hTd: " + c_aTW_hTd + "; c_aT_hTd: " + c_aT_hTd + "; c_hTd: " + c_hTd);
      System.out.println("  Generated with pb_go_hTWds: " + nf.format(pb_go_hTWds) + " pb_aTW_hTWd: " + nf.format(pb_aTW_hTWd) + " p_aTW_aT: " + nf.format(p_aTW_aT) + " pb_aT_hTWd: " + nf.format(pb_aT_hTWd));
      System.out.println("  NoDist score: " + score);
    }

    if (op.testOptions.prunePunc && pruneTW(aTW)) {
      return 1.0;
    }

    if (Double.isNaN(score)) {
      score = 0.0;
    }

    //if (op.testOptions.rightBonus && ! dependency.leftHeaded)
    //  score -= 0.2;

    if (score < MIN_PROBABILITY) {
      score = 0.0;
    }

    return score;
  }



  private double probSimilarWordAvg(IntDependency dep) {
    double regProb = probTB(dep);
    statsCounter.incrementCount("total");

    List> sim2arg = simArgMap.get(new Pair<>(dep.arg.word, stringBasicCategory(dep.arg.tag)));
    List> sim2head = simHeadMap.get(new Pair<>(dep.head.word, stringBasicCategory(dep.head.tag)));

    if (sim2head == null && sim2arg == null) {
      return regProb;
    }

    double sumScores = 0, sumWeights = 0;

    if (sim2head == null) {
      statsCounter.incrementCount("aSim");
      for (Triple simArg : sim2arg) {
        //double weight = 1 - simArg.third;
        double weight = Math.exp(-50*simArg.third);
        for (int tag = 0, numT = tagIndex.size(); tag < numT; tag++) {
          if (!stringBasicCategory(tag).equals(simArg.second)) {
            continue;
          }
          IntTaggedWord tempArg = new IntTaggedWord(simArg.first, tag);
          IntDependency tempDep = new IntDependency(dep.head, tempArg, dep.leftHeaded, dep.distance);
          double probArg = Math.exp(lex.score(tempArg, 0, wordIndex.get(tempArg.word), null));
          if (probArg == 0.0) {
            continue;
          }
          sumScores += probTB(tempDep) * weight / probArg;
          sumWeights += weight;
        }
      }
    } else if (sim2arg == null) {
      statsCounter.incrementCount("hSim");
      for (Triple simHead : sim2head) {
        //double weight = 1 - simHead.third;
        double weight = Math.exp(-50*simHead.third);
        for (int tag = 0, numT = tagIndex.size(); tag < numT; tag++) {
          if (!stringBasicCategory(tag).equals(simHead.second)) {
            continue;
          }
          IntTaggedWord tempHead = new IntTaggedWord(simHead.first, tag);
          IntDependency tempDep = new IntDependency(tempHead, dep.arg, dep.leftHeaded, dep.distance);
          sumScores += probTB(tempDep) * weight;
          sumWeights += weight;
        }
      }
    } else {
      statsCounter.incrementCount("hSim");
      statsCounter.incrementCount("aSim");
      statsCounter.incrementCount("aSim&hSim");
      for (Triple simArg : sim2arg) {
        for (int aTag = 0, numT = tagIndex.size(); aTag < numT; aTag++) {
          if (!stringBasicCategory(aTag).equals(simArg.second)) {
            continue;
          }
          IntTaggedWord tempArg = new IntTaggedWord(simArg.first, aTag);
          double probArg = Math.exp(lex.score(tempArg, 0, wordIndex.get(tempArg.word), null));
          if (probArg == 0.0) {
            continue;
          }
          for (Triple simHead : sim2head) {
            for (int hTag = 0; hTag < numT; hTag++) {
              if (!stringBasicCategory(hTag).equals(simHead.second)) {
                continue;
              }
              IntTaggedWord tempHead = new IntTaggedWord(simHead.first, aTag);
              IntDependency tempDep = new IntDependency(tempHead, tempArg, dep.leftHeaded, dep.distance);
              //double weight = (1-simHead.third) * (1-simArg.third);
              double weight = Math.exp(-50*simHead.third) * Math.exp(-50*simArg.third);
              sumScores += probTB(tempDep) * weight / probArg;
              sumWeights += weight;
            }
          }
        }
      }
    }

    IntDependency temp = new IntDependency(dep.head, wildTW, dep.leftHeaded, dep.distance);
    double countHead = argCounter.getCount(temp);

    double simProb;
    if (sim2arg == null) {
      simProb = sumScores / sumWeights;
    } else {
      double probArg = Math.exp(lex.score(dep.arg, 0, wordIndex.get(dep.arg.word), null));
      simProb = probArg * sumScores / sumWeights;
    }

    if (simProb == 0) {
      statsCounter.incrementCount("simProbZero");
    }
    if (regProb == 0) {
      //      log.info("zero reg prob");
      statsCounter.incrementCount("regProbZero");
    }
    double smoothProb = (countHead * regProb + simSmooth * simProb) / (countHead + simSmooth);
    if (smoothProb == 0) {
      //      log.info("zero smooth prob");
      statsCounter.incrementCount("smoothProbZero");
    }

    return smoothProb;
  }

  private String stringBasicCategory(int tag) {
    return tlp.basicCategory(tagIndex.get(tag));
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy