eu.fbk.twm.classifier.OneExamplePerSenseBowClassifier Maven / Gradle / Ivy
The newest version!
/*
* Copyright (2013) Fondazione Bruno Kessler (http://www.fbk.eu/)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package eu.fbk.twm.classifier;
import eu.fbk.twm.index.OneExamplePerSenseSearcher;
import eu.fbk.twm.utils.Defaults;
import eu.fbk.twm.utils.StringTable;
import eu.fbk.twm.utils.WikipediaExtractor;
import eu.fbk.twm.utils.analysis.HardTokenizer;
import eu.fbk.twm.utils.analysis.Token;
import eu.fbk.twm.utils.analysis.Tokenizer;
import org.apache.commons.cli.*;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import eu.fbk.utils.lsa.LSI;
import eu.fbk.utils.lsa.BOW;
import eu.fbk.utils.math.Node;
import org.xerial.snappy.SnappyInputStream;
import java.io.*;
import java.text.DecimalFormat;
import java.util.*;
import java.util.regex.Pattern;
/**
* Created with IntelliJ IDEA.
* User: giuliano
* Date: 2/6/13
* Time: 11:52 AM
* To change this template use File | Settings | File Templates.
*/
public class OneExamplePerSenseBowClassifier {
/**
* Define a static logger variable so that it references the
* Logger instance named OneExamplePerSenseBowClassifier
.
*/
static Logger logger = Logger.getLogger(OneExamplePerSenseBowClassifier.class.getName());
protected LSI lsi;
protected OneExamplePerSenseSearcher oneExamplePerSenseSearcher;
protected static DecimalFormat rf = new DecimalFormat("###,###,##0.000000");
protected static Pattern tabPattern = Pattern.compile(StringTable.HORIZONTAL_TABULATION);
protected static DecimalFormat df = new DecimalFormat("###,###,###,###");
protected static DecimalFormat tf = new DecimalFormat("000,000,000.#");
protected static DecimalFormat mf = new DecimalFormat("#.000");
protected boolean normalized;
public OneExamplePerSenseBowClassifier(LSI lsi, OneExamplePerSenseSearcher oneExamplePerSenseSearcher) {
this.lsi = lsi;
this.oneExamplePerSenseSearcher = oneExamplePerSenseSearcher;
normalized = true;
}
public void classify(File f, boolean compress) throws IOException {
logger.info("classifying " + f);
long begin = System.currentTimeMillis(), end = 0;
LineNumberReader lnr = null;
if (compress) {
lnr = new LineNumberReader(new InputStreamReader(new SnappyInputStream(new FileInputStream(f)), "UTF-8"));
}
else {
lnr = new LineNumberReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));
}
//Tokenizer tokenizer = HardTokenizer.getInstance();
String line;
String[] s;
int tot = 0;
//Node[][] nodes;
logger.info("totalFreq\tsize\ttime (ms)\tdate");
int tp = 0, fp = 0, fn = 0;
while ((line = lnr.readLine()) != null) {
s = tabPattern.split(line);
//nodes = mapInstance(s);
Sense[] senses = classify(s);
String page = "";
if (senses.length > 0) {
page = senses[0].getPage();
logger.info("i\tprior\tbow\tls\tcombo\tpage");
for (int i = 0; i < senses.length && i<3; i++) {
logger.info(i + "\t" + rf.format(senses[i].getPrior()) + "\t" + rf.format(senses[i].getPrior()) + "\t" + rf.format(senses[i].getCombo()) + "\t" + senses[i].getPage());
}
}
if (s[0].equals(page)) {
tp++;
}
else {
fp++;
fn++;
}
logger.debug(tot + "\t" + tp + "\t" + fp + "\t" + fn + "\t" + s[0] + "\t" + s[1] + "\t" + page);
tot++;
}
lnr.close();
end = System.currentTimeMillis();
double precision = (double) tp / (tp + fp);
double recall = (double) tp / (tp + fn);
double f1 = (2 * precision * recall) / (precision + recall);
logger.debug(tot + "\t" + tp + "\t" + fp + "\t" + fn + "\t" + rf.format(precision) + "\t" + rf.format(recall) + "\t" + rf.format(f1));
logger.info(df.format(tot) + "\t" + df.format(end - begin) + "\t" + new Date());
logger.info("ending the process " + new Date() + "...");
}
public Sense[] classify(Token[] s, String form) {
return classify(createBow(s), form);
}
private Sense[] classify(String[] s) {
return classify(createBow(s), s[3]);
}
protected double dot(Node[] n1, Node[] n2) {
double dot12 = Node.dot(n1, n2);
double dot11 = Node.dot(n1, n1);
double dot22 = Node.dot(n2, n2);
double dot = dot12 / Math.sqrt(dot11 * dot22);
return dot;
}
public static double dot(Node[] x, Node[] y, Map termMap, String form, String page) {
double sum = 0;
int xlen = x.length;
int ylen = y.length;
int i = 0;
int j = 0;
while (i < xlen && j < ylen) {
if (x[i].index == y[j].index) {
//logger.debug(x[i].index + "\t" + x[i].value + "\t" + y[j].value + "\t" + x[i].value * y[j].value);
logger.debug(form + "/" + page + "\t" + x[i].index + "\t" + termMap.get(x[i].index) + "\t" + mf.format(x[i].value) + "*" + mf.format(y[j].value) + "=" + mf.format(x[i].value * y[j].value));
sum += x[i++].value * y[j++].value;
}
else {
if (x[i].index > y[j].index) {
++j;
}
else {
++i;
}
}
}
logger.debug(form + "/" + page + "\t\t\t" + mf.format(sum));
return sum;
}
public static Map read(Reader in) throws IOException {
long begin = System.currentTimeMillis();
logger.info("\n\nreading index - term...");
Map termMap = new HashMap();
LineNumberReader lnr = new LineNumberReader(in);
String line;
String[] s;
Integer id;
while ((line = lnr.readLine()) != null) {
line = line.trim();
if (!line.startsWith("#")) {
//s = line.split("\t");
s = tabPattern.split(line);
if (s.length == 2) {
termMap.put(new Integer(s[0]), s[1]);
}
}
}
lnr.close();
long end = System.currentTimeMillis();
logger.debug(termMap.size() + " terms read in " + tf.format(end - begin));
return termMap;
}
//best strategy
public Sense[] classify(BOW bow, String form) {
//logger.debug(bow);
long begin = System.nanoTime();
//logger.debug("searching " + form);
OneExamplePerSenseSearcher.Entry[] entries = oneExamplePerSenseSearcher.search(form);
//logger.debug(Arrays.toString(entries));
long end = System.nanoTime();
Node[] bowVector = lsi.mapDocument(bow);
//logger.debug("bow\t" + Node.toString(bowVector));
//Node.normalize(bowVector);
Sense[] senses = new ContextualSense[entries.length];
//logger.debug("+\t" + Node.toString(lsVector));
for (int i = 0; i < entries.length; i++) {
//logger.debug(i + "\tB:" + Node.toString(entries[i].getBowVector()));
//Node.normalize(entries[i].getBowVector());
double bowKernel = Node.dot(bowVector, entries[i].getBowVector());
logger.trace(i + "\t" + entries[i].getPage() + "\t" + rf.format(entries[i].getFreq()) + "\t" + rf.format(bowKernel) + "\t" + rf.format(0));
senses[i] = new ContextualSense(entries[i].getPage(), entries[i].getFreq(), bowKernel, 0);
}
Arrays.sort(senses, new Comparator() {
@Override
public int compare(Sense sense, Sense sense2) {
double diff = sense.getCombo() - sense2.getCombo();
if (diff > 0) {
return -1;
}
else if (diff < 0) {
return 1;
}
return 0;
}
}
);
//logger.info("i\tprior\tbow\tls\tcombo\tpage");
//for (int i = 0; i < senses.length; i++) {
// logger.info(i + "\t" + rf.format(senses[i].getPrior()) + "\t" + rf.format(senses[i].getBow()) + "\t" + rf.format(senses[i].getLs()) + "\t" + rf.format(senses[i].getCombo()) + "\t" + rf.format(senses[i].getCombo() * senses[i].getPrior())+ "\t" + senses[i].getPage());
//}
return senses;
}
private BOW createBow(Token[] tokenArray) {
BOW bow = new BOW();
for (int i = 0; i < tokenArray.length; i++) {
bow.add(tokenArray[i].getForm().toLowerCase());
}
return bow;
}
private BOW createBow(String[] s) {
Tokenizer tokenizer = HardTokenizer.getInstance();
BOW bow = new BOW();
String[] left = tokenizer.stringArray(s[2].toLowerCase());
bow.addAll(left);
if (s.length == 5) {
String[] right = tokenizer.stringArray(s[4].toLowerCase());
bow.addAll(right);
}
return bow;
}
public void interactive() throws Exception {
InputStreamReader reader = null;
BufferedReader myInput = null;
while (true) {
System.out.println("\nPlease write a key and type to continue (CTRL C to exit):");
reader = new InputStreamReader(System.in);
myInput = new BufferedReader(reader);
String query = myInput.readLine().toString();
String[] s = query.split("\t");
String context = s[0];
if (s.length > 2) {
context += " " + s[2];
}
String form = s[1];
HardTokenizer hardTokenizer = new HardTokenizer();
Token[] tokens = hardTokenizer.tokenArray(context);
Sense[] sense = classify(tokens, form);
//logger.info(Arrays.toString(sense));
logger.info("i\tpage\tprior\tbow\tls\tcombo");
for (int i = 0; i < sense.length; i++) {
logger.info(i + "\t" + sense[i]);
}
}
}
public static void main(String args[]) throws Exception {
String logConfig = System.getProperty("log-config");
if (logConfig == null) {
logConfig = "configuration/log-config.txt";
}
PropertyConfigurator.configure(logConfig);
Options options = new Options();
try {
Option indexNameOpt = OptionBuilder.withArgName("dir").hasArg().withDescription("open an index with the specified name").isRequired().withLongOpt("index").create("i");
Option interactiveModeOpt = OptionBuilder.withDescription("enter in the interactive mode").withLongOpt("interactive-mode").create("t");
Option instanceFileOpt = OptionBuilder.withArgName("file").hasArg().withDescription("read the instances to classify from the specified file").withLongOpt("instance-file").create("f");
Option lsmDirOpt = OptionBuilder.withArgName("dir").hasArg().withDescription("lsi dir").isRequired().withLongOpt("lsi").create("l");
Option lsmDimOpt = OptionBuilder.withArgName("int").hasArg().withDescription("lsi dim").withLongOpt("dim").create("d");
Option normalizedOpt = OptionBuilder.withDescription("normalize vectors (default is " + WikipediaExtractor.DEFAULT_NORMALIZE + ")").withLongOpt("normalized").create();
options.addOption("h", "help", false, "print this message");
options.addOption("v", "version", false, "output version information and exit");
options.addOption(indexNameOpt);
options.addOption(interactiveModeOpt);
options.addOption(instanceFileOpt);
options.addOption(lsmDirOpt);
options.addOption(lsmDimOpt);
options.addOption(normalizedOpt);
CommandLineParser parser = new PosixParser();
CommandLine line = parser.parse(options, args);
if (line.hasOption("help") || line.hasOption("version")) {
throw new ParseException("");
}
int minFreq = OneExamplePerSenseSearcher.DEFAULT_MIN_FREQ;
if (line.hasOption("minimum-freq")) {
minFreq = Integer.parseInt(line.getOptionValue("minimum-freq"));
}
int notificationPoint = Defaults.DEFAULT_NOTIFICATION_POINT;
if (line.hasOption("notification-point")) {
notificationPoint = Integer.parseInt(line.getOptionValue("notification-point"));
}
String lsmDirName = line.getOptionValue("lsi");
if (!lsmDirName.endsWith(File.separator)) {
lsmDirName += File.separator;
}
boolean normalized = WikipediaExtractor.DEFAULT_NORMALIZE;
if (line.hasOption("normalized")) {
normalized = true;
}
File fileUt = new File(lsmDirName + "X-Ut");
File fileSk = new File(lsmDirName + "X-S");
File fileR = new File(lsmDirName + "X-row");
File fileC = new File(lsmDirName + "X-col");
File fileDf = new File(lsmDirName + "X-df");
int dim = 100;
if (line.hasOption("dim")) {
dim = Integer.parseInt(line.getOptionValue("dim"));
}
logger.debug(line.getOptionValue("lsi") + "\t" + line.getOptionValue("dim"));
Map termMap = null;
try {
termMap = read(new InputStreamReader(new FileInputStream(fileR), "UTF-8"));
} catch (IOException e) {
logger.error(e);
}
LSI lsi = new LSI(fileUt, fileSk, fileR, fileC, fileDf, dim, true, normalized);
OneExamplePerSenseSearcher oneExamplePerSenseSearcher = new OneExamplePerSenseSearcher(line.getOptionValue("index"));
oneExamplePerSenseSearcher.setNotificationPoint(notificationPoint);
if (line.hasOption("instance-file")) {
OneExamplePerSenseBowClassifier oneExamplePerSenseClassifier = new OneExamplePerSenseBowClassifier(lsi, oneExamplePerSenseSearcher);
oneExamplePerSenseClassifier.classify(new File(line.getOptionValue("instance-file")), false);
}
if (line.hasOption("interactive-mode")) {
OneExamplePerSenseBowClassifier oneExamplePerSenseClassifier = new OneExamplePerSenseBowClassifier(lsi, oneExamplePerSenseSearcher);
oneExamplePerSenseClassifier.interactive();
}
} catch (ParseException e) {
// oops, something went wrong
if (e.getMessage().length() > 0) {
System.out.println("Parsing failed: " + e.getMessage() + "\n");
}
HelpFormatter formatter = new HelpFormatter();
formatter.printHelp(400, "java -cp dist/thewikimachine.jar eu.fbk.twm.classifier.OneExamplePerSenseBowClassifier", "\n", options, "\n", true);
}
}
}