All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.sr.evaluation.MostSimilarDataset Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.wikibrain.sr.evaluation;

import org.apache.commons.lang3.StringUtils;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.utils.KnownSim;

import java.util.*;

/**
 * Groups similarities for a particular phrases together to form ranked most similar lists.
 *
 * @author Shilad Sen
 */
public class MostSimilarDataset {
    private static final double DEFAULT_THRESHOLD = Double.NEGATIVE_INFINITY;

    private final String name;
    private final Language language;
    private final Map data;

    private MostSimilarDataset(Language language, String name) {
        this.language = language;
        this.name = name;
        this.data = new HashMap();
    }

    /**
     * @see #MostSimilarDataset(java.util.List)
     * @param dataset
     */
    public MostSimilarDataset(Dataset dataset) {
        this(Arrays.asList(dataset));
    }

    /**
     * Creates a new most similar dataset based on some input datasets.
     * KnownSims with similarity less than DEFAULT_THRESHOLD are ignored.
     *
     * @param datasets
     */
    public MostSimilarDataset(List datasets) {
        this(datasets, DEFAULT_THRESHOLD);
    }

    /**
     * Creates a new most similar dataset based on some input datasets.
     * KnownSims with similarity less than threshold are ignored.
     *
     * @param datasets
     */
    public MostSimilarDataset(List datasets, double threshold) {
        if (datasets.isEmpty()) {
            throw new IllegalArgumentException();
        }
        this.language = datasets.get(0).getLanguage();
        Map> sims = new HashMap>();
        List names = new ArrayList();
        for (Dataset ds : datasets) {
            ds.normalize(); // just to be safe
            if (ds.getLanguage() != language) {
                throw new IllegalArgumentException("All datasets must be the same language");
            }
            for (KnownSim ks : ds.getData()) {
                addToMap(sims, ks);
                addToMap(sims, ks.getReversed());
            }
            names.add(ds.getName());
        }
        name = StringUtils.join(names, ",") +
                ((threshold == DEFAULT_THRESHOLD) ? "" : ("+threshold="+threshold));
        data = new HashMap();
        for (String phrase : sims.keySet()) {
            KnownMostSim mostSim = new KnownMostSim(sims.get(phrase), threshold);
            if (mostSim.getMostSimilar().size() > 0) {
                data.put(phrase, mostSim);
            }
        }
    }

    public Set getPhrases() {
        return data.keySet();
    }

    public KnownMostSim getSimilarities(String phrase) {
        return data.get(phrase);
    }

    /**
     * Returns a new dataset that only contains phrases with at least n KnownSim entries.
     * @param n Minimum number of phrases
     * @return
     */
    public MostSimilarDataset pruneSmallLists(int n) {
        MostSimilarDataset pruned = new MostSimilarDataset(language, name + "+pruned=" + n);
        for (String phrase : data.keySet()) {
            if (data.get(phrase).getMostSimilar().size() >= n) {
                pruned.data.put(phrase, data.get(phrase));
            }
        }
        return pruned;
    }

    private void addToMap(Map> sims, KnownSim ks) {
        if (!sims.containsKey(ks.phrase1)) {
            sims.put(ks.phrase1, new ArrayList());
        }
        sims.get(ks.phrase1).add(ks);
    }

    public String getName() {
        return name;
    }

    public Language getLanguage() {
        return language;
    }

    /**
     * Converts the most similar dataset back to a "normal" dataset.
     * @return
     */
    public Dataset toDataset() {
        List sims = new ArrayList();
        for (KnownMostSim kms : data.values()) {
            sims.addAll(kms.getMostSimilar());
        }
        return new Dataset(name, language, sims);
    }

    /**
     * Returns a list of suitable test cross-validation sets.
     * The splits occur along phrases, so all entries for a particular phrase stay in the
     * same cross-validation split.
     * @param n
     * @return
     */
    public List split(int n) {
        List phrases = new ArrayList(data.keySet());
        Collections.shuffle(phrases);
        List result = new ArrayList();
        for (int i = 0; i < n; i++) {
            result.add(new MostSimilarDataset(language, name + "+split-" + i));
        }
        for (int i = 0; i < phrases.size(); i++) {
            String p = phrases.get(i);
            result.get(i % n).data.put(p, data.get(p));
        }
        return result;
    }

    /**
     * @see #split(int)
     * @see #toDataset()
     * @param n
     * @return
     */
    public List splitIntoDatasets(int n) {
        List result = new ArrayList();
        for (MostSimilarDataset msd : split(n)) {
            result.add(msd.toDataset());
        }
        return result;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy