All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.sr.milnewitten.MilneWittenDisambiguator Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.wikibrain.sr.milnewitten;

import com.google.common.collect.Maps;
import com.typesafe.config.Config;
import com.typesafe.config.ConfigValueFactory;
import org.apache.commons.collections.CollectionUtils;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.DaoFilter;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.LocalId;
import org.wikibrain.core.lang.LocalString;
import org.wikibrain.core.model.NameSpace;
import org.wikibrain.phrases.PhraseAnalyzer;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.utils.WbMathUtils;
import org.wikibrain.utils.WpCollectionUtils;

import java.util.*;

/**
 * @author Shilad Sen
 *
 * A literal implementation of the disambiguator described in
 * http://www.aaai.org/Papers/Workshops/2008/WS-08-15/WS08-15-005.pdf
 *
 * This is naively extended to support multiple strings and contexts.
 */
public class MilneWittenDisambiguator extends Disambiguator {
    private final Language language;
    private final LocalPageDao pageDao;
    private final PhraseAnalyzer analyzer;
    private final SRMetric metric;
    private final int numPages;

    /**
     * Construct a new disambiguator that uses a particular metric.
     *
     * @param pageDao
     * @param analyzer
     * @param metric
     * @throws DaoException
     */
    public MilneWittenDisambiguator(LocalPageDao pageDao, PhraseAnalyzer analyzer, SRMetric metric) throws DaoException {
        this.language = metric.getLanguage();
        this.pageDao = pageDao;
        this.analyzer = analyzer;
        this.metric = metric;
        this.numPages = pageDao.getCount(
                new DaoFilter().setLanguages(language)
                        .setNameSpaces(NameSpace.ARTICLE)
                        .setRedirect(false)
                        .setDisambig(false));
    }

    /**
     * Given
     * @param phrases   The target phrases being disambiguated.
     * @param context  Other phrases (in the same language as the target phrase)
     *                 related to the target phrase being disambiguated that may
     *                 aid disambiguation.
     * @return
     * @throws DaoException
     */
    @Override
    public List> disambiguate(List phrases, Set context) throws DaoException {

        // Collect all phrases
        List allPhrases = new ArrayList(
                (context == null) ? phrases : CollectionUtils.union(phrases, context));

        // Step 1: calculate most frequent candidate senses for each phrase
        Map> candidates = Maps.newHashMap();
        for (LocalString s : allPhrases) {
            if (!s.getLanguage().equals(language)) {
                throw new IllegalArgumentException("Disambiguator only supports language " + language);
            }
            candidates.put(s, analyzer.resolve(s.getLanguage(), s.getString(), 100));
        }

        // Step 1.5: Build mapping from local id to phrases that are candidates
        Map> idsToPhrases = new HashMap>();
        for (LocalString s : candidates.keySet()) {
            for (LocalId lid : candidates.get(s).keySet()) {
                if (!idsToPhrases.containsKey(lid)) {
                    idsToPhrases.put(lid, new HashSet());
                }
                idsToPhrases.get(lid).add(s);
            }
        }

        // Step 2: calculate the sum of cosimilarities for each page
        Map pageSims = getCosimilaritySums(candidates);

        // Step 3: Choose the best options for each phrase
        List> result = new ArrayList>();
        for (LocalString p : phrases) {
            result.add(
                    disambiguateOnePhrase(p, candidates.get(p), idsToPhrases, pageSims));
        }
        return result;
    }

    private LinkedHashMap disambiguateOnePhrase(LocalString phrase, LinkedHashMap candidates, Map> idsToPhrases, Map pageSims) throws DaoException {
        // Identify the highest similarity for each page
        float maxSimilarity = Float.NEGATIVE_INFINITY;
        for (LocalId lid : candidates.keySet()) {
            maxSimilarity = Math.max(maxSimilarity, pageSims.get(lid));
        }

        // Identify the most popular senses within 40% of the top similarity
        Map scores = new HashMap();
        double scoreSum = 0.0;
        for (LocalId lid : candidates.keySet()) {
            double sim = pageSims.get(lid);
            if (sim < 0.4 * maxSimilarity) {
                continue;
            }
            double pop = candidates.get(lid);

            double phraseBonus = 0.0;
            /*

            TODO: figure out phraseBonus for multiple phrases

            int numPhrases = 0;
            int sumCounts = 0;
            for (LocalString ls2 : idsToPhrases.get(lid)) {
                numPhrases++;
                sumCounts += getPhraseCount(phrase.getString() + " " + ls2.getString());
                sumCounts += getPhraseCount(ls2.getString() + " " + phrase.getString());
            }

            if (sumCounts > 0) {
                int maxExpectedValue = numPhrases * numPages / 50;
                phraseBonus = Math.log(sumCounts + 1) / Math.log(maxExpectedValue);
                phraseBonus = Math.min(0.5, phraseBonus);
            }

            System.err.println("phrase bonus for " + phrase + ", " + idsToPhrases.get(lid) + " is " + phraseBonus);
                         */

            double score = pop + phraseBonus;
            scores.put(lid, (float) score);
            scoreSum += score;

        }
        LinkedHashMap pageResult = new LinkedHashMap();
        for (LocalId key : WpCollectionUtils.sortMapKeys(scores, true)) {
            pageResult.put(key, (float) (scores.get(key) / scoreSum));
        }
        return pageResult;
    }

    /*
    private int getPhraseCount(String phrase) throws DaoException {
        PrunedCounts pages = analyzer.getDao().getPhraseCounts(language, phrase, 1);
        if (pages == null) {
            return 0;
        } else {
            return pages.getTotal();
        }
    }
    */


    /**
     * Return the sum of cosimilarity scores for all unique pages among the candidates.
     * @param candidates
     * @return
     * @throws DaoException
     */
    private Map getCosimilaritySums(Map> candidates) throws DaoException {

        // Step 1: compute the page cosimilarity matrix
        Set uniques = new HashSet();
        for (LinkedHashMap prob : candidates.values()) {
            uniques.addAll(prob.keySet());
        }
        List pages = new ArrayList(uniques);
        double[][] cosim;

        if (pages.isEmpty()){
            cosim = new double[0][0];
        } else {
            int[] pageIds = new int[pages.size()];
            for (int i=0; i pageSims = new HashMap();
        for (int i = 0; i < pages.size(); i++) {
            double sum = 0.0;
            for (int j = 0; j < pages.size(); j++) {
                if (i != j && WbMathUtils.isReal(cosim[i][j])) {
                    sum += Math.max(0, cosim[i][j]);    // Hack: no negative numbers
                }
            }
            // add 0.0001 to give every candidate a tiny chance and avoid divide by zero errors when there are no good options
            pageSims.put(pages.get(i), (float)(sum + 0.0001));
        }
        return pageSims;
    }

    public static class Provider extends org.wikibrain.conf.Provider{
        public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
            super(configurator,config);
        }

        @Override
        public Class getType(){
            return Disambiguator.class;
        }

        @Override
        public String getPath(){
            return "sr.disambig";
        }

        @Override
        public Disambiguator get(String name, Config config, Map runtimeParams) throws ConfigurationException{
            if (!config.getString("type").equals("milnewitten")){
                return null;
            }

            if (runtimeParams == null || !runtimeParams.containsKey("language")){
                throw new IllegalArgumentException("SimpleMilneWitten requires 'language' runtime parameter.");
            }
            Language lang = Language.getByLangCode(runtimeParams.get("language"));
            PhraseAnalyzer pa = getConfigurator().get(PhraseAnalyzer.class, config.getString("phraseAnalyzer"));
            LocalPageDao pageDao = getConfigurator().get(LocalPageDao.class);

            // Create override config for sr metric and load it.
            String srName = config.getString("metric");
            Config newConfig = getConfig().get().getConfig("sr.metric.local." + srName)
                    .withValue("disambiguator", ConfigValueFactory.fromAnyRef("topResult"));
            Map srRuntimeParams = new HashMap();
            srRuntimeParams.put("language", lang.getLangCode());
            SRMetric sr = getConfigurator().construct(SRMetric.class, srName, newConfig, srRuntimeParams);

            try {
                return new MilneWittenDisambiguator(pageDao, pa, sr);
            } catch (DaoException e) {
                throw new ConfigurationException(e);
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy