All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.spatial.cookbook.tflevaluate.KNNEvaluator Maven / Gradle / Ivy

The newest version!
package org.wikibrain.spatial.cookbook.tflevaluate;

import au.com.bytecode.opencsv.CSVWriter;
import com.vividsolutions.jts.geom.Geometry;
import com.vividsolutions.jts.geom.Point;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.WikiBrainException;
import org.wikibrain.core.cmd.Env;
import org.wikibrain.core.cmd.EnvBuilder;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.dao.UniversalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.LanguageSet;
import org.wikibrain.core.model.Title;
import org.wikibrain.core.model.UniversalPage;
import org.wikibrain.spatial.dao.SpatialDataDao;
import org.wikibrain.spatial.dao.SpatialNeighborDao;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;

import java.io.FileWriter;
import java.io.IOException;
import java.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Created by toby on 5/17/14.
 */
public class KNNEvaluator {


    private static int WIKIDATA_CONCEPTS = 1;


    private static final Logger LOG = LoggerFactory.getLogger(KNNEvaluator.class);

    private Random random = new Random();

    private final SpatialDataDao sdDao;
    private final LocalPageDao lpDao;
    private final UniversalPageDao upDao;
    private final SpatialNeighborDao snDao;
    private final List langs;
    private final Map metrics;
    private final DistanceMetrics distanceMetrics;

    private final List concepts = new ArrayList();
    private final Map locations = new HashMap();
    private final Env env;
    private CSVWriter output;
    private String layerName = "wikidata";



    public KNNEvaluator(Env env, LanguageSet languages) throws ConfigurationException {
        this.env = env;
        //this.langs = new ArrayList(env.getLanguages().getLanguages());
        langs = new ArrayList();
        for(Language lang : languages.getLanguages())
            langs.add(lang);

        // Get data access objects
        Configurator c = env.getConfigurator();
        this.sdDao = c.get(SpatialDataDao.class);
        this.lpDao = c.get(LocalPageDao.class);
        this.upDao = c.get(UniversalPageDao.class);
        this.snDao = c.get(SpatialNeighborDao.class);

        this.distanceMetrics = new DistanceMetrics(env, c, snDao);

        // build SR metrics
        this.metrics = new HashMap();
        for(Language lang : langs){
            SRMetric m = c.get(SRMetric.class, "ensemble", "language", lang.getLangCode());
            metrics.put(lang, m);
        }

    }

    public static  List getRandomSubList(List input, int subsetSize)
    {
        if(subsetSize > input.size())
            subsetSize = input.size();
        Random r = new Random();
        int inputSize = input.size();
        for (int i = 0; i < subsetSize; i++)
        {
            int indexToSwap = i + r.nextInt(inputSize - i);
            T temp = input.get(i);
            input.set(i, input.get(indexToSwap));
            input.set(indexToSwap, temp);
        }
        return input.subList(0, subsetSize);
    }

    public static T getRandomElement(List input){
        return getRandomSubList(input, 1).get(0);

    }

    private void writeHeader() throws IOException {
        String[] headerEntries = new String[5 + langs.size()];
        headerEntries[0] = "ITEM_NAME_1";
        headerEntries[1] = "ITEM_ID_1";
        headerEntries[2] = "ITEM_NAME_2";
        headerEntries[3] = "ITEM_ID_2";
        headerEntries[4] = "KNN_DISTANCE";
        int counter = 0;
        for (Language lang : langs) {
            headerEntries[5 + counter] = lang.getLangCode() + "_SR";
            counter ++;
        }
        output.writeNext(headerEntries);
        output.flush();
    }

    private void writeRow(UniversalPage c1, UniversalPage c2, Integer KNNDistance, List results) throws WikiBrainException, IOException {

        Title t1 = c1.getBestEnglishTitle(lpDao, true);
        Title t2 = c2.getBestEnglishTitle(lpDao, true);

        String[] rowEntries = new String[5 + langs.size()];
        rowEntries[0] = t1.getCanonicalTitle();
        rowEntries[1] = String.valueOf(c1.getUnivId());
        rowEntries[2] = t2.getCanonicalTitle();
        rowEntries[3] = String.valueOf(c2.getUnivId());
        rowEntries[4] = String.valueOf(KNNDistance);
        int counter = 0;
        for (SRResult result : results) {
            rowEntries[5 + counter] = String.valueOf(result.getScore());
            counter ++;
        }
        output.writeNext(rowEntries);
        output.flush();
        //if(CSVRowCounter % 1000 == 0
        //    LOG.info("Finished writing to CSV Row " + CSVRowCounter);
        //}

    }

    /**
     *
     * @param originId Origins to start
     * @param k K as in "K-nearest neighbors"
     * @param limitPerLevel Number of samples to pick from each "K-nearest neighbors" to evaluate
     * @param limitBranch not used
     * @param maxDist Max distance (depth of search)
     * @param outputPath The path for the output CSV file
     * @throws DaoException
     * @throws IOException
     */


    public void evaluate(Iterable originId, final Integer k, final Integer limitPerLevel, Integer limitBranch, final Integer maxDist, String outputPath) throws DaoException, IOException{
        //TODO: parallel this process...originId.size() should definitely be larger than the number of available threads
        this.output = new CSVWriter(new FileWriter(outputPath), ',');

        writeHeader();
        ParallelForEach.iterate(originId.iterator(), new Procedure() {
            @Override
            public void call(Integer arg) throws Exception {
                evaluateForOne(arg, sdDao.getGeometry(arg, layerName, "earth"), k, limitPerLevel, maxDist);
            }
        });
    }

    //TODO: return only a limited number of pairs for each recursion

    public void evaluateForOne(Integer originId, Geometry originGeom, Integer k, Integer limitPerLevel, Integer maxDist) throws DaoException{
        Integer CSVRowCounter = 0;
        Set excludeIds = new HashSet();
        Map evalResult = new HashMap();
        List nodeToDiscover = new LinkedList();
        Map geometryMap = new HashMap();
        evalResult.put(originId, 0);
        geometryMap.put(originId, originGeom);
        excludeIds.add(originId);
        nodeToDiscover.add(originId);
        int counter = 0;
        while(counter < maxDist){
            counter ++;
            Integer nodeToExpand = getRandomElement(nodeToDiscover);
            Map thisLevel;
            //No need to lock as they are all Read-Read
            // synchronized (this){
                thisLevel = snDao.getKNNeighbors(geometryMap.get(nodeToExpand), k, layerName, "earth", excludeIds);
            //}
            if(thisLevel == null || thisLevel.size() == 0)
                break;
            excludeIds.addAll(thisLevel.keySet());



            List nodesToPutInTheCSV = getRandomSubList(new ArrayList(thisLevel.keySet()), limitPerLevel);

            for(Integer i : nodesToPutInTheCSV){
                evalResult.put(i, counter);
            }


            Integer nodeToAdd = getRandomElement(new ArrayList(thisLevel.keySet()));
            nodeToDiscover.add(nodeToAdd);
            geometryMap.put(nodeToAdd, thisLevel.get(nodeToAdd));
        }

        for(Integer x : evalResult.keySet()){
            for(Integer y : evalResult.keySet()){
                try {
                    List results = new ArrayList();
                    //synchronized (this){
                        for (Language lang : langs) {
                            SRMetric sr = metrics.get(lang);
                            results.add(sr.similarity(upDao.getById(x).getLocalId(lang), upDao.getById(y).getLocalId(lang), false));
                        }
                        writeRow(upDao.getById(x), upDao.getById(y), Math.abs(evalResult.get(x) - evalResult.get(y)), results);
                    //}
                    CSVRowCounter++;
                    if(CSVRowCounter % 5000 == 0)
                        LOG.info("Thread " + Thread.currentThread().getId() + " Now printing " + CSVRowCounter + " From " + x + " To " + y + " at level " + Math.abs(evalResult.get(x) - evalResult.get(y)));
                }
                catch (Exception e){
                    //do nothing
                }
            }
        }

    }


/*
    public Map evaluateRecursive(Integer originId, Geometry originGeom, Integer k, Integer limitPerLevel, Integer limitBranch, Integer maxDist) throws DaoException{
        if (maxDist == 0){
            return new HashMap();
        }
        if (maxDist < currentLevel){
            currentLevel = maxDist;
            LOG.info("reached level " + currentLevel);
        }
        excludeIds.add(originId);
        Map thisLevel = snDao.getKNNeighbors(originGeom, k, layerName, "earth", excludeIds);
        excludeIds.addAll(thisLevel.keySet());
        Map thisLevelRes = new HashMap();

        if(limitBranch > thisLevel.size())
            limitBranch = thisLevel.size();

        List candidateList = getRandomSubList(new LinkedList(thisLevel.keySet()), limitBranch);


        for(Integer i : candidateList){
            thisLevelRes.put(i, 1);
            Map childLevelRes = evaluateRecursive(i, thisLevel.get(i), k, limitPerLevel, limitBranch, maxDist - 1);
            for(Integer q: childLevelRes.keySet()){
                thisLevelRes.put(q, childLevelRes.get(q) + 1);
            }
        }
        UniversalPage originPage = upDao.getById(originId, WIKIDATA_CONCEPTS);
        for(Integer i: thisLevelRes.keySet()){
            try {
                List results = new ArrayList();
                for (Language lang : langs) {
                    MonolingualSRMetric sr = metrics.get(lang);
                    results.add(sr.similarity(originPage.getLocalId(lang), upDao.getById(i, WIKIDATA_CONCEPTS).getLocalId(lang), false));
                }
                LOG.info("Now printing " + CSVRowCounter + " From " + originId + " To " + i + " at level " + maxDist);
                writeRow(originPage, upDao.getById(i, WIKIDATA_CONCEPTS), thisLevelRes.get(i), results);
            }
            catch (Exception e){
                //do nothing
            }

        }

        if(limitPerLevel > thisLevelRes.size())
            limitPerLevel = thisLevelRes.size();

        List returnList = getRandomSubList(new LinkedList(thisLevelRes.keySet()), limitPerLevel);
        Map returnMap = new HashMap();
        for(Integer i : returnList){
            returnMap.put(i, thisLevelRes.get(i));
        }

        return returnMap;
    }
 */

    public static void main(String[] args) throws Exception {

        Env env = EnvBuilder.envFromArgs(args);
        Configurator conf = env.getConfigurator();
        KNNEvaluator evaluator = new KNNEvaluator(env, new LanguageSet("simple"));
        SpatialDataDao sdDao = conf.get(SpatialDataDao.class);
        Set originSet = new HashSet();
        originSet.add(36091);originSet.add(956);originSet.add(64);originSet.add(258);originSet.add(60);originSet.add(65);originSet.add(90);originSet.add(84);originSet.add(1490);
        evaluator.evaluate(originSet, 100, 5, 1, 30, "test-topo.csv");
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy