
com.googlecode.clearnlp.dependency.srl.SRLabeler Maven / Gradle / Ivy
/**
* Copyright (c) 2009-2012, Regents of the University of Colorado
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
package com.googlecode.clearnlp.dependency.srl;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;
import com.carrotsearch.hppc.IntOpenHashSet;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.classification.vector.StringFeatureVector;
import com.googlecode.clearnlp.dependency.DEPArc;
import com.googlecode.clearnlp.dependency.DEPLib;
import com.googlecode.clearnlp.dependency.DEPNode;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.feature.xml.FtrToken;
import com.googlecode.clearnlp.feature.xml.SRLFtrXml;
import com.googlecode.clearnlp.util.UTOutput;
import com.googlecode.clearnlp.util.map.Prob1DMap;
import com.googlecode.clearnlp.util.pair.IntIntPair;
import com.googlecode.clearnlp.util.pair.StringIntPair;
/**
* @since 1.0.0
* @author Jinho D. Choi ({@code [email protected]})
*/
public class SRLabeler extends AbstractSRLabeler
{
static public final int MODEL_SIZE = 2;
static public final int MODEL_LEFT = 0;
static public final int MODEL_RIGHT = 1;
static protected final int PATH_ALL = 0;
static protected final int PATH_UP = 1;
static protected final int PATH_DOWN = 2;
static protected final int SUBCAT_ALL = 0;
static protected final int SUBCAT_LEFT = 1;
static protected final int SUBCAT_RIGHT = 2;
static protected final String LB_NO_ARG = "N";
protected SRLFtrXml f_xml;
protected StringTrainSpace[] s_spaces;
protected StringModel[] s_models;
protected DEPTree d_tree;
protected int i_pred;
protected int i_arg;
protected int n_preds;
protected IntIntPair n_trans;
protected PrintStream f_trans;
protected StringIntPair[][] g_heads;
protected DEPNode[] lm_deps, rm_deps;
protected DEPNode[] ln_sibs, rn_sibs;
protected DEPNode d_lca;
protected IntOpenHashSet s_skip;
protected List l_argns;
protected Prob1DMap m_down, m_up;
protected Set s_down, s_up;
/** Constructs a semantic role labeler for collecting. */
public SRLabeler()
{
super(FLAG_LEXICA);
m_down = new Prob1DMap();
m_up = new Prob1DMap();
}
/** Constructs a semantic role labeler for training. */
public SRLabeler(SRLFtrXml xml, StringTrainSpace[] spaces, Set sDown, Set sUp)
{
super(FLAG_TRAIN);
f_xml = xml;
s_spaces = spaces;
s_down = sDown;
s_up = sUp;
}
/** Constructs a semantic role labeler for predicting. */
public SRLabeler(SRLFtrXml xml, StringModel[] models, Set sDown, Set sUp)
{
super(FLAG_PREDICT);
f_xml = xml;
s_models = models;
s_down = sDown;
s_up = sUp;
}
/** Constructs a semantic role labeler for bootstrapping. */
public SRLabeler(SRLFtrXml xml, StringModel[] models, StringTrainSpace[] spaces, Set sDown, Set sUp)
{
super(FLAG_BOOST);
f_xml = xml;
s_models = models;
s_spaces = spaces;
s_down = sDown;
s_up = sUp;
}
/** Constructs a semantic role labeler for demonstration. */
public SRLabeler(PrintStream fout)
{
super(FLAG_DEMO);
f_trans = fout;
}
public void saveModel(PrintStream fout) {}
/** Saves a semantic role labeling model to the specific output-stream. */
public void saveModel(PrintStream fout, int idx)
{
s_models[idx].save(fout);
}
public void saveDownSet(PrintStream fout)
{
UTOutput.printSet(fout, s_down);
}
public void saveUpSet(PrintStream fout)
{
UTOutput.printSet(fout, s_up);
}
/** @return the semantic role labeling models. */
public StringModel[] getModels()
{
return s_models;
}
/** Initializes the semantic role labeler given the specific dependency tree. */
public void init(DEPTree tree)
{
d_tree = tree;
i_pred = getNextPredId(0);
s_skip = new IntOpenHashSet();
l_argns = new ArrayList();
n_trans = new IntIntPair(0, 0);
n_preds = 0;
if (i_flag != FLAG_PREDICT)
g_heads = tree.getSHeads();
initArcs();
tree.clearSHeads();
// m_coverage = new IntObjectOpenHashMap();
}
/** @return the ID of the next predicate. */
private int getNextPredId(int prevId)
{
DEPNode pred = d_tree.getNextPredicate(prevId);
return (pred != null) ? pred.id : d_tree.size();
}
/** Initializes dependency arcs of all nodes. */
protected void initArcs()
{
int i, j, len, size = d_tree.size();
DEPNode curr, prev, next;
List deps;
DEPArc lmd, rmd;
lm_deps = new DEPNode[size];
rm_deps = new DEPNode[size];
ln_sibs = new DEPNode[size];
rn_sibs = new DEPNode[size];
d_tree.setDependents();
for (i=1; i i) rm_deps[i] = rmd.getNode();
for (j=1; j next.id)
rn_sibs[curr.id] = next;
}
}
}
private void collect(DEPTree tree)
{
DEPNode pred = tree.getNextPredicate(0);
DEPNode head;
tree.setDependents();
while (pred != null)
{
for (DEPArc arc : pred.getGrandDependents())
collectDown(pred, arc.getNode());
head = pred.getHead();
if (head != null) collectUp(pred, head.getHead());
pred = tree.getNextPredicate(pred.id);
}
}
private void collectDown(DEPNode pred, DEPNode arg)
{
if (arg.isArgumentOf(pred))
{
for (String path : getDUPathList(pred, arg.getHead()))
m_down.add(path);
}
for (DEPArc arc : arg.getDependents())
collectDown(pred, arc.getNode());
}
private void collectUp(DEPNode pred, DEPNode head)
{
if (head == null) return;
for (DEPArc arc : head.getDependents())
{
if (arc.getNode().isArgumentOf(pred))
{
for (String path : getDUPathList(head, pred))
m_up.add(path);
break;
}
}
collectUp(pred, head.getHead());
}
private String getDUPath(DEPNode top, DEPNode bottom)
{
return getPathAux(top, bottom, SRLFtrXml.F_DEPREL, SRLLib.DELIM_PATH_DOWN, true);
}
private List getDUPathList(DEPNode top, DEPNode bottom)
{
List paths = new ArrayList();
while (bottom != top)
{
paths.add(getDUPath(top, bottom));
bottom = bottom.getHead();
}
return paths;
}
public Set getDownSet(int cutoff)
{
return m_down.toSet(cutoff);
}
public Set getUpSet(int cutoff)
{
return m_up.toSet(cutoff);
}
/**
* Returns the number of transitions used for labeling the current dependency tree.
* @return the number of transitions used for labeling the current dependency tree.
*/
public IntIntPair getNumTransitions()
{
return n_trans;
}
public int getNumPredicates()
{
return n_preds;
}
/** Labels the dependency tree. */
public void label(DEPTree tree)
{
if (i_flag == FLAG_LEXICA)
{
collect(tree);
return;
}
init(tree);
labelAux();
if (i_flag == FLAG_DEMO)
f_trans.println();
}
private void labelAux()
{
int size = d_tree.size();
DEPNode pred;
while (i_pred < size)
{
pred = d_tree.get(i_pred);
n_trans.i1++;
s_skip .clear();
s_skip .add(i_pred);
s_skip .add(DEPLib.ROOT_ID);
l_argns.clear();
d_lca = pred;
do
{
labelAux(pred, d_lca);
d_lca = d_lca.getHead();
}
while (d_lca != null);// && (pred.isDependentOf(d_lca) || s_up.contains(getDUPath(d_lca, pred))));
n_preds++;
i_pred = getNextPredId(i_pred);
// m_coverage.put(i_pred, new IntOpenHashSet(s_skip));
}
}
/** Called by {@link SRLabeler#label(DEPTree)}. */
private void labelAux(DEPNode pred, DEPNode head)
{
if (!s_skip.contains(head.id))
{
i_arg = head.id;
addArgument(getLabel(getDirIndex()));
}
labelDown(pred, head.getDependents());
}
/** Called by {@link SRLabeler#labelAux(DEPNode, IntOpenHashSet)}. */
private void labelDown(DEPNode pred, List arcs)
{
DEPNode arg;
for (DEPArc arc : arcs)
{
arg = arc.getNode();
if (!s_skip.contains(arg.id))
{
i_arg = arg.id;
addArgument(getLabel(getDirIndex()));
if (i_pred == d_lca.id && s_down.contains(getDUPath(pred, arg)))
labelDown(pred, arg.getDependents());
}
}
}
private int getDirIndex()
{
return (i_arg < i_pred) ? MODEL_LEFT : MODEL_RIGHT;
}
private String getLabel(int idx)
{
StringFeatureVector vector = (i_flag != FLAG_DEMO) ? getFeatureVector(f_xml) : null;
String label = null;
if (i_flag == FLAG_TRAIN)
{
label = getGoldArgLabel();
s_spaces[idx].addInstance(label, vector);
}
else if (i_flag == FLAG_PREDICT)
{
label = getAutoLabel(idx, vector);
}
else if (i_flag == FLAG_BOOST)
{
s_spaces[idx].addInstance(getGoldArgLabel(), vector);
label = getAutoLabel(idx, vector);
}
else
label = getGoldArgLabel();
return label;
}
/** Called by {@link SRLabeler#getGoldLabel(byte)}. */
private String getGoldArgLabel()
{
for (StringIntPair head : g_heads[i_arg])
{
if (head.i == i_pred)
return head.s;
}
return LB_NO_ARG;
}
/** Called by {@link SRLabeler#getLabel(byte)}. */
private String getAutoLabel(int idx, StringFeatureVector vector)
{
return s_models[idx].predictBest(vector).label;
}
private void addArgument(String label)
{
s_skip.add(i_arg);
n_trans.i2++;
if (i_flag == FLAG_DEMO)
printState(label);
if (!label.equals(LB_NO_ARG))
{
DEPNode pred = d_tree.get(i_pred);
DEPNode arg = d_tree.get(i_arg);
arg.addSHead(pred, label);
if (SRLLib.isNumberedArgument(label))
l_argns.add(label);
}
}
private void printState(String label)
{
StringBuilder build = new StringBuilder();
build.append(i_pred);
build.append(" -");
build.append(label);
build.append("-> ");
build.append(i_arg);
f_trans.println(build.toString());
}
protected String getField(FtrToken token)
{
DEPNode node = getNode(token);
if (node == null) return null;
Matcher m;
if (token.isField(SRLFtrXml.F_FORM))
{
return node.form;
}
else if (token.isField(SRLFtrXml.F_LEMMA))
{
return node.lemma;
}
else if (token.isField(SRLFtrXml.F_POS))
{
return node.pos;
}
else if (token.isField(SRLFtrXml.F_DEPREL))
{
return node.getLabel();
}
else if (token.isField(SRLFtrXml.F_DISTANCE))
{
return getDistance(node);
}
else if ((m = SRLFtrXml.P_ARGN.matcher(token.field)).find())
{
int idx = l_argns.size() - Integer.parseInt(m.group(1)) - 1;
return (idx >= 0) ? l_argns.get(idx) : null;
}
else if ((m = SRLFtrXml.P_PATH.matcher(token.field)).find())
{
String type = m.group(1);
int dir = Integer.parseInt(m.group(2));
return getPath(type, dir);
}
else if ((m = SRLFtrXml.P_SUBCAT.matcher(token.field)).find())
{
String type = m.group(1);
int dir = Integer.parseInt(m.group(2));
return getSubcat(node, type, dir);
}
else if ((m = SRLFtrXml.P_FEAT.matcher(token.field)).find())
{
return node.getFeat(m.group(1));
}
else if ((m = SRLFtrXml.P_BOOLEAN.matcher(token.field)).find())
{
DEPNode pred = d_tree.get(i_pred);
int field = Integer.parseInt(m.group(1));
switch (field)
{
case 0: return (node.isDependentOf(pred)) ? token.field : null;
case 1: return (pred.isDependentOf(node)) ? token.field : null;
case 2: return (pred.isDependentOf(d_lca)) ? token.field : null;
case 3: return (pred == d_lca) ? token.field : null;
case 4: return (node == d_lca) ? token.field : null;
}
}
return null;
}
protected String[] getFields(FtrToken token)
{
DEPNode node = getNode(token);
if (node == null) return null;
if (token.isField(SRLFtrXml.F_DEPREL_SET))
{
return getDeprelSet(node.getDependents());
}
else if (token.isField(SRLFtrXml.F_GRAND_DEPREL_SET))
{
return getDeprelSet(node.getGrandDependents());
}
return null;
}
private String[] getDeprelSet(List deps)
{
if (deps.isEmpty()) return null;
Set set = new HashSet();
for (DEPArc arc : deps) set.add(arc.getLabel());
String[] fields = new String[set.size()];
set.toArray(fields);
return fields;
}
private String getDistance(DEPNode node)
{
int dist = Math.abs(i_pred - node.id);
if (dist <= 5) return "0";
else if (dist <= 10) return "1";
else if (dist <= 15) return "2";
else return "3";
}
private String getPath(String type, int dir)
{
DEPNode pred = d_tree.get(i_pred);
DEPNode arg = d_tree.get(i_arg);
if (dir == PATH_UP)
{
if (d_lca != pred)
return getPathAux(d_lca, pred, type, SRLLib.DELIM_PATH_UP, true);
}
else if (dir == PATH_DOWN)
{
if (d_lca != arg)
return getPathAux(d_lca, arg, type, SRLLib.DELIM_PATH_DOWN, true);
}
else
{
if (pred == d_lca)
return getPathAux(pred, arg, type, SRLLib.DELIM_PATH_DOWN, true);
else if (pred.isDescendentOf(arg))
return getPathAux(arg, pred, type, SRLLib.DELIM_PATH_UP, true);
else
{
String path = getPathAux(d_lca, pred, type, SRLLib.DELIM_PATH_UP, true);
path += getPathAux(d_lca, arg, type, SRLLib.DELIM_PATH_DOWN, false);
return path;
}
}
return null;
}
private String getPathAux(DEPNode top, DEPNode bottom, String type, String delim, boolean includeTop)
{
StringBuilder build = new StringBuilder();
DEPNode head = bottom;
int dist = 0;
do
{
if (type.equals(SRLFtrXml.F_POS))
{
build.append(delim);
build.append(head.pos);
}
else if (type.equals(SRLFtrXml.F_DEPREL))
{
build.append(delim);
build.append(head.getLabel());
}
else if (type.equals(SRLFtrXml.F_DISTANCE))
{
dist++;
}
head = head.getHead();
}
while (head != top);
if (type.equals(SRLFtrXml.F_POS))
{
if (includeTop)
{
build.append(delim);
build.append(top.pos);
}
}
else if (type.equals(SRLFtrXml.F_DISTANCE))
{
build.append(delim);
build.append(dist);
}
return build.length() == 0 ? null : build.toString();
}
private String getSubcat(DEPNode node, String type, int dir)
{
List deps = node.getDependents();
StringBuilder build = new StringBuilder();
int i, size = deps.size();
DEPNode dep;
if (dir == SUBCAT_LEFT)
{
for (i=0; i node.id) break;
getSubcatAux(build, dep, type);
}
}
else if (dir == SUBCAT_RIGHT)
{
for (i=size-1; i>=0; i--)
{
dep = deps.get(i).getNode();
if (dep.id < node.id) break;
getSubcatAux(build, dep, type);
}
}
else
{
for (i=0; i m_coverage;
public IntIntPair getArgCoverage(StringIntPair[][] gHeads)
{
IntIntPair p = new IntIntPair(0, 0);
int argId, size = gHeads.length;
StringIntPair[] preds;
for (argId=1; argId
© 2015 - 2025 Weber Informatics LLC | Privacy Policy