edu.stanford.nlp.parser.lexparser.ChineseSimWordAvgDepGrammar Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-corenlp Show documentation
Show all versions of stanford-corenlp Show documentation
Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.
package edu.stanford.nlp.parser.lexparser;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.NumberFormat;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
/**
* A Dependency grammar that smoothes by averaging over similar words.
*
* @author Galen Andrew
* @author Pi-Chuan Chang
*/
@SuppressWarnings("deprecation")
public class ChineseSimWordAvgDepGrammar extends MLEDependencyGrammar {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(ChineseSimWordAvgDepGrammar.class);
private static final long serialVersionUID = -1845503582705055342L;
private static final double simSmooth = 10.0;
private static final String argHeadFile = "simWords/ArgHead.5";
private static final String headArgFile = "simWords/HeadArg.5";
private Map, List>> simArgMap;
private Map, List>> simHeadMap;
private static final boolean debug = true;
private static final boolean verbose = false;
//private static final double MIN_PROBABILITY = Math.exp(-100.0);
public ChineseSimWordAvgDepGrammar(TreebankLangParserParams tlpParams, boolean directional, boolean distance, boolean coarseDistance, boolean basicCategoryTagsInDependencyGrammar, Options op, Index wordIndex, Index tagIndex) {
super(tlpParams, directional, distance, coarseDistance, basicCategoryTagsInDependencyGrammar, op, wordIndex, tagIndex);
simHeadMap = getMap(headArgFile);
simArgMap = getMap(argHeadFile);
}
public Map, List>> getMap(String filename) {
Map, List>> hashMap = Generics.newHashMap();
try {
BufferedReader wordMapBReader = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "UTF-8"));
String wordMapLine;
Pattern linePattern = Pattern.compile("sim\\((.+)/(.+):(.+)/(.+)\\)=(.+)");
while ((wordMapLine = wordMapBReader.readLine()) != null) {
Matcher m = linePattern.matcher(wordMapLine);
if (!m.matches()) {
log.info("Ill-formed line in similar word map file: " + wordMapLine);
continue;
}
Pair iTW = new Pair<>(wordIndex.addToIndex(m.group(1)), m.group(2));
double score = Double.parseDouble(m.group(5));
List> tripleList = hashMap.get(iTW);
if (tripleList == null) {
tripleList = new ArrayList<>();
hashMap.put(iTW, tripleList);
}
tripleList.add(new Triple<>(wordIndex.addToIndex(m.group(3)), m.group(4), score));
}
} catch (IOException e) {
throw new RuntimeException("Problem reading similar words file!");
}
return hashMap;
}
@Override
public double scoreTB(IntDependency dependency) {
//return op.testOptions.depWeight * Math.log(probSimilarWordAvg(dependency));
return op.testOptions.depWeight * Math.log(probTBwithSimWords(dependency));
}
public void setLex(Lexicon lex) {
this.lex = lex;
}
private ClassicCounter statsCounter = new ClassicCounter<>();
public void dumpSimWordAvgStats() {
log.info("SimWordAvg stats:");
log.info(statsCounter);
}
/*
** An alternative kind of smoothing.
** The first one is "probSimilarWordAvg" implemented by Galen
** This one is trying to modify "probTB" in MLEDependencyGrammar using the simWords list we have
** -pichuan
*/
private double probTBwithSimWords(IntDependency dependency) {
boolean leftHeaded = dependency.leftHeaded && directional;
IntTaggedWord unknownHead = new IntTaggedWord(-1, dependency.head.tag);
IntTaggedWord unknownArg = new IntTaggedWord(-1, dependency.arg.tag);
if (verbose) {
System.out.println("Generating " + dependency);
}
short distance = dependency.distance;
// int hW = dependency.head.word;
// int aW = dependency.arg.word;
IntTaggedWord aTW = dependency.arg;
// IntTaggedWord hTW = dependency.head;
double pb_stop_hTWds = getStopProb(dependency);
boolean isRoot = rootTW(dependency.head);
if (dependency.arg.word == -2) {
// did we generate stop?
if (isRoot) {
return 0.0;
}
return pb_stop_hTWds;
}
double pb_go_hTWds = 1.0 - pb_stop_hTWds;
if (isRoot) {
pb_go_hTWds = 1.0;
}
// generate the argument
int valenceBinDistance = valenceBin(distance);
// KEY:
// c_ count of
// p_ MLE prob of
// pb_ MAP prob of
// a arg
// h head
// T tag
// W word
// d direction
// ds distance
IntDependency temp = new IntDependency(dependency.head, dependency.arg, leftHeaded, valenceBinDistance);
double c_aTW_hTWd = argCounter.getCount(temp);
temp = new IntDependency(dependency.head, unknownArg, leftHeaded, valenceBinDistance);
double c_aT_hTWd = argCounter.getCount(temp);
temp = new IntDependency(dependency.head, wildTW, leftHeaded, valenceBinDistance);
double c_hTWd = argCounter.getCount(temp);
temp = new IntDependency(unknownHead, dependency.arg, leftHeaded, valenceBinDistance);
double c_aTW_hTd = argCounter.getCount(temp);
temp = new IntDependency(unknownHead, unknownArg, leftHeaded, valenceBinDistance);
double c_aT_hTd = argCounter.getCount(temp);
temp = new IntDependency(unknownHead, wildTW, leftHeaded, valenceBinDistance);
double c_hTd = argCounter.getCount(temp);
temp = new IntDependency(wildTW, dependency.arg, false, -1);
double c_aTW = argCounter.getCount(temp);
temp = new IntDependency(wildTW, unknownArg, false, -1);
double c_aT = argCounter.getCount(temp);
// do the magic
double p_aTW_hTd = (c_hTd > 0.0 ? c_aTW_hTd / c_hTd : 0.0);
double p_aT_hTd = (c_hTd > 0.0 ? c_aT_hTd / c_hTd : 0.0);
double p_aTW_aT = (c_aTW > 0.0 ? c_aTW / c_aT : 1.0);
double pb_aTW_hTWd; // = (c_aTW_hTWd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smooth_aTW_hTWd);
double pb_aT_hTWd = (c_aT_hTWd + smooth_aT_hTWd * p_aT_hTd) / (c_hTWd + smooth_aT_hTWd);
double score; // = (interp * pb_aTW_hTWd + (1.0 - interp) * p_aTW_aT * pb_aT_hTWd) * pb_go_hTWds;
/* smooth by simWords -pichuan */
List> sim2arg = simArgMap.get(new Pair<>(dependency.arg.word, stringBasicCategory(dependency.arg.tag)));
List> sim2head = simHeadMap.get(new Pair<>(dependency.head.word, stringBasicCategory(dependency.head.tag)));
List simArg = new ArrayList<>();
List simHead= new ArrayList<>();
if (sim2arg != null) {
for (Triple t : sim2arg) {
simArg.add(t.first);
}
}
if (sim2head != null) {
for (Triple t : sim2head) {
simHead.add(t.first);
}
}
double cSim_aTW_hTd = 0;
double cSim_hTd = 0;
for (int h : simHead) {
IntTaggedWord hWord = new IntTaggedWord(h, dependency.head.tag);
temp = new IntDependency(hWord, dependency.arg, dependency.leftHeaded, dependency.distance);
cSim_aTW_hTd += argCounter.getCount(temp);
temp = new IntDependency(hWord, wildTW, dependency.leftHeaded, dependency.distance);
cSim_hTd += argCounter.getCount(temp);
}
double pSim_aTW_hTd = (cSim_hTd > 0.0 ? cSim_aTW_hTd / cSim_hTd : 0.0); // P(Wa,Ta|Th)
if (debug) {
//if (simHead.size() > 0 && cSim_hTd == 0.0) {
if (pSim_aTW_hTd > 0.0) {
//System.out.println("# simHead("+dependency.head.word+"-"+wordNumberer.object(dependency.head.word)+") =\t"+cSim_hTd);
System.out.println(dependency+"\t"+pSim_aTW_hTd);
//System.out.println(wordNumberer);
}
}
//pb_aTW_hTWd = (c_aTW_hTWd + smooth_aTW_hTWd * pSim_aTW_hTd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smooth_aTW_hTWd + smooth_aTW_hTWd);
//if (pSim_aTW_hTd > 0.0) {
double smoothSim_aTW_hTWd = 17.7;
double smooth_aTW_hTWd = 17.7*2;
//smooth_aTW_hTWd = smooth_aTW_hTWd*2;
pb_aTW_hTWd = (c_aTW_hTWd + smoothSim_aTW_hTWd * pSim_aTW_hTd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smoothSim_aTW_hTWd + smooth_aTW_hTWd);
System.out.println(dependency);
System.out.println(c_aTW_hTWd+" + "+ smoothSim_aTW_hTWd+" * "+pSim_aTW_hTd+" + "+smooth_aTW_hTWd+" * "+p_aTW_hTd);
System.out.println("-------------------------------- = "+pb_aTW_hTWd);
System.out.println(c_hTWd+" + "+ smoothSim_aTW_hTWd+" + "+smooth_aTW_hTWd);
System.out.println();
//}
//pb_aT_hTWd = (c_aT_hTWd + smooth_aT_hTWd * p_aT_hTd) / (c_hTWd + smooth_aT_hTWd);
score = (interp * pb_aTW_hTWd + (1.0 - interp) * p_aTW_aT * pb_aT_hTWd) * pb_go_hTWds;
if (verbose) {
NumberFormat nf = NumberFormat.getNumberInstance();
nf.setMaximumFractionDigits(2);
System.out.println(" c_aTW_hTWd: " + c_aTW_hTWd + "; c_aT_hTWd: " + c_aT_hTWd + "; c_hTWd: " + c_hTWd);
System.out.println(" c_aTW_hTd: " + c_aTW_hTd + "; c_aT_hTd: " + c_aT_hTd + "; c_hTd: " + c_hTd);
System.out.println(" Generated with pb_go_hTWds: " + nf.format(pb_go_hTWds) + " pb_aTW_hTWd: " + nf.format(pb_aTW_hTWd) + " p_aTW_aT: " + nf.format(p_aTW_aT) + " pb_aT_hTWd: " + nf.format(pb_aT_hTWd));
System.out.println(" NoDist score: " + score);
}
if (op.testOptions.prunePunc && pruneTW(aTW)) {
return 1.0;
}
if (Double.isNaN(score)) {
score = 0.0;
}
//if (op.testOptions.rightBonus && ! dependency.leftHeaded)
// score -= 0.2;
if (score < MIN_PROBABILITY) {
score = 0.0;
}
return score;
}
private double probSimilarWordAvg(IntDependency dep) {
double regProb = probTB(dep);
statsCounter.incrementCount("total");
List> sim2arg = simArgMap.get(new Pair<>(dep.arg.word, stringBasicCategory(dep.arg.tag)));
List> sim2head = simHeadMap.get(new Pair<>(dep.head.word, stringBasicCategory(dep.head.tag)));
if (sim2head == null && sim2arg == null) {
return regProb;
}
double sumScores = 0, sumWeights = 0;
if (sim2head == null) {
statsCounter.incrementCount("aSim");
for (Triple simArg : sim2arg) {
//double weight = 1 - simArg.third;
double weight = Math.exp(-50*simArg.third);
for (int tag = 0, numT = tagIndex.size(); tag < numT; tag++) {
if (!stringBasicCategory(tag).equals(simArg.second)) {
continue;
}
IntTaggedWord tempArg = new IntTaggedWord(simArg.first, tag);
IntDependency tempDep = new IntDependency(dep.head, tempArg, dep.leftHeaded, dep.distance);
double probArg = Math.exp(lex.score(tempArg, 0, wordIndex.get(tempArg.word), null));
if (probArg == 0.0) {
continue;
}
sumScores += probTB(tempDep) * weight / probArg;
sumWeights += weight;
}
}
} else if (sim2arg == null) {
statsCounter.incrementCount("hSim");
for (Triple simHead : sim2head) {
//double weight = 1 - simHead.third;
double weight = Math.exp(-50*simHead.third);
for (int tag = 0, numT = tagIndex.size(); tag < numT; tag++) {
if (!stringBasicCategory(tag).equals(simHead.second)) {
continue;
}
IntTaggedWord tempHead = new IntTaggedWord(simHead.first, tag);
IntDependency tempDep = new IntDependency(tempHead, dep.arg, dep.leftHeaded, dep.distance);
sumScores += probTB(tempDep) * weight;
sumWeights += weight;
}
}
} else {
statsCounter.incrementCount("hSim");
statsCounter.incrementCount("aSim");
statsCounter.incrementCount("aSim&hSim");
for (Triple simArg : sim2arg) {
for (int aTag = 0, numT = tagIndex.size(); aTag < numT; aTag++) {
if (!stringBasicCategory(aTag).equals(simArg.second)) {
continue;
}
IntTaggedWord tempArg = new IntTaggedWord(simArg.first, aTag);
double probArg = Math.exp(lex.score(tempArg, 0, wordIndex.get(tempArg.word), null));
if (probArg == 0.0) {
continue;
}
for (Triple simHead : sim2head) {
for (int hTag = 0; hTag < numT; hTag++) {
if (!stringBasicCategory(hTag).equals(simHead.second)) {
continue;
}
IntTaggedWord tempHead = new IntTaggedWord(simHead.first, aTag);
IntDependency tempDep = new IntDependency(tempHead, tempArg, dep.leftHeaded, dep.distance);
//double weight = (1-simHead.third) * (1-simArg.third);
double weight = Math.exp(-50*simHead.third) * Math.exp(-50*simArg.third);
sumScores += probTB(tempDep) * weight / probArg;
sumWeights += weight;
}
}
}
}
}
IntDependency temp = new IntDependency(dep.head, wildTW, dep.leftHeaded, dep.distance);
double countHead = argCounter.getCount(temp);
double simProb;
if (sim2arg == null) {
simProb = sumScores / sumWeights;
} else {
double probArg = Math.exp(lex.score(dep.arg, 0, wordIndex.get(dep.arg.word), null));
simProb = probArg * sumScores / sumWeights;
}
if (simProb == 0) {
statsCounter.incrementCount("simProbZero");
}
if (regProb == 0) {
// log.info("zero reg prob");
statsCounter.incrementCount("regProbZero");
}
double smoothProb = (countHead * regProb + simSmooth * simProb) / (countHead + simSmooth);
if (smoothProb == 0) {
// log.info("zero smooth prob");
statsCounter.incrementCount("smoothProbZero");
}
return smoothProb;
}
private String stringBasicCategory(int tag) {
return tlp.basicCategory(tagIndex.get(tag));
}
}