All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.sr.phrasesim.PhraseSimEvaluator Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.wikibrain.sr.phrasesim;

import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.commons.cli.*;
import org.apache.commons.io.FileUtils;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.DefaultOptionBuilder;
import org.wikibrain.core.cmd.Env;
import org.wikibrain.core.cmd.EnvBuilder;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;

import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author Shilad Sen
 */
public class PhraseSimEvaluator {
    private boolean debug = true;
    private static List> TEST_BUNDLES = Arrays.asList(
            makeSet("jazz music blues"),
            makeSet("music math statistics"),
            makeSet("music brain"),
            makeSet("brain mind"),
            makeSet("brain statistics algorithm")
    );
    private int k = 10;

    private final Env env;

    public PhraseSimEvaluator(Env env) {
        this.env = env;
    }

    public void evaluate(final List> bundles) throws ConfigurationException, IOException {
        String lc = env.getDefaultLanguage().getLangCode();
        File dir = FileUtils.getFile(env.getBaseDir(), "dat/sr/known-phrase/en");
        FileUtils.deleteQuietly(dir);

        final KnownPhraseSim sim = (KnownPhraseSim) env.getConfigurator().get(SRMetric.class, "known-phrase", "language", lc);
        if (!sim.getDataDir().equals(dir)) {
            throw new IllegalStateException("expected dir " + dir + ", found " + sim.getDataDir());
        }
        final Map ids = new ConcurrentHashMap();
        ParallelForEach.loop(bundles, new Procedure>() {
            @Override
            public void call(List bundle) throws Exception {
                for (String phrase : bundle) {
                    String s = sim.normalize(phrase);
                    if (!ids.containsKey(s)) {
                        ids.put(s, ids.size());
                    }
                    int id = ids.get(s);
                    sim.addPhrase(phrase, id);
                }
            }
        });

        sim.flushCosimilarity();
        sim.trainNormalizer();

        int numSamples = 0;
        int numSampleHits = 0;
        int numRecommended = 0;
        int numRecommendedHits = 0;
        int possible = 0;
        int numErrors = 0;

        long before = System.currentTimeMillis();
        Random rand = new Random();
        for (int i = 0; i < 1000; i++) {
            // Select a random bundle
            List bundle = bundles.get(rand.nextInt(bundles.size()));
            if (bundle.isEmpty()) {
                continue;
            }
            numSamples++;
            TIntSet bundleIds = new TIntHashSet();
            for (String p : bundle) {
                bundleIds.add(ids.get(sim.normalize(p)));
            }
            String target = bundle.iterator().next();
            int targetId = ids.get(sim.normalize(target));
            int j = 0;
            boolean hasHit = false;
            StringBuffer line = new StringBuffer(target).append(": ");
            SRResultList neighbors = sim.mostSimilar(target, k + 1);
            if (neighbors == null) {
                numErrors++;
                continue;
            }
            for (SRResult r : neighbors) {
                if (r.getId() != targetId) {
                    if (this.debug) line.append(
                            String.format("%s %.3f, ",
                                    sim.getPhrase(r.getId()), r.getScore()));
                    if (bundleIds.contains(r.getId())) {
                        hasHit = true;
                        numRecommendedHits++;
                    }
                    numRecommended++;
                    if (++j >= k) {
                        break;
                    }
                }
            }
            if (this.debug) System.out.println(line);
            possible += bundleIds.size();
            if (bundleIds.contains(targetId)) {
                possible--;
            }
            if (hasHit) {
                numSampleHits++;
            }
        }
        long after = System.currentTimeMillis();

        System.out.println("for " + bundles.size() + ", top " + k);
        System.out.println("Total samples: " + numSamples);
        System.out.println("Total errors: " + numErrors);
        System.out.println("Total seconds: " + ((after - before) / 1000.0));
        System.out.println("Total samples with hits: " + numSampleHits);
        System.out.println("Total related items: " + numRecommended);
        System.out.println("Total related items with hits: " + numRecommendedHits);
        System.out.println("Precision: " + (1.0 * numRecommendedHits / numRecommended));
        System.out.println("Recall: " + (1.0 * numRecommendedHits / possible));
    }

    public void setTopK(int k) {
        this.k = k;
    }

    static private List makeSet(String line) {
        return new ArrayList(Arrays.asList(line.split(" ")));
    }

    public static List> readBundles(File f) throws IOException {
        List> bundles = new ArrayList>();
        for (String line : FileUtils.readLines(f)) {
            List bundle = new ArrayList();
            for (String token : line.split("\t")) {
                bundle.add(token.trim());
            }
            if (bundle.size() >= 2) {
                bundles.add(bundle);
            }
        }
        return bundles;
    }

    public static void main(String args[]) throws ConfigurationException, IOException {
        Options options = new Options();
        options.addOption(
                new DefaultOptionBuilder()
                        .withLongOpt("bundles")
                        .withDescription("bundle file with tab separated phrases")
                        .hasArg()
                        .create("b"));
        options.addOption(
                new DefaultOptionBuilder()
                        .withLongOpt("topk")
                        .withDescription("number neighbors per phrase")
                        .hasArg()
                        .create("k"));
        EnvBuilder.addStandardOptions(options);

        CommandLineParser parser = new PosixParser();
        CommandLine cmd;
        try {
            cmd = parser.parse(options, args);
        } catch (ParseException e) {
            System.err.println( "Invalid option usage: " + e.getMessage());
            new HelpFormatter().printHelp("PhraseSimEvaluator", options);
            return;
        }

        Env env = new EnvBuilder(cmd).build();

        PhraseSimEvaluator eval = new PhraseSimEvaluator(env);
        List> bundles;

        if (cmd.hasOption("b")) {
            bundles = readBundles(new File(cmd.getOptionValue("b")));
        } else {
            bundles = TEST_BUNDLES;
        }

        if (cmd.hasOption("k")) {
            eval.setTopK(Integer.parseInt(cmd.getOptionValue("k")));
        }
        eval.evaluate(bundles);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy