All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.evaluation.EvaluationMetricHelper Maven / Gradle / Ivy

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    EvaluationMetricHelper.java
 *    Copyright (C) 2014 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.evaluation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Helper routines for extracting metric values from built-in and plugin
 * evaluation metrics.
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: 10934 $
 */
public class EvaluationMetricHelper {

  /** The Evaluation object to extract built-in and plugin metrics from */
  protected Evaluation m_eval;

  /** A lookup for built-in metrics */
  protected Map m_builtin = new HashMap();

  /** A lookup for plugin metrics */
  protected Map m_pluginMetrics =
    new HashMap();

  /**
   * Construct a new EvaluationMetricHelper
   * 
   * @param eval the Evaluation object to use
   */
  public EvaluationMetricHelper(Evaluation eval) {
    for (int i = 0; i < Evaluation.BUILT_IN_EVAL_METRICS.length; i++) {
      m_builtin.put(Evaluation.BUILT_IN_EVAL_METRICS[i].toLowerCase(), i);
    }

    setEvaluation(eval);
  }

  /**
   * Sets the Evaluation object to use
   * 
   * @param eval the Evaluation object to use
   */
  public void setEvaluation(Evaluation eval) {
    m_eval = eval;
    initializeWithPluginMetrics();
  }

  /**
   * Initializes the plugin lookup
   */
  protected void initializeWithPluginMetrics() {
    m_pluginMetrics.clear();
    List pluginMetrics = m_eval.getPluginMetrics();
    if (pluginMetrics != null && pluginMetrics.size() > 0) {
      for (AbstractEvaluationMetric m : pluginMetrics) {
        List statNames = m.getStatisticNames();
        for (String s : statNames) {
          m_pluginMetrics.put(s.toLowerCase(), m);
        }
      }
    }
  }

  /**
   * Get a list of built-in metric names
   * 
   * @return a list of built-in metric names
   */
  public static List getBuiltInMetricNames() {
    List builtIn = new ArrayList();
    builtIn.addAll(Arrays.asList(Evaluation.BUILT_IN_EVAL_METRICS));

    return builtIn;
  }

  /**
   * Get a list of plugin metric names
   * 
   * @return a list of plugin metric names
   */
  public static List getPluginMetricNames() {
    List pluginNames = new ArrayList();
    List pluginMetrics =
      AbstractEvaluationMetric.getPluginMetrics();

    if (pluginMetrics != null) {
      for (AbstractEvaluationMetric m : pluginMetrics) {
        List statNames = m.getStatisticNames();
        for (String s : statNames) {
          pluginNames.add(s.toLowerCase());
        }
      }
    }

    return pluginNames;
  }

  /**
   * Get a list of all available evaluation metric names
   * 
   * @return a list of all available evaluation metric names
   */
  public static List getAllMetricNames() {
    List metrics = getBuiltInMetricNames();
    metrics.addAll(getPluginMetricNames());

    return metrics;
  }

  /**
   * Returns true if the specified built-in metric is maximisable
   * 
   * @param metricIndex the index of metric
   * @return true if the metric in question is optimum at a maximal value
   * @throws Exception if the metric is not a known built-in metric
   */
  protected boolean builtInMetricIsMaximisable(int metricIndex)
    throws Exception {
    switch (metricIndex) {
    case 0:
      // correct
      return true;
    case 1:
      // incorrect
      return false;
    case 2:
      // kappa
      return true;
    case 3:
      // total cost
      return false;
    case 4:
      // avg cost
      return false;
    case 5:
      // KB relative info
      return false;
    case 6:
      // KB info
      return false;
    case 7:
      // correlation
      return true;
    case 8:
      // SF prior entropy
      return false;
    case 9:
      // SF scheme entropy
      return false;
    case 10:
      // SF entropy gain
      return true;
    case 11:
      // MAE
      return false;
    case 12:
      // RMSE
      return false;
    case 13:
      // RAE
      return false;
    case 14:
      // RRSE
      return false;
    case 15:
      // coverage of cases by predicted regions
      return true;
    case 16:
      // size of predicted regions
      return false;
    case 17:
      // TPR
      return true;
    case 18:
      // FPR
      return false;
    case 19:
      // precision
      return true;
    case 20:
      // recall
      return true;
    case 21:
      // f-measure
      return true;
    case 22:
      // Matthews correlation
      return true;
    case 23:
      // AUC
      return true;
    case 24:
      // AUPRC
      return true;
    }

    throw new Exception("Unknown built-in metric");
  }

  /**
   * Gets the value of a built-in metric
   * 
   * @param metricIndex the index of the metric
   * @param classValIndex the optional class value index
   * @return the value of the metric
   * @throws Exception if the metric is not a known built-in metric
   */
  protected double getBuiltinMetricValue(int metricIndex, int... classValIndex)
    throws Exception {

    boolean hasValIndex = classValIndex != null && classValIndex.length == 1;

    switch (metricIndex) {
    case 0:
      return m_eval.correct();
    case 1:
      return m_eval.incorrect();
    case 2:
      return m_eval.kappa();
    case 3:
      return m_eval.totalCost();
    case 4:
      return m_eval.avgCost();
    case 5:
      return m_eval.KBRelativeInformation();
    case 6:
      return m_eval.KBInformation();
    case 7:
      return m_eval.correlationCoefficient();
    case 8:
      return m_eval.SFPriorEntropy();
    case 9:
      return m_eval.SFSchemeEntropy();
    case 10:
      return m_eval.SFEntropyGain();
    case 11:
      return m_eval.meanAbsoluteError();
    case 12:
      return m_eval.rootMeanSquaredError();
    case 13:
      return m_eval.relativeAbsoluteError();
    case 14:
      return m_eval.rootRelativeSquaredError();
    case 15:
      return m_eval.coverageOfTestCasesByPredictedRegions();
    case 16:
      return m_eval.sizeOfPredictedRegions();
    case 17:
      return hasValIndex ? m_eval.truePositiveRate(classValIndex[0]) : m_eval
        .weightedTruePositiveRate();
    case 18:
      return hasValIndex ? m_eval.falsePositiveRate(classValIndex[0]) : m_eval
        .weightedFalsePositiveRate();
    case 19:
      return hasValIndex ? m_eval.precision(classValIndex[0]) : m_eval
        .weightedPrecision();
    case 20:
      return hasValIndex ? m_eval.recall(classValIndex[0]) : m_eval
        .weightedRecall();
    case 21:
      return hasValIndex ? m_eval.fMeasure(classValIndex[0]) : m_eval
        .weightedFMeasure();
    case 22:
      return hasValIndex ? m_eval
        .matthewsCorrelationCoefficient(classValIndex[0]) : m_eval
        .weightedMatthewsCorrelation();
    case 23:
      return hasValIndex ? m_eval.areaUnderROC(classValIndex[0]) : m_eval
        .weightedAreaUnderROC();
    case 24:
      return hasValIndex ? m_eval.areaUnderPRC(classValIndex[0]) : m_eval
        .weightedAreaUnderPRC();
    }

    throw new Exception("Unknown built-in metric");
  }

  /**
   * Get the value of a plugin metric
   * 
   * @param m the metric to get the value from
   * @param statName the name of the statistic to get the value of
   * @param classValIndex the optional class value index
   * @return the value of the metric
   * @throws Exception if a problem occurs
   */
  protected double getPluginMetricValue(AbstractEvaluationMetric m,
    String statName, int... classValIndex) throws Exception {

    boolean hasValIndex = classValIndex != null && classValIndex.length == 1;

    if (m instanceof InformationRetrievalEvaluationMetric) {
      return hasValIndex ? ((InformationRetrievalEvaluationMetric) m)
        .getStatistic(statName, classValIndex[0])
        : ((InformationRetrievalEvaluationMetric) m)
          .getClassWeightedAverageStatistic(statName);
    }

    return m.getStatistic(statName);
  }

  /**
   * Returns true if the named statistic is maximisable
   * 
   * @param m the metric to check
   * @param statName the name of the statistic to check
   * @return true if the metric in question is optimum at a maximal value
   */
  protected boolean pluginMetricIsMaximisable(AbstractEvaluationMetric m,
    String statName) {
    return m.statisticIsMaximisable(statName);
  }

  /**
   * Gets the value of a named metric. For information retrieval metrics if a
   * class value index is not supplied then the class weighted variant is
   * returned.
   * 
   * @param statName the name of the metric/statistic to get
   * @param classValIndex the optional class value index
   * @return the value of the metric
   * @throws Exception if the metric/stat is unknown or a problem occurs
   */
  public double getNamedMetric(String statName, int... classValIndex)
    throws Exception {

    if (classValIndex != null && classValIndex.length > 1) {
      throw new IllegalArgumentException(
        "Only one class value index should be supplied");
    }

    Integer builtinIndex = m_builtin.get(statName.toLowerCase());
    if (builtinIndex != null) {
      return getBuiltinMetricValue(builtinIndex.intValue(), classValIndex);
    } else {
      AbstractEvaluationMetric m = m_pluginMetrics.get(statName.toLowerCase());
      if (m == null) {
        throw new Exception("Unknown evaluation metric: " + statName);
      }
      return getPluginMetricValue(m, statName, classValIndex);
    }
  }

  /**
   * Gets the thresholds produced by the metric, if the metric implements
   * ThresholdProducingMetric.
   * 
   * @param statName the name of the metric/statistic to get
   * @return the thresholds, null if metric does not produce any
   * @throws Exception if the metric/stat is unknown or a problem occurs
   */
  public double[] getNamedMetricThresholds(String statName)
    throws Exception {

    Integer builtinIndex = m_builtin.get(statName.toLowerCase());
    if (builtinIndex != null) {
      return null; // built-in metrics don not produce thresholds
    } else {
      AbstractEvaluationMetric m = m_pluginMetrics.get(statName.toLowerCase());
      if (m == null) {
        throw new Exception("Unknown evaluation metric: " + statName);
      }
      if (m instanceof ThresholdProducingMetric) {
        return ((ThresholdProducingMetric)m).getThresholds();
      } else {
        return null;
      }
    }
  }

  /**
   * Returns true if the named metric is maximisable
   * 
   * @param statName the name of the metric/statistic to check
   * @return true if the metric in question is optimum at a maximal value
   * @throws Exception if a problem occurs
   */
  public boolean metricIsMaximisable(String statName) throws Exception {
    Integer builtinIndex = m_builtin.get(statName.toLowerCase());
    if (builtinIndex != null) {
      return builtInMetricIsMaximisable(builtinIndex.intValue());
    } else {
      AbstractEvaluationMetric m = m_pluginMetrics.get(statName.toLowerCase());
      if (m == null) {
        throw new Exception("Unknown evaluation metric: " + statName);
      }
      return pluginMetricIsMaximisable(m, statName);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy