All Downloads are FREE. Search and download functionalities are using the official Maven repository.

eu.fbk.twm.classifier.NGramOneExamplePerSenseExtractor Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (2014) Fondazione Bruno Kessler (http://www.fbk.eu/)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package eu.fbk.twm.classifier;

import eu.fbk.utils.core.core.HashMultiSet;
import eu.fbk.utils.core.core.MultiSet;
import eu.fbk.utils.lsa.BOW;
import eu.fbk.utils.math.Node;
import org.apache.commons.cli.*;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import eu.fbk.twm.utils.analysis.HardTokenizer;
import eu.fbk.twm.utils.analysis.Tokenizer;
import eu.fbk.twm.utils.CharacterTable;

import java.io.*;
import java.text.DecimalFormat;
import java.util.*;

/**
 * Created with IntelliJ IDEA.
 * User: giuliano
 * Date: 1/6/14
 * Time: 7:59 AM
 * To change this template use File | Settings | File Templates.
 */
public class NGramOneExamplePerSenseExtractor extends OneExamplePerSenseExtractor {

    /**
     * Define a static logger variable so that it references the
     * Logger instance named NGramOneExamplePerSenseExtractor.
     */
    static Logger logger = Logger.getLogger(NGramOneExamplePerSenseExtractor.class.getName());

    DecimalFormat vf = new DecimalFormat("###0");

    DecimalFormat sf = new DecimalFormat(".000");

    public static final int DEFAULT_N_GRAM_LENGTH = 3;

    public static final int DEFAULT_N_GRAM_SIZE = 1000000;

    private static DecimalFormat df = new DecimalFormat("#.0000");

    private NGramModel nGramModel;

    protected int totalFreq;

    public NGramOneExamplePerSenseExtractor(String outputFileName, int numThreads, NGramModel nGramModel)
            throws IOException {
        this(new File(outputFileName), numThreads, nGramModel);
    }

    public NGramOneExamplePerSenseExtractor(File outputFile, int numThreads, NGramModel nGramModel) throws IOException {
        super(outputFile, numThreads);
        this.nGramModel = nGramModel;
        totalFreq = 0;
    }

    public void interactive() {
        InputStreamReader reader = null;
        BufferedReader myInput = null;
        while (true) {
            System.out.println("\nPlease write a key and type  to continue (CTRL C to exit):");

            reader = new InputStreamReader(System.in);
            myInput = new BufferedReader(reader);
            //String query = null;

            try {
                String query = myInput.readLine().toString();
                String[] s = query.split("\t");
                Sense[] senses = classify(s);
                if (s.length == 5) {
                    String answerPage = "null";
                    if (senses != null && senses.length > 0) {
                        answerPage = senses[0].getPage();
                    }
                    if (answerPage.equals(s[0])) {
                        logger.info(answerPage + " = " + s[0]);
                    } else {
                        logger.warn(answerPage + " != " + s[0]);
                    }
                }

            } catch (IOException e) {
                logger.error(e);
                e.printStackTrace();
            }
        }

    }

    public void eval(String name) throws IOException {
        LineNumberReader lnr = new LineNumberReader(new InputStreamReader(new FileInputStream(name), "UTF-8"));
        double tp = 0, fp = 0, fn = 0, tn = 0, tot = 0;
        double p = 0, r = 0, f1 = 0, acc = 0, correct = 0;

        String line;

        while ((line = lnr.readLine()) != null) {
            String[] s = line.split("\t");

            if (s.length == 5) {
                long b = System.currentTimeMillis();
                Sense[] senses = classify(s);
                long e = System.currentTimeMillis();
                String answerPage = "null";
                if (senses != null && senses.length > 0) {
                    answerPage = senses[0].getPage();
                }
                if (answerPage.equals(s[0])) {
                    logger.info(answerPage + " = " + s[0]);
                } else {
                    logger.warn(answerPage + " != " + s[0]);
                }

                if (answerPage == null) {
                    if (s[0].equals("Null_result")) {
                        tp++;
                        correct++;
                    } else {
                        fn++;
                    }
                } else {

                    if (answerPage.equals(s[0])) {
                        tp++;
                        correct++;
                    } else {
                        fn++;
                        fp++;
                    }
                }
                tot++;
                p = tp / (tp + fp);
                r = tp / (tp + fn);
                f1 = (2 * p * r) / (p + r);
                System.out.println(
                        s[1] + "\t" + vf.format(tp) + "\t" + vf.format(fp) + "\t" + vf.format(fn) + "\t" + sf.format(p)
                                + "\t" + sf.format(r) + "\t" + sf.format(f1) + "\t" + answerPage + "\t" + s[0] + "\t"
                                + s[3] + "\t" + df.format(e - b));
            }
        }

    }

    private Sense[] classify(String[] s) {
        String[] e = new String[9];

        if (s.length == 2) {
            e[3] = "null";
            e[2] = s[1];
            e[7] = s[0];
            e[8] = ".";
        } else if (s.length == 3) {
            e[3] = "null";
            e[2] = s[1];
            e[7] = s[0];
            e[8] = s[2];
        } else if (s.length == 5) {
            e[3] = s[0];
            e[2] = s[3];
            e[7] = s[2];
            e[8] = s[4];
        } else {
            logger.error(s.length);
            return null;
        }
        Tokenizer tokenizer = new HardTokenizer();
        e[2] = tokenizer.tokenizedString(e[2]);
        e[7] = tokenizer.tokenizedString(e[7]);
        e[8] = tokenizer.tokenizedString(e[8]);
        logger.debug("2='" + e[2] + "'");
        logger.debug("7='" + e[7] + "'");
        logger.debug("8='" + e[8] + "'");
        logger.debug(Arrays.toString(s));
        logger.debug(Arrays.toString(e));
        List list = new ArrayList();
        list.add(e);
        Example example = new Example(e[3], list, e[2]);
        logger.trace(Arrays.toString(example.getLocalContextVector()));
        logger.trace(Arrays.toString(example.getBowVector()));
        Example[] examples = map.get(e[2]);
        Sense[] senses = null;
        if (examples != null) {
            senses = new ContextualSense[examples.length];
            for (int i = 0; i < examples.length; i++) {
                int freq = examples[i].getFreq();
                if (freq > 2) {
                    logger.trace(i + "\t" + freq + "\t" + Arrays.toString(examples[i].getLocalContextVector()));
                    logger.trace(i + "\t" + freq + "\t" + Arrays.toString(examples[i].getBowVector()));
                    double localDot = 0;//Node.dot(example.getLocalContextVector(), examples[i].getLocalContextVector());
                    double bowDot = Node.dot(example.getBowVector(), examples[i].getBowVector());
                    logger.info(
                            i + "\t" + freq + "\t" + df.format(localDot) + "\t" + df.format(bowDot) + "\t" + examples[i]
                                    .getPage());
                    senses[i] = new ContextualSense(examples[i].getPage(), examples[i].getFreq(), localDot, bowDot);
                } else {
                    senses[i] = new ContextualSense(examples[i].getPage(), examples[i].getFreq(), 0, 0);
                }
            }
            Arrays.sort(senses, new Comparator() {

                        @Override
                        public int compare(Sense sense, Sense sense2) {
                            double diff = sense.getCombo() - sense2.getCombo();
                            if (diff > 0) {
                                return -1;
                            } else if (diff < 0) {
                                return 1;
                            }
                            return 0;
                        }
                    }
            );

            for (int i = 0; i < senses.length && i < 5; i++) {
                logger.info(i + "\t" + senses[i].getPage() + "\t" + senses[i].getCombo());
            }

        }
        return senses;
    }

    @Override
    public void end() {
        //interactive();
    }

    Map map = new HashMap();

    @Override
    public void buildExamples(Map> senseMap, String form) {

        Example[] examples = new Example[senseMap.size()];
        Iterator it = senseMap.keySet().iterator();
        String page;
        List list;
        for (int i = 0; it.hasNext(); i++) {
            page = it.next();
            list = senseMap.get(page);
            //logger.debug(form + "\t" + page + "\t" + list.size());

            examples[i] = new Example(page, list, form);
        }
        map.put(form, examples);
    }

    class Example implements Comparable {

        private BOW bow;

        private eu.fbk.utils.math.Node[] bowVector;

        private eu.fbk.utils.math.Node[] localContextVector;

        private int freq;

        private String page;
        private String form;

        Example(String page, List list, String form) {
            this.page = page;
            this.form = form;
            freq = list.size();
            totalFreq += freq;
            StringBuilder sb = new StringBuilder();
            bow = buildBOW(list);

            //logger.debug(bow);
            bowVector = buildBowVector(bow);
            //bowVector = new Node[0];
            localContextVector = buildLocalContext(list);
            //logger.debug(Arrays.toString(localContextVector));

            if (normalized) {
                Node.normalize(bowVector);
                Node.normalize(localContextVector);
            }

            //logger.debug(toString());

        }

        private Node[] buildLocalContext(List list) {
            //Map nodes = new TreeSet();
            MultiSet multiSet = new HashMultiSet();
            for (int i = 0; i < list.size(); i++) {
                String[] s = list.get(i);
                String leftContext = s[eu.fbk.twm.index.csv.OneExamplePerSenseExtractor.LEFT_CONTEXT_INDEX];
                String form = s[eu.fbk.twm.index.csv.OneExamplePerSenseExtractor.FORM_INDEX];
                int pos = leftContext.lastIndexOf(' ');
                if (pos != -1) {
                    String t = leftContext.substring(pos + 1, leftContext.length());
                    //logger.trace(i + "\t" + t + "\t" + form + "\t" + page);
                    multiSet.add(t);
                } else {
                    multiSet.add(leftContext);
                }
            }
            //logger.debug(multiSet.toSortedMap());
            SortedSet nodes = new TreeSet();
            Iterator it = multiSet.iterator();
            for (int i = 0; it.hasNext(); i++) {
                String t = it.next();
                Integer index = nGramModel.getIndex(t);
                if (index != null) {
                    int value = multiSet.getFrequency(t);
                    //logger.trace(i + "\t'" + t + "'\t" + index + ":" + value + "\t" + form + "\t" + page);
                    nodes.add(new Node(index, value));
                }
            }

            return nodes.toArray(new Node[nodes.size()]);
        }

        private boolean isCapitalized(String w) {
            if (w.length() == 0) {
                return false;
            }
            return Character.isUpperCase(w.charAt(0));
        }

        private boolean isUpperCase(String w) {
            if (w.length() == 0) {
                return false;
            }
            for (int i = 0; i < w.length(); i++) {
                if (Character.isLowerCase(w.charAt(i))) {
                    return false;
                }
            }
            return true;
        }

        Node[] getLocalContextVector() {
            return localContextVector;
        }

        private BOW buildBOW(List list) {
            bow = new BOW(tfType);
            String[] s;
            String[] leftContext;
            String[] rightContext;
            for (int i = 0; i < list.size(); i++) {
                try {
                    s = list.get(i);
                    //todo: add toLowerCase
                    //leftContext = spacePattern.split(s[WikipediaExampleExtractor.LEFT_CONTEXT_INDEX].toLowerCase());
                    leftContext = spacePattern.split(s[eu.fbk.twm.index.csv.OneExamplePerSenseExtractor.LEFT_CONTEXT_INDEX]);
                    //rightContext = spacePattern.split(s[WikipediaExampleExtractor.RIGHT_CONTEXT_INDEX].toLowerCase());
                    rightContext = spacePattern.split(s[eu.fbk.twm.index.csv.OneExamplePerSenseExtractor.RIGHT_CONTEXT_INDEX]);
                    //logger.debug(i + "\t" + s[WikipediaExampleExtractor.LEFT_CONTEXT_INDEX] + "\t<"+form+">\t" + s[WikipediaExampleExtractor.RIGHT_CONTEXT_INDEX]);

                    extract(leftContext, bow);
                    extract(rightContext, bow);

                } catch (Exception e) {
                    logger.error("Error at " + i);
                    logger.error(e);
                }
            }
            return bow;
        }

        private Node[] buildBowVector(BOW bow) {
            SortedSet nodes = new TreeSet();
            Iterator it = bow.iterator();
            for (int i = 0; it.hasNext(); i++) {
                String form = it.next();

                Integer index = nGramModel.getIndex(form);
                Double value = nGramModel.getValue(form);

                if (index != null && value != null) {
                    double tf = bow.tf(form);
                    //logger.debug(i + "\t" + form + "\t"+page+ "\t" + index + ":" + value + " * " + tf);
                    nodes.add(new Node(index, value * tf));
                }
            }
            return nodes.toArray(new Node[nodes.size()]);
        }

        private String tokenizedForm(String[] tokenArray, int start, int end) {
            StringBuilder sb = new StringBuilder();
            sb.append(tokenArray[start]);
            for (int i = start + 1; i <= end; i++) {
                sb.append(CharacterTable.SPACE);
                sb.append(tokenArray[i]);
            }
            return sb.toString();
        }

        public void extract(String[] tokenArray, BOW bow) {
            int m = 0;
            String tokenizedForm;
            for (int i = 0; i < tokenArray.length; i++) {
                m = i + nGramModel.getLength();// + 1;
                if (m > tokenArray.length) {
                    m = tokenArray.length;
                }
                for (int j = i; j < m; j++) {
                    tokenizedForm = tokenizedForm(tokenArray, i, j);
                    bow.add(tokenizedForm);
                }
            }

        }

        String getForm() {
            return form;
        }

        public String getPage() {
            return page;
        }

        public BOW getBow() {
            return bow;
        }

        public Node[] getBowVector() {
            return bowVector;
        }

        public int getFreq() {
            return freq;
        }

        @Override
        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append(page);
            sb.append(CharacterTable.HORIZONTAL_TABULATION);
            sb.append((double) freq / totalFreq);
            sb.append(CharacterTable.HORIZONTAL_TABULATION);
            sb.append(Node.toString(localContextVector));
            sb.append(CharacterTable.HORIZONTAL_TABULATION);
            sb.append(Node.toString(bowVector));

            return sb.toString();
        }

        @Override
        public int compareTo(Example example) {
            return example.getFreq() - freq;
        }
    }

    private void writeExampleArray(Example[] examples) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < examples.length; i++) {
            sb.append(examples[i].getForm());
            sb.append(CharacterTable.HORIZONTAL_TABULATION);
            sb.append(examples[i]);
            sb.append(CharacterTable.LINE_FEED);
        }
        synchronized (this) {
            senseWriter.print(sb.toString());
        }
    }

    public static void main(String[] args) {
        // java com.ml.test.net.HttpServerDemo
        String logConfig = System.getProperty("log-config");
        if (logConfig == null) {
            logConfig = "configuration/log-config.txt";
        }

        //PropertyConfigurator.configure(logConfig);

        Options options = new Options();
        try {
            Option inputFileNameOpt = OptionBuilder.withArgName("file").hasArg()
                    .withDescription("sorted form/page file").isRequired().withLongOpt("input").create("i");
            Option outputFileNameOpt = OptionBuilder.withArgName("file").hasArg()
                    .withDescription("one sense per example file").isRequired().withLongOpt("output").create("o");
            Option tfOpt = OptionBuilder.withArgName("FUNC").hasArg().withDescription(
                    "term frequency function; FUNC is " + BOW.RAW_TERM_FREQUENCY + "=`"
                            + BOW.labels[BOW.RAW_TERM_FREQUENCY] + BOW.BOOLEAN_TERM_FREQUENCY + "=`"
                            + BOW.labels[BOW.BOOLEAN_TERM_FREQUENCY] + "'," + BOW.LOGARITHMIC_TERM_FREQUENCY + "=`"
                            + BOW.labels[BOW.LOGARITHMIC_TERM_FREQUENCY] + "'," + BOW.AUGMENTED_TERM_FREQUENCY + "=`"
                            + BOW.labels[BOW.AUGMENTED_TERM_FREQUENCY] + " (default is "
                            + BOW.DEFAULT_TERM_FREQUENCY_TYPE + ")").withLongOpt("tf").create();
            Option stopwordsFileNameOpt = OptionBuilder.withArgName("file").hasArg().withDescription("stopwords")
                    .isRequired().withLongOpt("stopwords").create();
            Option numFormOpt = OptionBuilder.withArgName("int").hasArg()
                    .withDescription("maximum number of forms to process (default is all)").withLongOpt("num-forms")
                    .create("f");
            Option numThreadOpt = OptionBuilder.withArgName("int").hasArg()
                    .withDescription("number of threads (default " + DEFAULT_THREADS_NUMBER + ")")
                    .withLongOpt("num-threads").create("t");
            Option formIdFileNameOpt = OptionBuilder.withArgName("file").hasArg().withDescription("form id mapping")
                    .isRequired().withLongOpt("form-id").create();
            Option ngramFileNameOpt = OptionBuilder.withArgName("file").hasArg().withDescription("form idf mapping")
                    .isRequired().withLongOpt("form-idf").create();
            Option nGramLengthOpt = OptionBuilder.withArgName("int").hasArg()
                    .withDescription("n-gram length (default is " + DEFAULT_N_GRAM_LENGTH + ")")
                    .withLongOpt("ngram-length").create("l");
            Option nGramSizeOpt = OptionBuilder.withArgName("int").hasArg()
                    .withDescription("n-gram length (default is " + DEFAULT_N_GRAM_SIZE + ")").withLongOpt("ngram-size")
                    .create("s");
            Option traceOpt = OptionBuilder.withDescription("trace mode").withLongOpt("trace").create();
            Option debugOpt = OptionBuilder.withDescription("debug mode").withLongOpt("debug").create();
            options.addOption(
                    OptionBuilder.withDescription("enter in the interactive mode").withLongOpt("interactive-mode")
                            .create());
            options.addOption(OptionBuilder.withArgName("file").hasArg()
                    .withDescription("evaluation file in tsv format (page\\tid\\tleft\\tterm\\tright)").isRequired()
                    .withLongOpt("eval-file").create());
            Option normalizedOpt = OptionBuilder
                    .withDescription("normalize vectors (default is " + DEFAULT_NORMALIZE + ")")
                    .withLongOpt("normalized").create("n");
            Option notificationPointOpt = OptionBuilder.withArgName("int").hasArg().withDescription(
                    "receive notification every n pages (default is " + DEFAULT_NOTIFICATION_POINT + ")")
                    .withLongOpt("notification-point").create("b");

            options.addOption("h", "help", false, "print this message");
            options.addOption("v", "version", false, "output version information and exit");

            options.addOption(inputFileNameOpt);
            options.addOption(tfOpt);
            options.addOption(outputFileNameOpt);
            options.addOption(numThreadOpt);
            options.addOption(notificationPointOpt);
            options.addOption(numFormOpt);
            options.addOption(formIdFileNameOpt);
            options.addOption(ngramFileNameOpt);
            options.addOption(nGramLengthOpt);
            options.addOption(nGramSizeOpt);
            options.addOption(normalizedOpt);
            options.addOption(stopwordsFileNameOpt);
            options.addOption(traceOpt);
            options.addOption(debugOpt);

            CommandLineParser parser = new PosixParser();
            CommandLine line = parser.parse(options, args);

            Properties defaultProps = new Properties();
            defaultProps.load(new InputStreamReader(new FileInputStream(logConfig), "UTF-8"));
            //defaultProps.setProperty("log4j.rootLogger", "info,stdout");
            if (line.hasOption("trace")) {
                defaultProps.setProperty("log4j.rootLogger", "trace,stdout");
            } else if (line.hasOption("debug")) {
                defaultProps.setProperty("log4j.rootLogger", "debug,stdout");
            } else {
                defaultProps.setProperty("log4j.rootLogger", "info,stdout");
            }
            PropertyConfigurator.configure(defaultProps);

            logger.debug(options);

            logger.debug(line.getOptionValue("output") + "\t" + line.getOptionValue("input") + "\t" + line
                    .getOptionValue("lsm"));

            boolean normalized = false;
            if (line.hasOption("normalized")) {
                normalized = true;
            }

            int nGramSize = DEFAULT_N_GRAM_SIZE;
            if (line.hasOption("ngram-size")) {
                nGramSize = Integer.parseInt(line.getOptionValue("ngram-size"));
            }

            int nGramLength = DEFAULT_N_GRAM_LENGTH;
            if (line.hasOption("ngram-length")) {
                nGramLength = Integer.parseInt(line.getOptionValue("ngram-length"));
            }
            NGramModel nGramModel = new NGramModel(line.getOptionValue("form-id"), line.getOptionValue("form-idf"),
                    line.getOptionValue("stopwords"));

            int numThreads = DEFAULT_THREADS_NUMBER;
            if (line.hasOption("num-threads")) {
                numThreads = Integer.parseInt(line.getOptionValue("num-threads"));
            }

            int minimumFormFreq = DEFAULT_MINIMUM_FORM_FREQ;
            if (line.hasOption("min-freq")) {
                minimumFormFreq = Integer.parseInt(line.getOptionValue("min-freq"));
            }

            int minimumPageFreq = DEFAULT_MINIMUM_PAGE_FREQ;
            if (line.hasOption("min-page")) {
                minimumPageFreq = Integer.parseInt(line.getOptionValue("min-page"));
            }

            int numForms = DEFAULT_NUM_FORMS;
            if (line.hasOption("num-forms")) {
                numForms = Integer.parseInt(line.getOptionValue("num-forms"));
            }

            int notificationPoint = DEFAULT_NOTIFICATION_POINT;
            if (line.hasOption("notification-point")) {
                notificationPoint = Integer.parseInt(line.getOptionValue("notification-point"));
            }

            int tfType = BOW.DEFAULT_TERM_FREQUENCY_TYPE;
            if (line.hasOption("tf")) {
                tfType = Integer.parseInt(line.getOptionValue("tf"));
            }

            logger.info("extracting one example per sense using " + numThreads + " threads");
            NGramOneExamplePerSenseExtractor oneExamplePerSenseExtractor = new NGramOneExamplePerSenseExtractor(
                    line.getOptionValue("output"), numThreads, nGramModel);
            oneExamplePerSenseExtractor.setNormalized(normalized);
            oneExamplePerSenseExtractor.setTfType(tfType);
            oneExamplePerSenseExtractor.setNotificationPoint(notificationPoint);
            oneExamplePerSenseExtractor.setNumForms(numForms);
            oneExamplePerSenseExtractor.extract(line.getOptionValue("input"));
            if (line.hasOption("interactive-mode")) {
                oneExamplePerSenseExtractor.interactive();
            }
            if (line.hasOption("eval-file")) {
                oneExamplePerSenseExtractor.eval(line.getOptionValue("eval-file"));
            }

        } catch (ParseException e) {
            // oops, something went wrong
            System.err.println("Parsing failed: " + e.getMessage() + "\n");
            HelpFormatter formatter = new HelpFormatter();
            formatter.printHelp(200,
                    "java -cp dist/thewikimachine.jar eu.fbk.twm.classifier.NGramOneExamplePerSenseExtractor",
                    "\n", options, "\n", true);
        } catch (Exception e) {
            logger.error(e);
            e.printStackTrace();
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy