All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.sr.vector.ESAGenerator Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.wikibrain.sr.vector;

import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.lucene.LuceneSearcher;
import org.wikibrain.lucene.QueryBuilder;
import org.wikibrain.lucene.WikiBrainScoreDoc;
import org.wikibrain.lucene.WpIdFilter;
import org.wikibrain.sr.Explanation;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.utils.Leaderboard;
import org.wikibrain.sr.utils.SimUtils;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author Shilad Sen
 */
public class ESAGenerator implements SparseVectorGenerator {

    private static final Logger LOG = LoggerFactory.getLogger(ESAGenerator.class);

    private final LuceneSearcher searcher;
    private final Language language;
    private final LocalPageDao pageDao;

    private WpIdFilter conceptFilter = null;
    private TIntSet blackListSet;
    private final String blackListFilePath;

    public ESAGenerator(Language language, LocalPageDao pageDao, LuceneSearcher searcher, String blackListFilePath) {
        this.language = language;
        this.pageDao = pageDao;
        this.searcher = searcher;
        this.blackListFilePath = blackListFilePath;
        try{
            createBlackListSet();
        } catch (Exception e){
            LOG.info("Could not create Blacklist Set");
        }
    }

    private void createBlackListSet() throws FileNotFoundException {
        blackListSet = new TIntHashSet();
        if(blackListFilePath == null || blackListFilePath.equals("")) {
            LOG.info("Skipping blacklist creation; no blacklist file specified.");
            return;
        }

        File file = new File(blackListFilePath);
        Scanner scanner = new Scanner(file);
        while(scanner.hasNext()){
            blackListSet.add(scanner.nextInt());
        }

        scanner.close();
    }


    @Override
    public TIntFloatMap getVector(int pageId) throws DaoException {
        int luceneId = searcher.getDocIdFromLocalId(pageId, language);
        if (luceneId < 0) {
            LOG.warn("Unindexed document " + pageId + " in " + language.getEnLangName());
            return new TIntFloatHashMap();
        }
        WikiBrainScoreDoc[] wikibrainScoreDocs =  getQueryBuilder()
                .setMoreLikeThisQuery(luceneId)
                .search();
        wikibrainScoreDocs = pruneSimilar(wikibrainScoreDocs);
        return SimUtils.normalizeVector(expandScores(wikibrainScoreDocs));

    }

    @Override
    public TIntFloatMap getVector(String phrase) {
        QueryBuilder builder = getQueryBuilder().setPhraseQuery(phrase);
        if (builder.hasQuery()) {
            WikiBrainScoreDoc[] scoreDocs = builder.search();
            scoreDocs = SimUtils.pruneSimilar(scoreDocs);
            return SimUtils.normalizeVector(expandScores(scoreDocs));
        } else {
            LOG.warn("Phrase cannot be parsed to get a query. "+phrase);
            return null;
        }
    }

    public void setConcepts(File file) throws IOException {
        conceptFilter = null;
        if (!file.isFile()) {
            LOG.warn("concept path " + file + " not a file; defaulting to all concepts");
            return;
        }
        TIntSet ids = new TIntHashSet();
        for (String wpId : FileUtils.readLines(file)) {
            int wpLocalIDNumb= Integer.valueOf(wpId);
            if(!isBlacklisted(wpLocalIDNumb)) {
                ids.add(wpLocalIDNumb);
            }
        }
        conceptFilter = new WpIdFilter(ids.toArray());
        LOG.warn("installed " + ids.size() + " concepts for " + language);
    }

    private boolean isBlacklisted(int wpLocalIDNumb) {
        return blackListSet.contains(wpLocalIDNumb);
    }

    @Override
    public List getExplanations(int pageID1, int pageID2, TIntFloatMap vector1, TIntFloatMap vector2, SRResult result) throws DaoException {
        LocalPage page1=pageDao.getById(language,pageID1);
        LocalPage page2=pageDao.getById(language,pageID2);
        Leaderboard lb = new Leaderboard(5);    // TODO: make 5 configurable
        for (int id : vector1.keys()) {
            if (vector2.containsKey(id)) {
                lb.tallyScore(id, vector1.get(id) * vector2.get(id));
            }
        }
        SRResultList top = lb.getTop();
        if (top.numDocs() == 0) {
            return Arrays.asList(new Explanation("? and ? share no links", page1, page2));
        }

        List explanations = new ArrayList();
        for (int i = 0; i < top.numDocs(); i++) {
            LocalPage p = pageDao.getById(language, top.getId(i));
            if (p != null) {
                explanations.add(new Explanation("Both ? and ? have similar text to ?", page1, page2, p));
            }
        }
        return explanations;
    }

    @Override
    public List getExplanations(String phrase1, String phrase2, TIntFloatMap vector1, TIntFloatMap vector2, SRResult result) throws DaoException {
        Leaderboard lb = new Leaderboard(5);    // TODO: make 5 configurable
        for (int id : vector1.keys()) {
            if (vector2.containsKey(id)) {
                lb.tallyScore(id, vector1.get(id) * vector2.get(id));
            }
        }
        SRResultList top = lb.getTop();
        if (top.numDocs() == 0) {
            return Arrays.asList(new Explanation("? and ? share no tags", phrase1, phrase2));
        }

        List explanations = new ArrayList();
        for (int i = 0; i < top.numDocs(); i++) {
            LocalPage p = pageDao.getById(language, searcher.getLocalIdFromDocId(top.getId(i), language));
            if (p != null) {
                explanations.add(new Explanation("Both ? and ? have similar text to ?", phrase1, phrase2, p));
            }
        }
        return explanations;
    }

    private QueryBuilder getQueryBuilder() {
        QueryBuilder builder = searcher.getQueryBuilderByLanguage(language);
        builder.setResolveWikipediaIds(false);
        if (conceptFilter != null) {
            builder.addFilter(conceptFilter);
        }
        return builder;
    }

    /**
     * Put data in a scoreDoc into a TIntDoubleHashMap
     *
     * @param wikibrainScoreDocs
     * @return
     */
    private TIntFloatMap expandScores(WikiBrainScoreDoc[] wikibrainScoreDocs) {
        TIntFloatMap expanded = new TIntFloatHashMap();
        for (WikiBrainScoreDoc wikibrainScoreDoc : wikibrainScoreDocs) {
            expanded.put(wikibrainScoreDoc.luceneId, wikibrainScoreDoc.score);
        }
        return expanded;
    }

    /**
     * Prune a WikiBrainScoreDoc array.
     * @param wikibrainScoreDocs array of WikiBrainScoreDoc
     */
    private WikiBrainScoreDoc[] pruneSimilar(WikiBrainScoreDoc[] wikibrainScoreDocs) {
        if (wikibrainScoreDocs.length == 0) {
            return wikibrainScoreDocs;
        }
        int cutoff = wikibrainScoreDocs.length;
        double threshold = 0.005 * wikibrainScoreDocs[0].score;
        for (int i = 0, j = 100; j < wikibrainScoreDocs.length; i++, j++) {
            float delta = wikibrainScoreDocs[i].score - wikibrainScoreDocs[j].score;
            if (delta < threshold) {
                cutoff = j;
                break;
            }
        }
        if (cutoff < wikibrainScoreDocs.length) {
            wikibrainScoreDocs = ArrayUtils.subarray(wikibrainScoreDocs, 0, cutoff);
        }
        return wikibrainScoreDocs;
    }

    public static class Provider extends org.wikibrain.conf.Provider {
        public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
            super(configurator, config);
        }

        @Override
        public Class getType() {
            return SparseVectorGenerator.class;
        }

        @Override
        public String getPath() {
            return "sr.metric.sparsegenerator";
        }

        @Override
        public SparseVectorGenerator get(String name, Config config, Map runtimeParams) throws ConfigurationException {
            if (!config.getString("type").equals("esa")) {
                return null;
            }
            if (!runtimeParams.containsKey("language")) {
                throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
            }
            Language language = Language.getByLangCode(runtimeParams.get("language"));
            ESAGenerator generator = new ESAGenerator(
                    language,
                    getConfigurator().get(LocalPageDao.class),
                    getConfigurator().get(LuceneSearcher.class, config.getString("luceneSearcher")),
                    getConfig().get().getString("sr.blacklist.path")
            );
            if (config.hasPath("concepts")) {
                try {
                    generator.setConcepts(FileUtils.getFile(
                            config.getString("concepts"),
                            language.getLangCode() + ".txt"));
                } catch (IOException e) {
                    throw new ConfigurationException(e);
                }
            }
            return generator;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy