![JAR search and dependency download from the Maven repository](/logo.png)
eu.fbk.twm.classifier.NGramOneExamplePerSenseExtractor Maven / Gradle / Ivy
The newest version!
/*
* Copyright (2014) Fondazione Bruno Kessler (http://www.fbk.eu/)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package eu.fbk.twm.classifier;
import eu.fbk.utils.core.core.HashMultiSet;
import eu.fbk.utils.core.core.MultiSet;
import eu.fbk.utils.lsa.BOW;
import eu.fbk.utils.math.Node;
import org.apache.commons.cli.*;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import eu.fbk.twm.utils.analysis.HardTokenizer;
import eu.fbk.twm.utils.analysis.Tokenizer;
import eu.fbk.twm.utils.CharacterTable;
import java.io.*;
import java.text.DecimalFormat;
import java.util.*;
/**
* Created with IntelliJ IDEA.
* User: giuliano
* Date: 1/6/14
* Time: 7:59 AM
* To change this template use File | Settings | File Templates.
*/
public class NGramOneExamplePerSenseExtractor extends OneExamplePerSenseExtractor {
/**
* Define a static logger variable so that it references the
* Logger instance named NGramOneExamplePerSenseExtractor
.
*/
static Logger logger = Logger.getLogger(NGramOneExamplePerSenseExtractor.class.getName());
DecimalFormat vf = new DecimalFormat("###0");
DecimalFormat sf = new DecimalFormat(".000");
public static final int DEFAULT_N_GRAM_LENGTH = 3;
public static final int DEFAULT_N_GRAM_SIZE = 1000000;
private static DecimalFormat df = new DecimalFormat("#.0000");
private NGramModel nGramModel;
protected int totalFreq;
public NGramOneExamplePerSenseExtractor(String outputFileName, int numThreads, NGramModel nGramModel)
throws IOException {
this(new File(outputFileName), numThreads, nGramModel);
}
public NGramOneExamplePerSenseExtractor(File outputFile, int numThreads, NGramModel nGramModel) throws IOException {
super(outputFile, numThreads);
this.nGramModel = nGramModel;
totalFreq = 0;
}
public void interactive() {
InputStreamReader reader = null;
BufferedReader myInput = null;
while (true) {
System.out.println("\nPlease write a key and type to continue (CTRL C to exit):");
reader = new InputStreamReader(System.in);
myInput = new BufferedReader(reader);
//String query = null;
try {
String query = myInput.readLine().toString();
String[] s = query.split("\t");
Sense[] senses = classify(s);
if (s.length == 5) {
String answerPage = "null";
if (senses != null && senses.length > 0) {
answerPage = senses[0].getPage();
}
if (answerPage.equals(s[0])) {
logger.info(answerPage + " = " + s[0]);
} else {
logger.warn(answerPage + " != " + s[0]);
}
}
} catch (IOException e) {
logger.error(e);
e.printStackTrace();
}
}
}
public void eval(String name) throws IOException {
LineNumberReader lnr = new LineNumberReader(new InputStreamReader(new FileInputStream(name), "UTF-8"));
double tp = 0, fp = 0, fn = 0, tn = 0, tot = 0;
double p = 0, r = 0, f1 = 0, acc = 0, correct = 0;
String line;
while ((line = lnr.readLine()) != null) {
String[] s = line.split("\t");
if (s.length == 5) {
long b = System.currentTimeMillis();
Sense[] senses = classify(s);
long e = System.currentTimeMillis();
String answerPage = "null";
if (senses != null && senses.length > 0) {
answerPage = senses[0].getPage();
}
if (answerPage.equals(s[0])) {
logger.info(answerPage + " = " + s[0]);
} else {
logger.warn(answerPage + " != " + s[0]);
}
if (answerPage == null) {
if (s[0].equals("Null_result")) {
tp++;
correct++;
} else {
fn++;
}
} else {
if (answerPage.equals(s[0])) {
tp++;
correct++;
} else {
fn++;
fp++;
}
}
tot++;
p = tp / (tp + fp);
r = tp / (tp + fn);
f1 = (2 * p * r) / (p + r);
System.out.println(
s[1] + "\t" + vf.format(tp) + "\t" + vf.format(fp) + "\t" + vf.format(fn) + "\t" + sf.format(p)
+ "\t" + sf.format(r) + "\t" + sf.format(f1) + "\t" + answerPage + "\t" + s[0] + "\t"
+ s[3] + "\t" + df.format(e - b));
}
}
}
private Sense[] classify(String[] s) {
String[] e = new String[9];
if (s.length == 2) {
e[3] = "null";
e[2] = s[1];
e[7] = s[0];
e[8] = ".";
} else if (s.length == 3) {
e[3] = "null";
e[2] = s[1];
e[7] = s[0];
e[8] = s[2];
} else if (s.length == 5) {
e[3] = s[0];
e[2] = s[3];
e[7] = s[2];
e[8] = s[4];
} else {
logger.error(s.length);
return null;
}
Tokenizer tokenizer = new HardTokenizer();
e[2] = tokenizer.tokenizedString(e[2]);
e[7] = tokenizer.tokenizedString(e[7]);
e[8] = tokenizer.tokenizedString(e[8]);
logger.debug("2='" + e[2] + "'");
logger.debug("7='" + e[7] + "'");
logger.debug("8='" + e[8] + "'");
logger.debug(Arrays.toString(s));
logger.debug(Arrays.toString(e));
List list = new ArrayList();
list.add(e);
Example example = new Example(e[3], list, e[2]);
logger.trace(Arrays.toString(example.getLocalContextVector()));
logger.trace(Arrays.toString(example.getBowVector()));
Example[] examples = map.get(e[2]);
Sense[] senses = null;
if (examples != null) {
senses = new ContextualSense[examples.length];
for (int i = 0; i < examples.length; i++) {
int freq = examples[i].getFreq();
if (freq > 2) {
logger.trace(i + "\t" + freq + "\t" + Arrays.toString(examples[i].getLocalContextVector()));
logger.trace(i + "\t" + freq + "\t" + Arrays.toString(examples[i].getBowVector()));
double localDot = 0;//Node.dot(example.getLocalContextVector(), examples[i].getLocalContextVector());
double bowDot = Node.dot(example.getBowVector(), examples[i].getBowVector());
logger.info(
i + "\t" + freq + "\t" + df.format(localDot) + "\t" + df.format(bowDot) + "\t" + examples[i]
.getPage());
senses[i] = new ContextualSense(examples[i].getPage(), examples[i].getFreq(), localDot, bowDot);
} else {
senses[i] = new ContextualSense(examples[i].getPage(), examples[i].getFreq(), 0, 0);
}
}
Arrays.sort(senses, new Comparator() {
@Override
public int compare(Sense sense, Sense sense2) {
double diff = sense.getCombo() - sense2.getCombo();
if (diff > 0) {
return -1;
} else if (diff < 0) {
return 1;
}
return 0;
}
}
);
for (int i = 0; i < senses.length && i < 5; i++) {
logger.info(i + "\t" + senses[i].getPage() + "\t" + senses[i].getCombo());
}
}
return senses;
}
@Override
public void end() {
//interactive();
}
Map map = new HashMap();
@Override
public void buildExamples(Map> senseMap, String form) {
Example[] examples = new Example[senseMap.size()];
Iterator it = senseMap.keySet().iterator();
String page;
List list;
for (int i = 0; it.hasNext(); i++) {
page = it.next();
list = senseMap.get(page);
//logger.debug(form + "\t" + page + "\t" + list.size());
examples[i] = new Example(page, list, form);
}
map.put(form, examples);
}
class Example implements Comparable {
private BOW bow;
private eu.fbk.utils.math.Node[] bowVector;
private eu.fbk.utils.math.Node[] localContextVector;
private int freq;
private String page;
private String form;
Example(String page, List list, String form) {
this.page = page;
this.form = form;
freq = list.size();
totalFreq += freq;
StringBuilder sb = new StringBuilder();
bow = buildBOW(list);
//logger.debug(bow);
bowVector = buildBowVector(bow);
//bowVector = new Node[0];
localContextVector = buildLocalContext(list);
//logger.debug(Arrays.toString(localContextVector));
if (normalized) {
Node.normalize(bowVector);
Node.normalize(localContextVector);
}
//logger.debug(toString());
}
private Node[] buildLocalContext(List list) {
//Map nodes = new TreeSet();
MultiSet multiSet = new HashMultiSet();
for (int i = 0; i < list.size(); i++) {
String[] s = list.get(i);
String leftContext = s[eu.fbk.twm.index.csv.OneExamplePerSenseExtractor.LEFT_CONTEXT_INDEX];
String form = s[eu.fbk.twm.index.csv.OneExamplePerSenseExtractor.FORM_INDEX];
int pos = leftContext.lastIndexOf(' ');
if (pos != -1) {
String t = leftContext.substring(pos + 1, leftContext.length());
//logger.trace(i + "\t" + t + "\t" + form + "\t" + page);
multiSet.add(t);
} else {
multiSet.add(leftContext);
}
}
//logger.debug(multiSet.toSortedMap());
SortedSet nodes = new TreeSet();
Iterator it = multiSet.iterator();
for (int i = 0; it.hasNext(); i++) {
String t = it.next();
Integer index = nGramModel.getIndex(t);
if (index != null) {
int value = multiSet.getFrequency(t);
//logger.trace(i + "\t'" + t + "'\t" + index + ":" + value + "\t" + form + "\t" + page);
nodes.add(new Node(index, value));
}
}
return nodes.toArray(new Node[nodes.size()]);
}
private boolean isCapitalized(String w) {
if (w.length() == 0) {
return false;
}
return Character.isUpperCase(w.charAt(0));
}
private boolean isUpperCase(String w) {
if (w.length() == 0) {
return false;
}
for (int i = 0; i < w.length(); i++) {
if (Character.isLowerCase(w.charAt(i))) {
return false;
}
}
return true;
}
Node[] getLocalContextVector() {
return localContextVector;
}
private BOW buildBOW(List list) {
bow = new BOW(tfType);
String[] s;
String[] leftContext;
String[] rightContext;
for (int i = 0; i < list.size(); i++) {
try {
s = list.get(i);
//todo: add toLowerCase
//leftContext = spacePattern.split(s[WikipediaExampleExtractor.LEFT_CONTEXT_INDEX].toLowerCase());
leftContext = spacePattern.split(s[eu.fbk.twm.index.csv.OneExamplePerSenseExtractor.LEFT_CONTEXT_INDEX]);
//rightContext = spacePattern.split(s[WikipediaExampleExtractor.RIGHT_CONTEXT_INDEX].toLowerCase());
rightContext = spacePattern.split(s[eu.fbk.twm.index.csv.OneExamplePerSenseExtractor.RIGHT_CONTEXT_INDEX]);
//logger.debug(i + "\t" + s[WikipediaExampleExtractor.LEFT_CONTEXT_INDEX] + "\t<"+form+">\t" + s[WikipediaExampleExtractor.RIGHT_CONTEXT_INDEX]);
extract(leftContext, bow);
extract(rightContext, bow);
} catch (Exception e) {
logger.error("Error at " + i);
logger.error(e);
}
}
return bow;
}
private Node[] buildBowVector(BOW bow) {
SortedSet nodes = new TreeSet();
Iterator it = bow.iterator();
for (int i = 0; it.hasNext(); i++) {
String form = it.next();
Integer index = nGramModel.getIndex(form);
Double value = nGramModel.getValue(form);
if (index != null && value != null) {
double tf = bow.tf(form);
//logger.debug(i + "\t" + form + "\t"+page+ "\t" + index + ":" + value + " * " + tf);
nodes.add(new Node(index, value * tf));
}
}
return nodes.toArray(new Node[nodes.size()]);
}
private String tokenizedForm(String[] tokenArray, int start, int end) {
StringBuilder sb = new StringBuilder();
sb.append(tokenArray[start]);
for (int i = start + 1; i <= end; i++) {
sb.append(CharacterTable.SPACE);
sb.append(tokenArray[i]);
}
return sb.toString();
}
public void extract(String[] tokenArray, BOW bow) {
int m = 0;
String tokenizedForm;
for (int i = 0; i < tokenArray.length; i++) {
m = i + nGramModel.getLength();// + 1;
if (m > tokenArray.length) {
m = tokenArray.length;
}
for (int j = i; j < m; j++) {
tokenizedForm = tokenizedForm(tokenArray, i, j);
bow.add(tokenizedForm);
}
}
}
String getForm() {
return form;
}
public String getPage() {
return page;
}
public BOW getBow() {
return bow;
}
public Node[] getBowVector() {
return bowVector;
}
public int getFreq() {
return freq;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(page);
sb.append(CharacterTable.HORIZONTAL_TABULATION);
sb.append((double) freq / totalFreq);
sb.append(CharacterTable.HORIZONTAL_TABULATION);
sb.append(Node.toString(localContextVector));
sb.append(CharacterTable.HORIZONTAL_TABULATION);
sb.append(Node.toString(bowVector));
return sb.toString();
}
@Override
public int compareTo(Example example) {
return example.getFreq() - freq;
}
}
private void writeExampleArray(Example[] examples) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < examples.length; i++) {
sb.append(examples[i].getForm());
sb.append(CharacterTable.HORIZONTAL_TABULATION);
sb.append(examples[i]);
sb.append(CharacterTable.LINE_FEED);
}
synchronized (this) {
senseWriter.print(sb.toString());
}
}
public static void main(String[] args) {
// java com.ml.test.net.HttpServerDemo
String logConfig = System.getProperty("log-config");
if (logConfig == null) {
logConfig = "configuration/log-config.txt";
}
//PropertyConfigurator.configure(logConfig);
Options options = new Options();
try {
Option inputFileNameOpt = OptionBuilder.withArgName("file").hasArg()
.withDescription("sorted form/page file").isRequired().withLongOpt("input").create("i");
Option outputFileNameOpt = OptionBuilder.withArgName("file").hasArg()
.withDescription("one sense per example file").isRequired().withLongOpt("output").create("o");
Option tfOpt = OptionBuilder.withArgName("FUNC").hasArg().withDescription(
"term frequency function; FUNC is " + BOW.RAW_TERM_FREQUENCY + "=`"
+ BOW.labels[BOW.RAW_TERM_FREQUENCY] + BOW.BOOLEAN_TERM_FREQUENCY + "=`"
+ BOW.labels[BOW.BOOLEAN_TERM_FREQUENCY] + "'," + BOW.LOGARITHMIC_TERM_FREQUENCY + "=`"
+ BOW.labels[BOW.LOGARITHMIC_TERM_FREQUENCY] + "'," + BOW.AUGMENTED_TERM_FREQUENCY + "=`"
+ BOW.labels[BOW.AUGMENTED_TERM_FREQUENCY] + " (default is "
+ BOW.DEFAULT_TERM_FREQUENCY_TYPE + ")").withLongOpt("tf").create();
Option stopwordsFileNameOpt = OptionBuilder.withArgName("file").hasArg().withDescription("stopwords")
.isRequired().withLongOpt("stopwords").create();
Option numFormOpt = OptionBuilder.withArgName("int").hasArg()
.withDescription("maximum number of forms to process (default is all)").withLongOpt("num-forms")
.create("f");
Option numThreadOpt = OptionBuilder.withArgName("int").hasArg()
.withDescription("number of threads (default " + DEFAULT_THREADS_NUMBER + ")")
.withLongOpt("num-threads").create("t");
Option formIdFileNameOpt = OptionBuilder.withArgName("file").hasArg().withDescription("form id mapping")
.isRequired().withLongOpt("form-id").create();
Option ngramFileNameOpt = OptionBuilder.withArgName("file").hasArg().withDescription("form idf mapping")
.isRequired().withLongOpt("form-idf").create();
Option nGramLengthOpt = OptionBuilder.withArgName("int").hasArg()
.withDescription("n-gram length (default is " + DEFAULT_N_GRAM_LENGTH + ")")
.withLongOpt("ngram-length").create("l");
Option nGramSizeOpt = OptionBuilder.withArgName("int").hasArg()
.withDescription("n-gram length (default is " + DEFAULT_N_GRAM_SIZE + ")").withLongOpt("ngram-size")
.create("s");
Option traceOpt = OptionBuilder.withDescription("trace mode").withLongOpt("trace").create();
Option debugOpt = OptionBuilder.withDescription("debug mode").withLongOpt("debug").create();
options.addOption(
OptionBuilder.withDescription("enter in the interactive mode").withLongOpt("interactive-mode")
.create());
options.addOption(OptionBuilder.withArgName("file").hasArg()
.withDescription("evaluation file in tsv format (page\\tid\\tleft\\tterm\\tright)").isRequired()
.withLongOpt("eval-file").create());
Option normalizedOpt = OptionBuilder
.withDescription("normalize vectors (default is " + DEFAULT_NORMALIZE + ")")
.withLongOpt("normalized").create("n");
Option notificationPointOpt = OptionBuilder.withArgName("int").hasArg().withDescription(
"receive notification every n pages (default is " + DEFAULT_NOTIFICATION_POINT + ")")
.withLongOpt("notification-point").create("b");
options.addOption("h", "help", false, "print this message");
options.addOption("v", "version", false, "output version information and exit");
options.addOption(inputFileNameOpt);
options.addOption(tfOpt);
options.addOption(outputFileNameOpt);
options.addOption(numThreadOpt);
options.addOption(notificationPointOpt);
options.addOption(numFormOpt);
options.addOption(formIdFileNameOpt);
options.addOption(ngramFileNameOpt);
options.addOption(nGramLengthOpt);
options.addOption(nGramSizeOpt);
options.addOption(normalizedOpt);
options.addOption(stopwordsFileNameOpt);
options.addOption(traceOpt);
options.addOption(debugOpt);
CommandLineParser parser = new PosixParser();
CommandLine line = parser.parse(options, args);
Properties defaultProps = new Properties();
defaultProps.load(new InputStreamReader(new FileInputStream(logConfig), "UTF-8"));
//defaultProps.setProperty("log4j.rootLogger", "info,stdout");
if (line.hasOption("trace")) {
defaultProps.setProperty("log4j.rootLogger", "trace,stdout");
} else if (line.hasOption("debug")) {
defaultProps.setProperty("log4j.rootLogger", "debug,stdout");
} else {
defaultProps.setProperty("log4j.rootLogger", "info,stdout");
}
PropertyConfigurator.configure(defaultProps);
logger.debug(options);
logger.debug(line.getOptionValue("output") + "\t" + line.getOptionValue("input") + "\t" + line
.getOptionValue("lsm"));
boolean normalized = false;
if (line.hasOption("normalized")) {
normalized = true;
}
int nGramSize = DEFAULT_N_GRAM_SIZE;
if (line.hasOption("ngram-size")) {
nGramSize = Integer.parseInt(line.getOptionValue("ngram-size"));
}
int nGramLength = DEFAULT_N_GRAM_LENGTH;
if (line.hasOption("ngram-length")) {
nGramLength = Integer.parseInt(line.getOptionValue("ngram-length"));
}
NGramModel nGramModel = new NGramModel(line.getOptionValue("form-id"), line.getOptionValue("form-idf"),
line.getOptionValue("stopwords"));
int numThreads = DEFAULT_THREADS_NUMBER;
if (line.hasOption("num-threads")) {
numThreads = Integer.parseInt(line.getOptionValue("num-threads"));
}
int minimumFormFreq = DEFAULT_MINIMUM_FORM_FREQ;
if (line.hasOption("min-freq")) {
minimumFormFreq = Integer.parseInt(line.getOptionValue("min-freq"));
}
int minimumPageFreq = DEFAULT_MINIMUM_PAGE_FREQ;
if (line.hasOption("min-page")) {
minimumPageFreq = Integer.parseInt(line.getOptionValue("min-page"));
}
int numForms = DEFAULT_NUM_FORMS;
if (line.hasOption("num-forms")) {
numForms = Integer.parseInt(line.getOptionValue("num-forms"));
}
int notificationPoint = DEFAULT_NOTIFICATION_POINT;
if (line.hasOption("notification-point")) {
notificationPoint = Integer.parseInt(line.getOptionValue("notification-point"));
}
int tfType = BOW.DEFAULT_TERM_FREQUENCY_TYPE;
if (line.hasOption("tf")) {
tfType = Integer.parseInt(line.getOptionValue("tf"));
}
logger.info("extracting one example per sense using " + numThreads + " threads");
NGramOneExamplePerSenseExtractor oneExamplePerSenseExtractor = new NGramOneExamplePerSenseExtractor(
line.getOptionValue("output"), numThreads, nGramModel);
oneExamplePerSenseExtractor.setNormalized(normalized);
oneExamplePerSenseExtractor.setTfType(tfType);
oneExamplePerSenseExtractor.setNotificationPoint(notificationPoint);
oneExamplePerSenseExtractor.setNumForms(numForms);
oneExamplePerSenseExtractor.extract(line.getOptionValue("input"));
if (line.hasOption("interactive-mode")) {
oneExamplePerSenseExtractor.interactive();
}
if (line.hasOption("eval-file")) {
oneExamplePerSenseExtractor.eval(line.getOptionValue("eval-file"));
}
} catch (ParseException e) {
// oops, something went wrong
System.err.println("Parsing failed: " + e.getMessage() + "\n");
HelpFormatter formatter = new HelpFormatter();
formatter.printHelp(200,
"java -cp dist/thewikimachine.jar eu.fbk.twm.classifier.NGramOneExamplePerSenseExtractor",
"\n", options, "\n", true);
} catch (Exception e) {
logger.error(e);
e.printStackTrace();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy