All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.sr.MetricTrainer Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.wikibrain.sr;

import org.apache.commons.cli.*;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.io.FileUtils;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.conf.DefaultOptionBuilder;
import org.wikibrain.core.WikiBrainException;
import org.wikibrain.core.cmd.Env;
import org.wikibrain.core.cmd.EnvBuilder;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.LanguageSet;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.dataset.DatasetDao;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

/**
 * @author Matt Lesciko
 * @author Ben Hillmann
 */
public class MetricTrainer {

    public static void main(String[] args) throws ConfigurationException, DaoException, IOException, WikiBrainException {
        Options options = new Options();

        //Number of Max Results(otherwise take from config)
        options.addOption(
                new DefaultOptionBuilder()
                        .hasArg()
                        .withLongOpt("max-results")
                        .withDescription("maximum number of results")
                        .create("r"));
        //Specify the Datasets
        options.addOption(
                new DefaultOptionBuilder()
                        .hasArgs()
                        .withLongOpt("gold")
                        .withDescription("the set of gold standard datasets to train on")
                        .create("g"));
        //Specify the Metrics
        options.addOption(
                new DefaultOptionBuilder()
                        .hasArg()
                        .withLongOpt("metric")
                        .withDescription("set a local metric")
                        .create("m"));

        EnvBuilder.addStandardOptions(options);


        CommandLineParser parser = new PosixParser();
        CommandLine cmd;
        try {
            cmd = parser.parse(options, args);
        } catch (ParseException e) {
            System.err.println("Invalid option usage: " + e.getMessage());
            new HelpFormatter().printHelp("MetricTrainer", options);
            return;
        }

        Env env = new EnvBuilder(cmd)
                        .setProperty("sr.metric.training", true)
                        .build();
        Configurator c = env.getConfigurator();

        if (!cmd.hasOption("m")&&!cmd.hasOption("u")){
            System.err.println("Must specify a metric to train using -m or -u.");
            new HelpFormatter().printHelp("MetricTrainer", options);
            return;
        }
        int maxResults = cmd.hasOption("r")? Integer.parseInt(cmd.getOptionValue("r")) : c.getConf().get().getInt("sr.normalizer.defaultmaxresults");

        String path = c.getConf().get().getString("sr.metric.path");
        LanguageSet allLangs = env.getLanguages();

        DatasetDao datasetDao = env.getConfigurator().get(DatasetDao.class);

        List datasetNames;
        if (cmd.hasOption("g")){
            datasetNames = Arrays.asList(cmd.getOptionValues("g"));
        } else {
            datasetNames = c.getConf().get().getStringList("sr.dataset.defaultsets");
        }

        List datasets = new ArrayList();
        for (String name : datasetNames) {
            DatasetDao.Info info = datasetDao.getInfo(name);
            Collection possibleLang = CollectionUtils.intersection(
                                                        info.getLanguages().getLanguages(),
                                                        allLangs.getLanguages());
            if (possibleLang.isEmpty()) {
                System.err.println("dataset " + name + " is a language other than " + allLangs);
                System.exit(1);
            }
            if (possibleLang.size() > 1) {
                System.err.println("dataset " + name + " supports more than one language of " + allLangs + " please specify");
                System.exit(1);
            }
            Language lang = possibleLang.iterator().next();
            if (datasets.size() > 0 && !lang.equals(datasets.get(0).getLanguage())) {
                System.err.println("Language mismatch in datasets " + name + " and " + datasets.get(0).getName());
                System.exit(1);
            }
            datasets.add(datasetDao.get(lang, name));
        }

        SRMetric sr=null;
        if (cmd.hasOption("m")){
            Language language = datasets.get(0).getLanguage();
            FileUtils.deleteDirectory(new File(path+cmd.getOptionValue("m")+"/"+"normalizer/"));
            sr = c.get(SRMetric.class,cmd.getOptionValue("m"), "language", language.getLangCode());
        }

        Dataset dataset = new Dataset(datasets);
        if (sr!=null){
            sr.trainMostSimilar(dataset, maxResults, null);
            sr.trainSimilarity(dataset);
            sr.write();
            sr.read();
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy