All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.sr.evaluation.SimilarityEvaluator Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.wikibrain.sr.evaluation;

import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.utils.KnownSim;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @see Evaluator
 * @author Shilad Sen
 */
public class SimilarityEvaluator extends Evaluator {
    private static final Logger LOG = LoggerFactory.getLogger(SimilarityEvaluator.class);

    public SimilarityEvaluator(File outputDir) {
        super(outputDir, "local-similarity");
    }

    /**
     * Adds a crossfold validation of a particular dataset.
     * The group of the split is set to the name of the dataset.
     * @param ds
     * @param numFolds
     */
    @Override
    public void addCrossfolds(Dataset ds, int numFolds) {
        List folds = ds.split(numFolds);
        for (int i = 0; i < folds.size(); i++) {
            Dataset test = folds.get(i);
            List trains = new ArrayList(folds);
            trains.remove(i);
            addSplit(new Split(ds.getName() + "-fold-" + i, ds.getName(), new Dataset(trains), test));
        }
    }

    @Override
    public SimilarityEvaluationLog createResults(File path) throws IOException {
        return new SimilarityEvaluationLog(path);
    }

    @Override
    public List getSummaryFields() {
        return Arrays.asList(
                "date",
                "runNumber",
                "lang",
                "metricName",
                "dataset",
                "successful",
                "missing",
                "failed",
                "pearsons",
                "spearmans",
                "resolvePhrases",
                "metricConfig",
                "disambigConfig"
        );
    }

    @Override
    protected SimilarityEvaluationLog evaluateSplit(MonolingualSRFactory factory, Split split, File log, File err, Map config) throws DaoException, IOException {
        SRMetric metric = factory.create();
        metric.trainSimilarity(split.getTrain());
        SimilarityEvaluationLog splitEval = new SimilarityEvaluationLog(config, log);
        BufferedWriter errFile = new BufferedWriter(new FileWriter(err));
        for (KnownSim ks : split.getTest().getData()) {
            try {
                SRResult result;
                if (shouldResolvePhrases()) {
                    result = metric.similarity(ks.wpId1, ks.wpId2, false);
                } else {
                    result = metric.similarity(ks.phrase1, ks.phrase2, false);
                }
                splitEval.record(ks, result);
            } catch (Exception e) {
                LOG.warn("Similarity of " + ks + " failed. Logging error to " + err);
                splitEval.recordFailed(ks);
                errFile.write("KnownSim failed: " + ks + "\n");
                errFile.write("\t" + e.getMessage() + "\n");
                for (String frame : ExceptionUtils.getStackFrames(e)) {
                    errFile.write("\t" + frame + "\n");
                }
                errFile.write("\n");
                errFile.flush();
            }
        }
        IOUtils.closeQuietly(splitEval);
        IOUtils.closeQuietly(errFile);
        return splitEval;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy