All Downloads are FREE. Search and download functionalities are using the official Maven repository.

it.unipi.di.acube.batframework.utils.RunExperiments Maven / Gradle / Ivy

/**
 * (C) Copyright 2012-2013 A-cube lab - Università di Pisa - Dipartimento di Informatica. 
 * BAT-Framework is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
 * BAT-Framework is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 * You should have received a copy of the GNU General Public License along with BAT-Framework.  If not, see .
 */

package it.unipi.di.acube.batframework.utils;

import it.unipi.di.acube.batframework.cache.BenchmarkCache;
import it.unipi.di.acube.batframework.data.*;
import it.unipi.di.acube.batframework.metrics.*;
import it.unipi.di.acube.batframework.problems.*;

import java.util.*;

/**
 * Static methods to run the experiments. A set of annotators are run on a set
 * of datasets, and the metrics are computer according to a set of match
 * relations. The result is written in resulting hash tables.
 * 
 */
public class RunExperiments {

	private static double THRESHOLD_STEP = 1. / 128.;

	public static void computeMetricsA2WFakeReductionToSa2W(
			MatchRelation m,
			A2WSystem tagger,
			A2WDataset ds,
			String precisionFilename,
			String recallFilename,
			String F1Filename,
			WikipediaApiInterface api,
			HashMap>>> results)
			throws Exception {
		Metrics metrics = new Metrics();
		float threshold = 0;
		System.out.print("Doing annotations... ");
		List> computedAnnotations = BenchmarkCache
				.doA2WAnnotations(tagger, ds);
		System.out.println("Done.");
		for (threshold = 0; threshold <= 1; threshold += THRESHOLD_STEP) {
			MetricsResultSet rs = metrics.getResult(computedAnnotations,
					ds.getA2WGoldStandardList(), m);
			updateThresholdRecords(results, m.getName(), tagger.getName(),
					ds.getName(), (float) threshold, rs);
		}
	}

	public static void computeMetricsA2WReducedFromSa2W(
			MatchRelation m,
			Sa2WSystem tagger,
			A2WDataset ds,
			String precisionFilename,
			String recallFilename,
			String F1Filename,
			WikipediaApiInterface api,
			HashMap>>> results)
			throws Exception {
		Metrics metrics = new Metrics();
		System.out.println("Doing annotations... ");
		List> computedAnnotations = BenchmarkCache
				.doSa2WAnnotations(tagger, ds, new AnnotatingCallback() {
					public void run(long msec, int doneDocs, int totalDocs,
							int foundTags) {
						System.out
								.printf("Done %d/%d documents. Found %d annotations/tags so far.%n",
										doneDocs, totalDocs, foundTags);
					}
				}, 60000);
		System.out.println("Done with all documents.");
		for (double threshold = 0; threshold <= 1; threshold += THRESHOLD_STEP) {
			System.out.println("Testing with tagger: " + tagger.getName()
					+ " dataset: " + ds.getName() + " score threshold: "
					+ threshold);
			List> reducedTags = ProblemReduction
					.Sa2WToA2WList(computedAnnotations, (float) threshold);
			MetricsResultSet rs = metrics.getResult(reducedTags,
					ds.getA2WGoldStandardList(), m);
			updateThresholdRecords(results, m.getName(), tagger.getName(),
					ds.getName(), (float) threshold, rs);
		}
	}

	public static void computeMetricsC2WReducedFromSa2W(
			MatchRelation m,
			Sa2WSystem tagger,
			C2WDataset ds,
			WikipediaApiInterface api,
			HashMap>>> results)
			throws Exception {
		Metrics metrics = new Metrics();
		System.out.println("Doing annotations... ");
		List> computedAnnotations = BenchmarkCache
				.doSa2WAnnotations(tagger, ds, new AnnotatingCallback() {
					public void run(long msec, int doneDocs, int totalDocs,
							int foundTags) {
						System.out
								.printf("Done %d/%d documents. Found %d annotations/tags so far.%n",
										doneDocs, totalDocs, foundTags);
					}
				}, 60000);
		System.out.println("Done with all documents.");
		for (double threshold = 0; threshold <= 1; threshold += THRESHOLD_STEP) {
			System.out.println("Testing with tagger: " + tagger.getName()
					+ " dataset: " + ds.getName() + " score threshold: "
					+ threshold);
			List> reducedAnnotations = ProblemReduction
					.Sa2WToA2WList(computedAnnotations, (float) threshold);
			List> reducedTags = ProblemReduction
					.A2WToC2WList(reducedAnnotations);
			List> reducedGs = ds.getC2WGoldStandardList();
			MetricsResultSet rs = metrics.getResult(reducedTags, reducedGs, m);
			updateThresholdRecords(results, m.getName(), tagger.getName(),
					ds.getName(), (float) threshold, rs);
		}
	}

	public static void computeMetricsC2WReducedFromSc2W(
			MatchRelation m,
			Sc2WSystem tagger,
			C2WDataset ds,
			WikipediaApiInterface api,
			HashMap>>> results)
			throws Exception {
		Metrics metrics = new Metrics();
		double threshold = 0;
		System.out.print("Doing annotations... ");
		List> computedAnnotations = BenchmarkCache
				.doSc2WTags(tagger, ds);
		System.out.println("Done.");
		for (threshold = 0; threshold <= 1; threshold += THRESHOLD_STEP) {
			System.out.println("Testing with tagger: " + tagger.getName()
					+ " dataset: " + ds.getName() + " score threshold: "
					+ threshold);
			List> reducedAnnotations = ProblemReduction
					.Sc2WToC2WList(computedAnnotations, (float) threshold);
			List> reducedGs = ds.getC2WGoldStandardList();
			MetricsResultSet rs = metrics.getResult(reducedAnnotations,
					reducedGs, m);
			updateThresholdRecords(results, m.getName(), tagger.getName(),
					ds.getName(), (float) threshold, rs);
		}
	}

	public static void computeMetricsC2W(
			MatchRelation m,
			C2WSystem tagger,
			C2WDataset ds,
			WikipediaApiInterface api,
			HashMap>>> results)
			throws Exception {
		Metrics metrics = new Metrics();
		double threshold = 0;
		System.out.print("Doing annotations... ");
		List> computedAnnotations = BenchmarkCache.doC2WTags(
				tagger, ds);
		System.out.println("Done.");
		System.out.println("Testing with tagger: " + tagger.getName()
				+ " dataset: " + ds.getName() + " (no score thr.)");
		MetricsResultSet rs = metrics.getResult(computedAnnotations,
				ds.getC2WGoldStandardList(), m);
		for (threshold = 0; threshold <= 1; threshold += THRESHOLD_STEP) {
			updateThresholdRecords(results, m.getName(), tagger.getName(),
					ds.getName(), (float) threshold, rs);
		}
	}

	public static void computeMetricsD2WFakeReductionToSa2W(
			D2WSystem tagger,
			D2WDataset ds,
			String precisionFilename,
			String recallFilename,
			String F1Filename,
			WikipediaApiInterface api,
			HashMap>>> results)
			throws Exception {
		Metrics metrics = new Metrics();
		StrongAnnotationMatch m = new StrongAnnotationMatch(api);
		float threshold = 0;
		System.out.print("Doing native D2W annotations... ");
		List> computedAnnotations = BenchmarkCache
				.doD2WAnnotations(tagger, ds, new AnnotatingCallback() {
					public void run(long msec, int doneDocs, int totalDocs,
							int foundTags) {
						System.out
								.printf("Done %d/%d documents. Found %d annotations so far.%n",
										doneDocs, totalDocs, foundTags);
					}
				}, 60000);
		System.out.println("Done with all documents.");
		for (threshold = 0; threshold <= 1; threshold += THRESHOLD_STEP) {
			MetricsResultSet rs = metrics.getResult(computedAnnotations,
					ds.getD2WGoldStandardList(), m);
			updateThresholdRecords(results, m.getName(), tagger.getName(),
					ds.getName(), (float) threshold, rs);
		}
	}

	public static void computeMetricsD2WReducedFromSa2W(
			Sa2WSystem tagger,
			D2WDataset ds,
			String precisionFilename,
			String recallFilename,
			String F1Filename,
			WikipediaApiInterface api,
			HashMap>>> results)
			throws Exception {
		Metrics metrics = new Metrics();
		StrongAnnotationMatch m = new StrongAnnotationMatch(api);
		System.out.println("Doing annotations... ");
		List> computedAnnotations = BenchmarkCache
				.doSa2WAnnotations(tagger, ds, new AnnotatingCallback() {
					public void run(long msec, int doneDocs, int totalDocs,
							int foundTags) {
						System.out
								.printf("Done %d/%d documents. Found %d annotations/tags so far.%n",
										doneDocs, totalDocs, foundTags);
					}
				}, 60000);
		System.out.println("Done with all documents.");
		System.out
				.printf("Testing with tagger: %s, dataset: %s, for values of the score threshold in [0,1].%n",
						tagger.getName(), ds.getName());
		for (double threshold = 0; threshold <= 1; threshold += THRESHOLD_STEP) {

			List> reducedAnns = ProblemReduction
					.Sa2WToD2WList(computedAnnotations,
							ds.getMentionsInstanceList(), (float) threshold);
			MetricsResultSet rs = metrics.getResult(reducedAnns,
					ds.getD2WGoldStandardList(), m);
			updateThresholdRecords(results, m.getName(), tagger.getName(),
					ds.getName(), (float) threshold, rs);
		}
	}

	public static HashMap>>> performC2WExpVarThreshold(
			Vector> matchRels,
			Vector a2wAnnotators, Vector sa2wAnnotators,
			Vector sc2wTaggers,
			Vector c2wTaggers, Vector dss,
			WikipediaApiInterface api) throws Exception {
		HashMap>>> result = new HashMap>>>();
		for (MatchRelation m : matchRels)
			for (C2WDataset ds : dss) {
				System.out.println("Testing " + ds.getName()
						+ " with score threshold parameter...");

				if (sa2wAnnotators != null)
					for (Sa2WSystem t : sa2wAnnotators) {
						computeMetricsC2WReducedFromSa2W(m, t, ds, api, result);
						BenchmarkCache.flush();
					}

				if (sc2wTaggers != null)
					for (Sc2WSystem t : sc2wTaggers) {
						computeMetricsC2WReducedFromSc2W(m, t, ds, api, result);
						BenchmarkCache.flush();
					}
				if (c2wTaggers != null)
					for (C2WSystem t : c2wTaggers) {
						computeMetricsC2W(m, t, ds, api, result);
						BenchmarkCache.flush();
					}

				System.out.println("Flushing Wikipedia API cache...");
				api.flush();
			}
		return result;
	}

	public static HashMap>>> performA2WExpVarThreshold(
			Vector> metrics,
			Vector a2wTaggers, Vector sa2wTaggers,
			Vector dss, WikipediaApiInterface api) throws Exception {
		HashMap>>> result = new HashMap>>>();
		for (MatchRelation metric : metrics) {
			for (A2WDataset ds : dss) {
				if (sa2wTaggers != null)
					for (Sa2WSystem t : sa2wTaggers) {
						System.out.println("Testing " + ds.getName() + " on "
								+ t.getName()
								+ " with score threshold parameter...");
						String prefix = metric.getName()
								.replaceAll("[^a-zA-Z0-9]", "").toLowerCase();
						String suffix = t.getName()
								.replaceAll("[^a-zA-Z0-9]", "").toLowerCase()
								+ "_"
								+ ds.getName().replaceAll("[^a-zA-Z0-9]", "")
										.toLowerCase() + ".dat";
						computeMetricsA2WReducedFromSa2W(metric, t, ds, prefix
								+ "_precision_threshold_" + suffix, prefix
								+ "_recall_threshold_" + suffix, prefix
								+ "_f1_threshold_" + suffix, api, result);
						BenchmarkCache.flush();
					}

				if (a2wTaggers != null)
					for (A2WSystem t : a2wTaggers) {
						System.out.println("Testing " + ds.getName() + " on "
								+ t.getName()
								+ " with score threshold parameter...");
						String prefix = metric.getName()
								.replaceAll("[^a-zA-Z0-9]", "").toLowerCase();
						String suffix = t.getName()
								.replaceAll("[^a-zA-Z0-9]", "").toLowerCase()
								+ "_"
								+ ds.getName().replaceAll("[^a-zA-Z0-9]", "")
										.toLowerCase() + ".dat";
						computeMetricsA2WFakeReductionToSa2W(metric, t, ds,
								prefix + "_precision_threshold_" + suffix,
								prefix + "_recall_threshold_" + suffix, prefix
										+ "_f1_threshold_" + suffix, api,
								result);
						BenchmarkCache.flush();
					}

				System.out.println("Flushing Wikipedia API cache...");
				api.flush();
			}
		}
		return result;
	}

	public static HashMap>>> performD2WExpVarThreshold(
			Vector d2wAnnotators, Vector sa2wAnnotators,
			Vector dss, WikipediaApiInterface api) throws Exception {
		HashMap>>> result = new HashMap>>>();
		MatchRelation sam = new StrongAnnotationMatch(api);
		for (D2WDataset ds : dss) {
			if (sa2wAnnotators != null)
				for (Sa2WSystem t : sa2wAnnotators) {
					System.out.println("Testing " + ds.getName() + " on "
							+ t.getName()
							+ " with score threshold parameter...");
					String prefix = sam.getName()
							.replaceAll("[^a-zA-Z0-9]", "").toLowerCase();
					String suffix = t.getName().replaceAll("[^a-zA-Z0-9]", "")
							.toLowerCase()
							+ "_"
							+ ds.getName().replaceAll("[^a-zA-Z0-9]", "")
									.toLowerCase() + ".dat";
					computeMetricsD2WReducedFromSa2W(t, ds, prefix
							+ "_precision_threshold_" + suffix, prefix
							+ "_recall_threshold_" + suffix, prefix
							+ "_f1_threshold_" + suffix, api, result);
					BenchmarkCache.flush();
				}
			if (d2wAnnotators != null)
				for (D2WSystem t : d2wAnnotators) {
					System.out.println("Testing " + ds.getName() + " on "
							+ t.getName()
							+ " with score threshold parameter...");
					String prefix = sam.getName()
							.replaceAll("[^a-zA-Z0-9]", "").toLowerCase();
					String suffix = t.getName().replaceAll("[^a-zA-Z0-9]", "")
							.toLowerCase()
							+ "_"
							+ ds.getName().replaceAll("[^a-zA-Z0-9]", "")
									.toLowerCase() + ".dat";
					computeMetricsD2WFakeReductionToSa2W(t, ds, prefix
							+ "_precision_threshold_" + suffix, prefix
							+ "_recall_threshold_" + suffix, prefix
							+ "_f1_threshold_" + suffix, api, result);
					BenchmarkCache.flush();
				}

			System.out.println("Flushing Wikipedia API cache...");
			api.flush();
		}
		return result;
	}

	private static void updateThresholdRecords(
			HashMap>>> threshRecords,
			String metricsName, String taggerName, String datasetName,
			float threshold, MetricsResultSet rs) {
		HashMap>> bestThreshold;
		if (!threshRecords.containsKey(metricsName))
			threshRecords
					.put(metricsName,
							new HashMap>>());
		bestThreshold = threshRecords.get(metricsName);

		HashMap> firstLevel;
		if (!bestThreshold.containsKey(taggerName))
			bestThreshold.put(taggerName,
					new HashMap>());
		firstLevel = bestThreshold.get(taggerName);

		HashMap secondLevel;
		if (!firstLevel.containsKey(datasetName))
			firstLevel.put(datasetName, new HashMap());
		secondLevel = firstLevel.get(datasetName);

		// populate the hash table with the new record.
		secondLevel.put(threshold, rs);
	}

	public static Pair getBestRecord(
			HashMap>>> threshResults,
			String metricsName, String taggerName, String datasetName) {
		HashMap records = threshResults
				.get(metricsName).get(taggerName).get(datasetName);
		List thresholds = new Vector(records.keySet());
		Collections.sort(thresholds);
		Pair bestRecord = null;
		for (Float t : thresholds)
			if (bestRecord == null
					|| records.get(t).getMacroF1() > bestRecord.second
							.getMacroF1())
				bestRecord = new Pair(t,
						records.get(t));
		return bestRecord;
	}

	public static HashMap getRecords(
			HashMap>>> threshResults,
			String metricsName, String taggerName, String datasetName) {
		HashMap records = threshResults
				.get(metricsName).get(taggerName).get(datasetName);
		return records;

	}

	public static MetricsResultSet performMentionSpottingExp(
			MentionSpotter spotter, D2WDataset ds) throws Exception {
		List> output = BenchmarkCache.doSpotMentions(spotter,
				ds);
		Metrics metrics = new Metrics();
		return metrics.getResult(output, ds.getMentionsInstanceList(),
				new MentionMatch());
	}

	public static void performMentionSpottingExp(
			List mentionSpotters, List dss)
			throws Exception {
		for (MentionSpotter spotter : mentionSpotters) {
			for (D2WDataset ds : dss) {
				System.out.println("Testing Spotter " + spotter.getName()
						+ " on dataset " + ds.getName());
				System.out.println("Doing spotting... ");
				MetricsResultSet rs = performMentionSpottingExp(spotter, ds);
				System.out.println("Done with all documents.");

				System.out.printf("%s / %s%n%s%n%n", spotter.getName(),
						ds.getName(), rs);
			}
		}

	}

	public static MetricsResultSet performCandidateSpottingExp(
			CandidatesSpotter spotter, D2WDataset dss, WikipediaApiInterface api)
			throws Exception {
		Metrics metrics = new Metrics();

		List> gold = annotationToMulti(dss
				.getD2WGoldStandardList());

		List> output = new Vector>();
		for (String text : dss.getTextInstanceList())
			output.add(spotter.getSpottedCandidates(text));

		// Filter system annotations so that only those contained in the dataset
		// AND in the output are taken into account.
		output = mentionSubstraction(output, gold);
		gold = mentionSubstraction(gold, output);

		return metrics.getResult(output, gold, new MultiEntityMatch(api));

	}

	public static Integer[] candidateCoverageDistributionExp(
			CandidatesSpotter spotter, D2WDataset dss, WikipediaApiInterface api)
			throws Exception {
		List> gold = annotationToMulti(dss
				.getD2WGoldStandardList());

		List> output = new Vector>();
		for (String text : dss.getTextInstanceList())
			output.add(spotter.getSpottedCandidates(text));

		output = mentionSubstraction(output, gold);
		gold = mentionSubstraction(gold, output);

		Vector positions = new Vector<>();
		for (int i = 0; i < output.size(); i++) {
			HashSet outI = output.get(i);
			HashSet goldI = gold.get(i);
			for (MultipleAnnotation outAnn : outI)
				for (MultipleAnnotation goldAnn : goldI)
					if (outAnn.overlaps(goldAnn)) {
						int goldCand = goldAnn.getCandidates()[0];
						int candIdx = 0;
						for (; candIdx < outAnn.getCandidates().length; candIdx++)
							if (outAnn.getCandidates()[candIdx] == goldCand) {
								positions.add(candIdx);
								break;
							}
						if (candIdx == outAnn.getCandidates().length)
							positions.add(-1);
					}
		}

		return positions.toArray(new Integer[positions.size()]);

	}

	private static  List> mentionSubstraction(
			List> list1, List> list2) {
		List> list1filtered = new Vector>();
		for (int i = 0; i < list1.size(); i++) {
			HashSet filtered1 = new HashSet();
			list1filtered.add(filtered1);
			for (T a : list1.get(i)) {
				boolean found = false;
				for (T goldA : list2.get(i))
					if (a.getPosition() == goldA.getPosition()
							&& a.getLength() == goldA.getLength()) {
						found = true;
						break;
					}
				if (found)
					filtered1.add(a);
			}
		}
		return list1filtered;
	}

	private static List> annotationToMulti(
			List> d2wGoldStandardList) {
		List> res = new Vector>();
		for (HashSet annSet : d2wGoldStandardList) {
			HashSet multiAnn = new HashSet();
			res.add(multiAnn);
			for (Annotation a : annSet)
				multiAnn.add(new MultipleAnnotation(a.getPosition(), a
						.getLength(), new int[] { a.getConcept() }));
		}
		return res;
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy