All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.sr.synrank.SynRank Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.wikibrain.sr.synrank;

import com.typesafe.config.Config;
import gnu.trove.list.TIntList;
import gnu.trove.list.linked.TIntLinkedList;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.*;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.LocalLink;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.core.model.NameSpace;
import org.wikibrain.sr.BaseSRMetric;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.disambig.Disambiguator;

import java.util.Arrays;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * An implementation of the SynRank SR metric as described in
 *
 * G. Quercini, H. Samet Uncovering the spatial relatedness in Wikipedia.
 * In Y. Huang, M. Gertz, J. C. Krumm, J. Sankaranarayanan, and M. Schneider, editors,
 * Proceedings of SIGSPATIAL 2014.
 *
 * Soon to be linked at http://www.cs.umd.edu/~hjs/hjsyear.html
 *
 * @author Shilad Sen
 */
public class SynRank extends BaseSRMetric {
    private static final Logger LOG = LoggerFactory.getLogger(SynRank.class);

    private final LocalLinkDao linkDao;
    private final int numArticles;
    private final TIntHashSet dabs;

    public SynRank(String name, Language language, LocalPageDao pageDao, Disambiguator dab, LocalLinkDao linkDao, LocalCategoryMemberDao catDao) throws DaoException {
        super(name, language, pageDao, dab);
        this.linkDao = linkDao;
        LOG.info("calculating number of articles");
        this.numArticles = pageDao.getCount(DaoFilter.normalPageFilter(language));
        LOG.info("found " + this.numArticles + " articles");
        this.dabs = new TIntHashSet();

        // TODO: fix dabs!
        LocalPage dabCat = null;
        for (String title: Arrays.asList("Category:Disambiguation pages", "Category:Disambiguation")) {
            dabCat = pageDao.getByTitle(getLanguage(), NameSpace.CATEGORY, title);
            if (dabCat != null) break;
        }
        if (dabCat == null) throw new IllegalArgumentException();
        for (int id : catDao.getCategoryMemberIds(dabCat)) {
            dabs.add(id);
        }
        LOG.info("identified " + dabs.size() + " disambiguation pages");
    }

    @Override
    public SRConfig getConfig() {
        SRConfig config = new SRConfig();
        config.maxScore = 1.1f;
        config.minScore = 0;
        return config;
    }

    @Override
    public SRResult similarity(int pageId1, int pageId2, boolean explanations) throws DaoException {
        if (pageId1 == pageId2) {
            return new SRResult(1.0);
        }
        if (dabs.contains(pageId1) || dabs.contains(pageId2)) {
            return null;
        }

        TIntSet ids1 = getLinksTo(pageId1);
        TIntSet ids2 = getLinksTo(pageId2);
        TIntSet both = new TIntHashSet(ids1);

        both.retainAll(ids2);
        if (both.isEmpty()) {
            return new SRResult(0.0);
        }

        double pmi = 1.0 * numArticles * both.size() / (ids1.size() * ids2.size());
        double boost = Math.log10(both.size());
        double graphDistance = graphDistance(pageId1, pageId2);
//        String t1 = getLocalPageDao().getById(getLanguage(), pageId1).getTitle().toString();
//        String t2 = getLocalPageDao().getById(getLanguage(), pageId2).getTitle().toString();
//        System.err.println(String.format("Values for %s, %s, are %.4f, %.4f, %.4f\n", t1, t2, pmi, boost, graphDistance));

        return new SRResult(pmi * boost / graphDistance);
    }

    private TIntSet getLinksTo(int pageId) throws DaoException {
        TIntSet ids = new TIntHashSet();
        for (LocalLink ll : linkDao.get(new DaoFilter().setDestIds(pageId).setLanguages(getLanguage()))) {
            ids.add(ll.getSourceId());
        }
        return ids;
    }

    private TIntSet getLinksFrom(int pageId) throws DaoException {
        TIntSet ids = new TIntHashSet();
        for (LocalLink ll : linkDao.get(new DaoFilter().setSourceIds(pageId).setLanguages(getLanguage()))) {
            ids.add(ll.getDestId());
        }
        return ids;
    }

    private int graphDistance(int pageId1, int pageId2) throws DaoException {
        if (getLinksTo(pageId1).contains(pageId2) || getLinksFrom(pageId1).contains(pageId2)) {
            return 1;
        } else {
            return 2;
        }
    }

    @Override
    public SRResultList mostSimilar(int pageId, int maxResults, TIntSet validIds) throws DaoException {
        throw new UnsupportedOperationException();
    }


    public static class Provider extends org.wikibrain.conf.Provider {
        public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
            super(configurator, config);
        }

        @Override
        public Class getType() {
            return SRMetric.class;
        }

        @Override
        public String getPath() {
            return "sr.metric.local";
        }

        @Override
        public SRMetric get(String name, Config config, Map runtimeParams) throws ConfigurationException {
            if (!config.getString("type").equals("synrank")) {
                return null;
            }
            if (runtimeParams == null || !runtimeParams.containsKey("language")) {
                throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
            }
            Language language = Language.getByLangCode(runtimeParams.get("language"));
            Disambiguator dab = getConfigurator().get(Disambiguator.class, config.getString("disambiguator"), "language", language.getLangCode());
            try {
                SynRank sr = new SynRank(
                        name,
                        language,
                        getConfigurator().get(LocalPageDao.class),
                        dab,
                        getConfigurator().get(LocalLinkDao.class),
                        getConfigurator().get(LocalCategoryMemberDao.class)
                    );
                configureBase(getConfigurator(), sr, config);
                return sr;
            } catch (DaoException e) {
                throw new ConfigurationException(e);
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy