All Downloads are FREE. Search and download functionalities are using the official Maven repository.

eu.fbk.twm.classifier.OneExamplePerSenseBowClassifier Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (2013) Fondazione Bruno Kessler (http://www.fbk.eu/)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package eu.fbk.twm.classifier;

import eu.fbk.twm.index.OneExamplePerSenseSearcher;
import eu.fbk.twm.utils.Defaults;
import eu.fbk.twm.utils.StringTable;
import eu.fbk.twm.utils.WikipediaExtractor;
import eu.fbk.twm.utils.analysis.HardTokenizer;
import eu.fbk.twm.utils.analysis.Token;
import eu.fbk.twm.utils.analysis.Tokenizer;
import org.apache.commons.cli.*;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import eu.fbk.utils.lsa.LSI;
import eu.fbk.utils.lsa.BOW;
import eu.fbk.utils.math.Node;
import org.xerial.snappy.SnappyInputStream;

import java.io.*;
import java.text.DecimalFormat;
import java.util.*;
import java.util.regex.Pattern;

/**
 * Created with IntelliJ IDEA.
 * User: giuliano
 * Date: 2/6/13
 * Time: 11:52 AM
 * To change this template use File | Settings | File Templates.
 */
public class OneExamplePerSenseBowClassifier {
	/**
	 * Define a static logger variable so that it references the
	 * Logger instance named OneExamplePerSenseBowClassifier.
	 */
	static Logger logger = Logger.getLogger(OneExamplePerSenseBowClassifier.class.getName());

	protected LSI lsi;

	protected OneExamplePerSenseSearcher oneExamplePerSenseSearcher;

	protected static DecimalFormat rf = new DecimalFormat("###,###,##0.000000");

	protected static Pattern tabPattern = Pattern.compile(StringTable.HORIZONTAL_TABULATION);

	protected static DecimalFormat df = new DecimalFormat("###,###,###,###");

	protected static DecimalFormat tf = new DecimalFormat("000,000,000.#");

	protected static DecimalFormat mf = new DecimalFormat("#.000");

	protected boolean normalized;


	public OneExamplePerSenseBowClassifier(LSI lsi, OneExamplePerSenseSearcher oneExamplePerSenseSearcher) {
		this.lsi = lsi;
		this.oneExamplePerSenseSearcher = oneExamplePerSenseSearcher;
		normalized = true;
	}

	public void classify(File f, boolean compress) throws IOException {
		logger.info("classifying " + f);
		long begin = System.currentTimeMillis(), end = 0;
		LineNumberReader lnr = null;
		if (compress) {
			lnr = new LineNumberReader(new InputStreamReader(new SnappyInputStream(new FileInputStream(f)), "UTF-8"));
		}
		else {
			lnr = new LineNumberReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));
		}

		//Tokenizer tokenizer = HardTokenizer.getInstance();
		String line;
		String[] s;
		int tot = 0;
		//Node[][] nodes;
		logger.info("totalFreq\tsize\ttime (ms)\tdate");
		int tp = 0, fp = 0, fn = 0;
		while ((line = lnr.readLine()) != null) {
			s = tabPattern.split(line);
			//nodes = mapInstance(s);
			Sense[] senses = classify(s);
			String page = "";
			if (senses.length > 0) {
				page = senses[0].getPage();
				logger.info("i\tprior\tbow\tls\tcombo\tpage");
				for (int i = 0; i < senses.length && i<3; i++) {
					logger.info(i + "\t" + rf.format(senses[i].getPrior()) + "\t" + rf.format(senses[i].getPrior()) + "\t" + rf.format(senses[i].getCombo()) + "\t" + senses[i].getPage());
				}
			}

			if (s[0].equals(page)) {
				tp++;
			}
			else {
				fp++;
				fn++;
			}
			logger.debug(tot + "\t" + tp + "\t" + fp + "\t" + fn + "\t" + s[0] + "\t" + s[1] + "\t" + page);

			tot++;
		}
		lnr.close();
		end = System.currentTimeMillis();

		double precision = (double) tp / (tp + fp);
		double recall = (double) tp / (tp + fn);
		double f1 = (2 * precision * recall) / (precision + recall);
		logger.debug(tot + "\t" + tp + "\t" + fp + "\t" + fn + "\t" + rf.format(precision) + "\t" + rf.format(recall) + "\t" + rf.format(f1));
		logger.info(df.format(tot) + "\t" + df.format(end - begin) + "\t" + new Date());

		logger.info("ending the process " + new Date() + "...");
	}

	public Sense[] classify(Token[] s, String form) {
		return classify(createBow(s), form);
	}


	private Sense[] classify(String[] s) {
		return classify(createBow(s), s[3]);
	}

	protected double dot(Node[] n1, Node[] n2) {
		double dot12 = Node.dot(n1, n2);
		double dot11 = Node.dot(n1, n1);
		double dot22 = Node.dot(n2, n2);
		double dot = dot12 / Math.sqrt(dot11 * dot22);
		return dot;
	}

	public static double dot(Node[] x, Node[] y, Map termMap, String form, String page) {
		double sum = 0;
		int xlen = x.length;
		int ylen = y.length;
		int i = 0;
		int j = 0;
		while (i < xlen && j < ylen) {
			if (x[i].index == y[j].index) {

				//logger.debug(x[i].index + "\t" + x[i].value + "\t" + y[j].value + "\t" + x[i].value * y[j].value);

				logger.debug(form + "/" + page + "\t" + x[i].index + "\t" + termMap.get(x[i].index) + "\t" + mf.format(x[i].value) + "*" + mf.format(y[j].value) + "=" + mf.format(x[i].value * y[j].value));
				sum += x[i++].value * y[j++].value;
			}
			else {
				if (x[i].index > y[j].index) {
					++j;
				}
				else {
					++i;
				}
			}
		}
		logger.debug(form + "/" + page + "\t\t\t" + mf.format(sum));
		return sum;
	}

	public static Map read(Reader in) throws IOException {
		long begin = System.currentTimeMillis();
		logger.info("\n\nreading index - term...");
		Map termMap = new HashMap();
		LineNumberReader lnr = new LineNumberReader(in);

		String line;
		String[] s;
		Integer id;
		while ((line = lnr.readLine()) != null) {
			line = line.trim();
			if (!line.startsWith("#")) {
				//s = line.split("\t");
				s = tabPattern.split(line);
				if (s.length == 2) {
					termMap.put(new Integer(s[0]), s[1]);
				}
			}
		}
		lnr.close();


		long end = System.currentTimeMillis();
		logger.debug(termMap.size() + " terms read in " + tf.format(end - begin));
		return termMap;
	}

	//best strategy
	public Sense[] classify(BOW bow, String form) {


		//logger.debug(bow);
		long begin = System.nanoTime();
		//logger.debug("searching " + form);
		OneExamplePerSenseSearcher.Entry[] entries = oneExamplePerSenseSearcher.search(form);
		//logger.debug(Arrays.toString(entries));
		long end = System.nanoTime();

		Node[] bowVector = lsi.mapDocument(bow);
		//logger.debug("bow\t" + Node.toString(bowVector));

		//Node.normalize(bowVector);
		Sense[] senses = new ContextualSense[entries.length];
		//logger.debug("+\t" + Node.toString(lsVector));
		for (int i = 0; i < entries.length; i++) {
			//logger.debug(i + "\tB:" + Node.toString(entries[i].getBowVector()));
			//Node.normalize(entries[i].getBowVector());
			double bowKernel = Node.dot(bowVector, entries[i].getBowVector());
			logger.trace(i + "\t" + entries[i].getPage() + "\t" + rf.format(entries[i].getFreq()) + "\t" + rf.format(bowKernel) + "\t" + rf.format(0));
			senses[i] = new ContextualSense(entries[i].getPage(), entries[i].getFreq(), bowKernel, 0);
		}

		Arrays.sort(senses, new Comparator() {
			@Override
			public int compare(Sense sense, Sense sense2) {
				double diff = sense.getCombo() - sense2.getCombo();
				if (diff > 0) {
					return -1;
				}
				else if (diff < 0) {
					return 1;
				}
				return 0;
			}
		}
		);
		//logger.info("i\tprior\tbow\tls\tcombo\tpage");
		//for (int i = 0; i < senses.length; i++) {
		//	logger.info(i + "\t" + rf.format(senses[i].getPrior()) + "\t" + rf.format(senses[i].getBow()) + "\t" + rf.format(senses[i].getLs()) + "\t" + rf.format(senses[i].getCombo()) + "\t" + rf.format(senses[i].getCombo() * senses[i].getPrior())+  "\t" + senses[i].getPage());
		//}

		return senses;
	}


	private BOW createBow(Token[] tokenArray) {
		BOW bow = new BOW();
		for (int i = 0; i < tokenArray.length; i++) {
			bow.add(tokenArray[i].getForm().toLowerCase());
		}
		return bow;
	}

	private BOW createBow(String[] s) {
		Tokenizer tokenizer = HardTokenizer.getInstance();
		BOW bow = new BOW();
		String[] left = tokenizer.stringArray(s[2].toLowerCase());
		bow.addAll(left);
		if (s.length == 5) {
			String[] right = tokenizer.stringArray(s[4].toLowerCase());
			bow.addAll(right);
		}
		return bow;
	}

	public void interactive() throws Exception {
		InputStreamReader reader = null;
		BufferedReader myInput = null;
		while (true) {
			System.out.println("\nPlease write a key and type  to continue (CTRL C to exit):");

			reader = new InputStreamReader(System.in);
			myInput = new BufferedReader(reader);
			String query = myInput.readLine().toString();
			String[] s = query.split("\t");

			String context = s[0];
			if (s.length > 2) {
				context += " " + s[2];
			}

			String form = s[1];
			HardTokenizer hardTokenizer = new HardTokenizer();
			Token[] tokens = hardTokenizer.tokenArray(context);
			Sense[] sense = classify(tokens, form);
			//logger.info(Arrays.toString(sense));
			logger.info("i\tpage\tprior\tbow\tls\tcombo");
			for (int i = 0; i < sense.length; i++) {
				logger.info(i + "\t" + sense[i]);

			}
		}
	}


	public static void main(String args[]) throws Exception {
		String logConfig = System.getProperty("log-config");
		if (logConfig == null) {
			logConfig = "configuration/log-config.txt";
		}

		PropertyConfigurator.configure(logConfig);
		Options options = new Options();
		try {
			Option indexNameOpt = OptionBuilder.withArgName("dir").hasArg().withDescription("open an index with the specified name").isRequired().withLongOpt("index").create("i");
			Option interactiveModeOpt = OptionBuilder.withDescription("enter in the interactive mode").withLongOpt("interactive-mode").create("t");
			Option instanceFileOpt = OptionBuilder.withArgName("file").hasArg().withDescription("read the instances to classify from the specified file").withLongOpt("instance-file").create("f");
			Option lsmDirOpt = OptionBuilder.withArgName("dir").hasArg().withDescription("lsi dir").isRequired().withLongOpt("lsi").create("l");
			Option lsmDimOpt = OptionBuilder.withArgName("int").hasArg().withDescription("lsi dim").withLongOpt("dim").create("d");
			Option normalizedOpt = OptionBuilder.withDescription("normalize vectors (default is " + WikipediaExtractor.DEFAULT_NORMALIZE + ")").withLongOpt("normalized").create();

			options.addOption("h", "help", false, "print this message");
			options.addOption("v", "version", false, "output version information and exit");

			options.addOption(indexNameOpt);
			options.addOption(interactiveModeOpt);
			options.addOption(instanceFileOpt);
			options.addOption(lsmDirOpt);
			options.addOption(lsmDimOpt);
			options.addOption(normalizedOpt);

			CommandLineParser parser = new PosixParser();
			CommandLine line = parser.parse(options, args);

			if (line.hasOption("help") || line.hasOption("version")) {
				throw new ParseException("");
			}

			int minFreq = OneExamplePerSenseSearcher.DEFAULT_MIN_FREQ;
			if (line.hasOption("minimum-freq")) {
				minFreq = Integer.parseInt(line.getOptionValue("minimum-freq"));
			}

			int notificationPoint = Defaults.DEFAULT_NOTIFICATION_POINT;
			if (line.hasOption("notification-point")) {
				notificationPoint = Integer.parseInt(line.getOptionValue("notification-point"));
			}

			String lsmDirName = line.getOptionValue("lsi");
			if (!lsmDirName.endsWith(File.separator)) {
				lsmDirName += File.separator;
			}

			boolean normalized = WikipediaExtractor.DEFAULT_NORMALIZE;
			if (line.hasOption("normalized")) {
				normalized = true;
			}

			File fileUt = new File(lsmDirName + "X-Ut");
			File fileSk = new File(lsmDirName + "X-S");
			File fileR = new File(lsmDirName + "X-row");
			File fileC = new File(lsmDirName + "X-col");
			File fileDf = new File(lsmDirName + "X-df");
			int dim = 100;
			if (line.hasOption("dim")) {
				dim = Integer.parseInt(line.getOptionValue("dim"));
			}
			logger.debug(line.getOptionValue("lsi") + "\t" + line.getOptionValue("dim"));
			Map termMap = null;
			try {
				termMap = read(new InputStreamReader(new FileInputStream(fileR), "UTF-8"));
			} catch (IOException e) {
				logger.error(e);
			}

			LSI lsi = new LSI(fileUt, fileSk, fileR, fileC, fileDf, dim, true, normalized);
			OneExamplePerSenseSearcher oneExamplePerSenseSearcher = new OneExamplePerSenseSearcher(line.getOptionValue("index"));
			oneExamplePerSenseSearcher.setNotificationPoint(notificationPoint);


			if (line.hasOption("instance-file")) {
				OneExamplePerSenseBowClassifier oneExamplePerSenseClassifier = new OneExamplePerSenseBowClassifier(lsi, oneExamplePerSenseSearcher);
				oneExamplePerSenseClassifier.classify(new File(line.getOptionValue("instance-file")), false);
			}

			if (line.hasOption("interactive-mode")) {
				OneExamplePerSenseBowClassifier oneExamplePerSenseClassifier = new OneExamplePerSenseBowClassifier(lsi, oneExamplePerSenseSearcher);
				oneExamplePerSenseClassifier.interactive();
			}
		} catch (ParseException e) {
			// oops, something went wrong
			if (e.getMessage().length() > 0) {
				System.out.println("Parsing failed: " + e.getMessage() + "\n");
			}
			HelpFormatter formatter = new HelpFormatter();
			formatter.printHelp(400, "java -cp dist/thewikimachine.jar eu.fbk.twm.classifier.OneExamplePerSenseBowClassifier", "\n", options, "\n", true);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy