All Downloads are FREE. Search and download functionalities are using the official Maven repository.

mulan.evaluation.Evaluator Maven / Gradle / Ivy

Go to download

Mulan is an open-source Java library for learning from multi-label datasets.

The newest version!
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    Evaluator.java
 *    Copyright (C) 2009-2012 Aristotle University of Thessaloniki, Greece
 */
package mulan.evaluation;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.measure.*;
import weka.core.Instance;
import weka.core.Instances;

/**
 * Evaluator - responsible for generating evaluation data
 * 
 * @author rofr
 * @author Grigorios Tsoumakas
 * @version 2011.09.06
 */
public class Evaluator {

    // seed for reproduction of cross-validation results
    private int seed = 1;

    /**
     * Sets the seed for reproduction of cross-validation results
     * 
     * @param aSeed seed for reproduction of cross-validation results
     */
    public void setSeed(int aSeed) {
        seed = aSeed;
    }

    /**
     * Evaluates a {@link MultiLabelLearner} on given test data set using specified evaluation measures
     * 
     * @param learner the learner to be evaluated via cross-validation
     * @param data the data set for cross-validation
     * @param measures the evaluation measures to compute
     * @return an Evaluation object
     * @throws IllegalArgumentException if an input parameter is null
     * @throws Exception
     */
    public Evaluation evaluate(MultiLabelLearner learner, MultiLabelInstances data, List measures) throws IllegalArgumentException, Exception {
        checkLearner(learner);
        checkData(data);
        checkMeasures(measures);

        // reset measures
        for (Measure m : measures) {
            m.reset();
        }

        int numLabels = data.getNumLabels();
        int[] labelIndices = data.getLabelIndices();
        boolean[] trueLabels;
        Set failed = new HashSet();
        Instances testData = data.getDataSet();
        int numInstances = testData.numInstances();
        for (int instanceIndex = 0; instanceIndex < numInstances; instanceIndex++) {
            Instance instance = testData.instance(instanceIndex);
            if (data.hasMissingLabels(instance)) {
                continue;
            }
            Instance labelsMissing = (Instance) instance.copy();
            labelsMissing.setDataset(instance.dataset());
            for (int i = 0; i < data.getNumLabels(); i++) {
                labelsMissing.setMissing(data.getLabelIndices()[i]);
            }
            MultiLabelOutput output = learner.makePrediction(labelsMissing);
            trueLabels = getTrueLabels(instance, numLabels, labelIndices);
            Iterator it = measures.iterator();
            while (it.hasNext()) {
                Measure m = it.next();
                if (!failed.contains(m)) {
                    try {
                        m.update(output, trueLabels);
                    } catch (Exception ex) {
                        failed.add(m);
                    }
                }
            }
        }

        return new Evaluation(measures, data);
    }

    private void checkLearner(MultiLabelLearner learner) {
        if (learner == null) {
            throw new IllegalArgumentException("Learner to be evaluated is null.");
        }
    }

    private void checkData(MultiLabelInstances data) {
        if (data == null) {
            throw new IllegalArgumentException("Evaluation data object is null.");
        }
    }

    private void checkMeasures(List measures) {
        if (measures == null) {
            throw new IllegalArgumentException("List of evaluation measures to compute is null.");
        }
    }

    private void checkFolds(int someFolds) {
        if (someFolds < 2) {
            throw new IllegalArgumentException("Number of folds must be at least two or higher.");
        }
    }

    /**
     * Evaluates a {@link MultiLabelLearner} on given test data set.
     * 
     * @param learner the learner to be evaluated
     * @param data the data set for evaluation
     * @return the evaluation result
     * @throws IllegalArgumentException if either of input parameters is null.
     * @throws Exception
     */
    public Evaluation evaluate(MultiLabelLearner learner, MultiLabelInstances data) throws IllegalArgumentException, Exception {
        checkLearner(learner);
        checkData(data);

        List measures = prepareMeasures(learner, data);

        return evaluate(learner, data, measures);
    }

    private List prepareMeasures(MultiLabelLearner learner, MultiLabelInstances data) {
        List measures = new ArrayList();

        MultiLabelOutput prediction;
        try {
            MultiLabelLearner copyOfLearner = learner.makeCopy();
            prediction = copyOfLearner.makePrediction(data.getDataSet().instance(0));
            // add bipartition-based measures if applicable
            if (prediction.hasBipartition()) {
                // add example-based measures
                measures.add(new HammingLoss());
                measures.add(new SubsetAccuracy());
                measures.add(new ExampleBasedPrecision());
                measures.add(new ExampleBasedRecall());
                measures.add(new ExampleBasedFMeasure());
                measures.add(new ExampleBasedAccuracy());
                measures.add(new ExampleBasedSpecificity());
                // add label-based measures
                int numOfLabels = data.getNumLabels();
                measures.add(new MicroPrecision(numOfLabels));
                measures.add(new MicroRecall(numOfLabels));
                measures.add(new MicroFMeasure(numOfLabels));
                measures.add(new MicroSpecificity(numOfLabels));
                measures.add(new MacroPrecision(numOfLabels));
                measures.add(new MacroRecall(numOfLabels));
                measures.add(new MacroFMeasure(numOfLabels));
                measures.add(new MacroSpecificity(numOfLabels));
            }
            // add ranking-based measures if applicable
            if (prediction.hasRanking()) {
                // add ranking based measures
                measures.add(new AveragePrecision());
                measures.add(new Coverage());
                measures.add(new OneError());
                measures.add(new IsError());
                measures.add(new ErrorSetSize());
                measures.add(new RankingLoss());
            }
            // add confidence measures if applicable
            if (prediction.hasConfidences()) {
                int numOfLabels = data.getNumLabels();
                measures.add(new MeanAveragePrecision(numOfLabels));
                measures.add(new GeometricMeanAveragePrecision(numOfLabels));
                measures.add(new MeanAverageInterpolatedPrecision(numOfLabels, 10));
                measures.add(new GeometricMeanAverageInterpolatedPrecision(numOfLabels, 10));
                measures.add(new MicroAUC(numOfLabels));
                measures.add(new MacroAUC(numOfLabels));
            }
            // add hierarchical measures if applicable
            if (data.getLabelsMetaData().isHierarchy()) {
                measures.add(new HierarchicalLoss(data));
            }
        } catch (Exception ex) {
            Logger.getLogger(Evaluator.class.getName()).log(Level.SEVERE, null, ex);
        }

        return measures;
    }

    private boolean[] getTrueLabels(Instance instance, int numLabels, int[] labelIndices) {

        boolean[] trueLabels = new boolean[numLabels];
        for (int counter = 0; counter < numLabels; counter++) {
            int classIdx = labelIndices[counter];
            String classValue = instance.attribute(classIdx).value((int) instance.value(classIdx));
            trueLabels[counter] = classValue.equals("1");
        }

        return trueLabels;
    }

    /**
     * Evaluates a {@link MultiLabelLearner} via cross-validation on given data set with defined number of folds and seed.
     * 
     * @param learner the learner to be evaluated via cross-validation
     * @param data the multi-label data set for cross-validation
     * @param someFolds
     * @return a {@link MultipleEvaluation} object holding the results
     */
    public MultipleEvaluation crossValidate(MultiLabelLearner learner, MultiLabelInstances data, int someFolds) {
        checkLearner(learner);
        checkData(data);
        checkFolds(someFolds);

        return innerCrossValidate(learner, data, false, null, someFolds);
    }

    /**
     * Evaluates a {@link MultiLabelLearner} via cross-validation on given data set using given evaluation measures with defined number of folds and seed.
     * 
     * @param learner the learner to be evaluated via cross-validation
     * @param data the multi-label data set for cross-validation
     * @param measures the evaluation measures to compute
     * @param someFolds
     * @return a {@link MultipleEvaluation} object holding the results
     */
    public MultipleEvaluation crossValidate(MultiLabelLearner learner, MultiLabelInstances data, List measures, int someFolds) {
        checkLearner(learner);
        checkData(data);
        checkMeasures(measures);

        return innerCrossValidate(learner, data, true, measures, someFolds);
    }

    private MultipleEvaluation innerCrossValidate(MultiLabelLearner learner, MultiLabelInstances data, boolean hasMeasures, List measures, int someFolds) {
        Evaluation[] evaluation = new Evaluation[someFolds];

        Instances workingSet = new Instances(data.getDataSet());
        workingSet.randomize(new Random(seed));
        for (int i = 0; i < someFolds; i++) {
            System.out.println("Fold " + (i + 1) + "/" + someFolds);
            try {
                Instances train = workingSet.trainCV(someFolds, i);
                Instances test = workingSet.testCV(someFolds, i);
                MultiLabelInstances mlTrain = new MultiLabelInstances(train, data.getLabelsMetaData());
                MultiLabelInstances mlTest = new MultiLabelInstances(test, data.getLabelsMetaData());
                MultiLabelLearner clone = learner.makeCopy();
                clone.build(mlTrain);
                if (hasMeasures)
                    evaluation[i] = evaluate(clone, mlTest, measures);
                else
                    evaluation[i] = evaluate(clone, mlTest);
            } catch (Exception ex) {
                Logger.getLogger(Evaluator.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        MultipleEvaluation me = new MultipleEvaluation(evaluation, data);
        me.calculateStatistics();
        return me;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy