All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.meta.multisearch.DefaultEvaluationTask Maven / Gradle / Ivy

There is a newer version: 2021.2.17
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 * DefaultEvaluationTask.java
 * Copyright (C) 2015-2018 University of Waikato, Hamilton, NZ
 */

package weka.classifiers.meta.multisearch;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;
import weka.core.SetupGenerator;
import weka.core.setupgenerator.Point;

import java.io.Serializable;
import java.util.Random;

/**
 * Default Evaluation task.
 */
public class DefaultEvaluationTask
  extends AbstractEvaluationTask {

  /**
   * Initializes the task.
   *
   * @param owner		the owning MultiSearch classifier
   * @param train		the training data
   * @param test		the test data, can be null
   * @param generator		the generator to use
   * @param values		the setup values
   * @param folds		the number of cross-validation folds
   * @param eval		the type of evaluation
   * @param classLabel		the class label index (0-based; if applicable)
   */
  public DefaultEvaluationTask(
    MultiSearchCapable owner, Instances train, Instances test,
    SetupGenerator generator, Point values, int folds, int eval, int classLabel) {
    super(owner, train, test, generator, values, folds, eval, classLabel);
  }

  /**
   * Returns whether predictions can be discarded (depends on selected measure).
   */
  protected boolean canDiscardPredictions() {
    switch (m_Owner.getEvaluation().getSelectedTag().getID()) {
      case DefaultEvaluationMetrics.EVALUATION_AUC:
      case DefaultEvaluationMetrics.EVALUATION_PRC:
        return false;
      default:
        return true;
    }
  }

  /**
   * Performs the evaluation.
   *
   */
  protected Boolean doRun() throws Exception{
    Point	evals;
    Evaluation 		eval;
    Classifier 		classifier;
    Performance		performance;
    boolean		completed;

    // setup
    try {
      evals = m_Generator.evaluate(m_Values);
      classifier = (Classifier) m_Generator.setup((Serializable) m_Owner.getClassifier(), evals);
    }
    catch (Exception e) {
      m_Exception = e;
      System.err.println("Failed to configure classifier!");
      e.printStackTrace();
      return false;
    }

    // evaluate
    try {
      eval = new Evaluation(m_Train);
      eval.setDiscardPredictions(canDiscardPredictions());
      if (m_Test == null) {
        if (m_Folds >= 2) {
          eval.crossValidateModel(classifier, m_Train, m_Folds, new Random(m_Owner.getSeed()));
        }
        else {
          classifier.buildClassifier(m_Train);
          eval.evaluateModel(classifier, m_Train);
        }
      }
      else {
        classifier.buildClassifier(m_Train);
        eval.evaluateModel(classifier, m_Test);
      }
      completed = true;
    }
    catch (Exception e) {
      eval = null;
      m_Exception = e;
      System.err.println("Encountered exception while evaluating classifier, skipping!");
      System.err.println("- Classifier: " + m_Owner.getCommandline(classifier));
      e.printStackTrace();
      completed = false;
    }

    // store performance
    performance = new Performance(
      m_Values,
      m_Owner.getFactory().newWrapper(eval),
      m_Evaluation,
      m_ClassLabel,
      (Classifier) m_Generator.setup((Serializable) m_Owner.getClassifier(), evals));
    m_Owner.getAlgorithm().addPerformance(performance, m_Folds);

    // log
    m_Owner.log(performance + ": cached=false");

    return completed;
  }
}