All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.evaluation.Evaluation Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

The newest version!
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    Evaluation.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.evaluation;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.ConditionalDensityEstimator;
import weka.classifiers.CostMatrix;
import weka.classifiers.IntervalEstimator;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.Sourcable;
import weka.classifiers.UpdateableBatchProcessor;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
import weka.classifiers.evaluation.output.prediction.PlainText;
import weka.classifiers.misc.InputMappedClassifier;
import weka.core.BatchPredictor;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Summarizable;
import weka.core.Utils;
import weka.core.Version;
import weka.core.converters.ConverterUtils.DataSink;
import weka.core.converters.ConverterUtils.DataSource;
import weka.estimators.UnivariateKernelEstimator;

import java.beans.BeanInfo;
import java.beans.Introspector;
import java.beans.MethodDescriptor;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Reader;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

/**
 * Class for evaluating machine learning models.
 * 

* * ------------------------------------------------------------------- *

* * General options when evaluating a learning scheme from the command-line: *

* * -t filename
* Name of the file with the training data. (required) *

* * -T filename
* Name of the file with the test data. If missing a cross-validation is * performed. *

* * -c index
* Index of the class attribute (1, 2, ...; default: last). *

* * -x number
* The number of folds for the cross-validation (default: 10). *

* * -no-cv
* No cross validation. If no test file is provided, no evaluation is done. *

* * -split-percentage percentage
* Sets the percentage for the train/test set split, e.g., 66. *

* * -preserve-order
* Preserves the order in the percentage split instead of randomizing the data * first with the seed value ('-s'). *

* * -s seed
* Random number seed for the cross-validation and percentage split (default: * 1). *

* * -m filename
* The name of a file containing a cost matrix. *

* * -disable list
* A comma separated list of metric names not to include in the output. *

* * -l filename
* Loads classifier from the given file. In case the filename ends with ".xml", * a PMML file is loaded or, if that fails, options are loaded from XML. *

* * -d filename
* Saves classifier built from the training data into the given file. In case * the filename ends with ".xml" the options are saved XML, not the model. *

* * -v
* Outputs no statistics for the training data. *

* * -o
* Outputs statistics only, not the classifier. *

* * -output-models-for-training-splits
* Output models for training splits if cross-validation or percentage-split evaluation is used. *

* * -do-not-output-per-class-statistics
* Do not output statistics per class. *

* * -k
* Outputs information-theoretic statistics. *

* * -classifications * "weka.classifiers.evaluation.output.prediction.AbstractOutput + options"
* Uses the specified class for generating the classification output. E.g.: * weka.classifiers.evaluation.output.prediction.PlainText or : * weka.classifiers.evaluation.output.prediction.CSV * * -p range
* Outputs predictions for test instances (or the train instances if no test * instances provided and -no-cv is used), along with the attributes in the * specified range (and nothing else). Use '-p 0' if no attributes are desired. *

* Deprecated: use "-classifications ..." instead. *

* * -distribution
* Outputs the distribution instead of only the prediction in conjunction with * the '-p' option (only nominal classes). *

* Deprecated: use "-classifications ..." instead. *

* * -no-predictions
* Turns off the collection of predictions in order to conserve memory. *

* * -r
* Outputs cumulative margin distribution (and nothing else). *

* * -g
* Only for classifiers that implement "Graphable." Outputs the graph * representation of the classifier (and nothing else). *

* * -xml filename | xml-string
* Retrieves the options from the XML-data instead of the command line. *

* * -threshold-file file
* The file to save the threshold data to. The format is determined by the * extensions, e.g., '.arff' for ARFF format or '.csv' for CSV. *

* * -threshold-label label
* The class label to determine the threshold data for (default is the first * label) *

* * ------------------------------------------------------------------- *

* * Example usage as the main of a classifier (called FunkyClassifier): *

 * public static void main(String [] args) {
 *   runClassifier(new FunkyClassifier(), args);
 * }
 * 
*

* * ------------------------------------------------------------------ *

* * Example usage from within an application:

 * Instances trainInstances = ... instances got from somewhere
 * Instances testInstances = ... instances got from somewhere
 * Classifier scheme = ... scheme got from somewhere
 *
 * Evaluation evaluation = new Evaluation(trainInstances);
 * evaluation.evaluateModel(scheme, testInstances);
 * System.out.println(evaluation.toSummaryString());
 * 
* * * @author Eibe Frank ([email protected]) * @author Len Trigg ([email protected]) * @version $Revision: 15278 $ */ public class Evaluation implements Summarizable, RevisionHandler, Serializable { /** * For serialization */ private static final long serialVersionUID = -7010314486866816271L; /** * The number of classes. */ protected int m_NumClasses; /** * The number of folds for a cross-validation. */ protected int m_NumFolds; /** * The weight of all incorrectly classified instances. */ protected double m_Incorrect; /** * The weight of all correctly classified instances. */ protected double m_Correct; /** * The weight of all unclassified instances. */ protected double m_Unclassified; /*** The weight of all instances that had no class assigned to them. */ protected double m_MissingClass; /** * The weight of all instances that had a class assigned to them. */ protected double m_WithClass; /** * Array for storing the confusion matrix. */ protected double[][] m_ConfusionMatrix; /** * The names of the classes. */ protected String[] m_ClassNames; /** * Is the class nominal or numeric? */ protected boolean m_ClassIsNominal; /** * The prior probabilities of the classes. */ protected double[] m_ClassPriors; /** * The sum of counts for priors. */ protected double m_ClassPriorsSum; /** * The cost matrix (if given). */ protected CostMatrix m_CostMatrix; /** * The total cost of predictions (includes instance weights). */ protected double m_TotalCost; /** * Sum of errors. */ protected double m_SumErr; /** * Sum of absolute errors. */ protected double m_SumAbsErr; /** * Sum of squared errors. */ protected double m_SumSqrErr; /** * Sum of class values. */ protected double m_SumClass; /** * Sum of squared class values. */ protected double m_SumSqrClass; /*** Sum of predicted values. */ protected double m_SumPredicted; /** * Sum of squared predicted values. */ protected double m_SumSqrPredicted; /** * Sum of predicted * class values. */ protected double m_SumClassPredicted; /** * Sum of absolute errors of the prior. */ protected double m_SumPriorAbsErr; /** * Sum of absolute errors of the prior. */ protected double m_SumPriorSqrErr; /** * Total Kononenko & Bratko Information. */ protected double m_SumKBInfo; /*** Resolution of the margin histogram. */ protected static int k_MarginResolution = 500; /** * Cumulative margin distribution. */ protected double m_MarginCounts[]; /** * Number of non-missing class training instances seen. */ protected int m_NumTrainClassVals; /** * Array containing all numeric training class values seen. */ protected double[] m_TrainClassVals; /** * Array containing all numeric training class weights. */ protected double[] m_TrainClassWeights; /** * Numeric class estimator for prior. */ protected UnivariateKernelEstimator m_PriorEstimator; /** * Whether complexity statistics are available. */ protected boolean m_ComplexityStatisticsAvailable = true; /** * The minimum probability accepted from an estimator to avoid taking log(0) * in Sf calculations. */ protected static final double MIN_SF_PROB = Double.MIN_VALUE; /** * Total entropy of prior predictions. */ protected double m_SumPriorEntropy; /** * Total entropy of scheme predictions. */ protected double m_SumSchemeEntropy; /** * Whether coverage statistics are available. */ protected boolean m_CoverageStatisticsAvailable = true; /** * The confidence level used for coverage statistics. */ protected double m_ConfLevel = 0.95; /** * Total size of predicted regions at the given confidence level. */ protected double m_TotalSizeOfRegions; /** * Total coverage of test cases at the given confidence level. */ protected double m_TotalCoverage; /** * Minimum target value. */ protected double m_MinTarget; /** * Maximum target value. */ protected double m_MaxTarget; /** * The list of predictions that have been generated (for computing AUC). */ protected ArrayList m_Predictions; /** * enables/disables the use of priors, e.g., if no training set is present in * case of de-serialized schemes. */ protected boolean m_NoPriors = false; /** * The header of the training set. */ protected Instances m_Header; /** * whether to discard predictions (and save memory). */ protected boolean m_DiscardPredictions; /** * The list of metrics to display in the output */ protected List m_metricsToDisplay = new ArrayList(); public static final String[] BUILT_IN_EVAL_METRICS = {"Correct", "Incorrect", "Kappa", "Total cost", "Average cost", "KB relative", "KB information", "Correlation", "Complexity 0", "Complexity scheme", "Complexity improvement", "MAE", "RMSE", "RAE", "RRSE", "Coverage", "Region size", "TP rate", "FP rate", "Precision", "Recall", "F-measure", "MCC", "ROC area", "PRC area"}; /** * Utility method to get a list of the names of all built-in and plugin * evaluation metrics * * @return the complete list of available evaluation metrics */ public static List getAllEvaluationMetricNames() { List allEvals = new ArrayList(); for (String s : Evaluation.BUILT_IN_EVAL_METRICS) { allEvals.add(s); } return allEvals; } /** * Initializes all the counters for the evaluation. Use * useNoPriors() if the dataset is the test set and you can't * initialize with the priors from the training set via * setPriors(Instances). * * @param data set of training instances, to get some header information and * prior class distribution information * @throws Exception if the class is not defined * @see #useNoPriors() * @see #setPriors(Instances) */ public Evaluation(Instances data) throws Exception { this(data, null); } /** * Initializes all the counters for the evaluation and also takes a cost * matrix as parameter. Use useNoPriors() if the dataset is the * test set and you can't initialize with the priors from the training set via * setPriors(Instances). * * @param data set of training instances, to get some header information and * prior class distribution information * @param costMatrix the cost matrix---if null, default costs will be used * @throws Exception if cost matrix is not compatible with data, the class is * not defined or the class is numeric * @see #useNoPriors() * @see #setPriors(Instances) */ public Evaluation(Instances data, CostMatrix costMatrix) throws Exception { m_Header = new Instances(data, 0); m_NumClasses = data.numClasses(); m_NumFolds = 1; m_ClassIsNominal = data.classAttribute().isNominal(); if (m_ClassIsNominal) { m_ConfusionMatrix = new double[m_NumClasses][m_NumClasses]; m_ClassNames = new String[m_NumClasses]; for (int i = 0; i < m_NumClasses; i++) { m_ClassNames[i] = data.classAttribute().value(i); } } m_CostMatrix = costMatrix; if (m_CostMatrix != null) { if (!m_ClassIsNominal) { throw new Exception("Class has to be nominal if cost matrix given!"); } if (m_CostMatrix.size() != m_NumClasses) { throw new Exception("Cost matrix not compatible with data!"); } } m_ClassPriors = new double[m_NumClasses]; setPriors(data); m_MarginCounts = new double[k_MarginResolution + 1]; for (String s : BUILT_IN_EVAL_METRICS) { if (!s.equalsIgnoreCase("Coverage") && !s.equalsIgnoreCase("Region size")) { m_metricsToDisplay.add(s.toLowerCase()); } } } /** * Returns the header of the underlying dataset. * * @return the header information */ public Instances getHeader() { return m_Header; } /** * Sets whether to discard predictions, ie, not storing them for future * reference via predictions() method in order to conserve memory. * * @param value true if to discard the predictions * @see #predictions() */ public void setDiscardPredictions(boolean value) { m_DiscardPredictions = value; if (m_DiscardPredictions) { m_Predictions = null; } } /** * Returns whether predictions are not recorded at all, in order to conserve * memory. * * @return true if predictions are not recorded * @see #predictions() */ public boolean getDiscardPredictions() { return m_DiscardPredictions; } /** * Set a list of the names of metrics to have appear in the output. The * default is to display all built in metrics and plugin metrics that haven't * been globally disabled. * * @param display a list of metric names to have appear in the output */ public void setMetricsToDisplay(List display) { // make sure all metric names are lower case for matching m_metricsToDisplay.clear(); for (String s : display) { m_metricsToDisplay.add(s.trim().toLowerCase()); } } /** * Get a list of the names of metrics to have appear in the output The default * is to display all built in metrics and plugin metrics that haven't been * globally disabled. * * @return a list of metric names to have appear in the output */ public List getMetricsToDisplay() { return m_metricsToDisplay; } /** * Toggle the output of the metrics specified in the supplied list. * * @param metricsToToggle a list of metrics to toggle */ public void toggleEvalMetrics(List metricsToToggle) { for (String s : metricsToToggle) { if (m_metricsToDisplay.contains(s.toLowerCase())) { m_metricsToDisplay.remove(s.toLowerCase()); } else { m_metricsToDisplay.add(s.toLowerCase()); } } } /** * Returns the area under ROC for those predictions that have been collected * in the evaluateClassifier(Classifier, Instances) method. Returns * Utils.missingValue() if the area is not available. * * @param classIndex the index of the class to consider as "positive" * @return the area under the ROC curve or not a number */ public double areaUnderROC(int classIndex) { // Check if any predictions have been collected if (m_Predictions == null) { return Utils.missingValue(); } else { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_Predictions, classIndex); return ThresholdCurve.getROCArea(result); } } /** * Calculates the weighted (by class size) AUROC. * * Unclassified instances are not included in the calculation. * * @return the weighted AUC. */ public double weightedAreaUnderROC() { double[] trueClassCounts = trueClassCounts(); double aucTotal = 0; for (int i = 0; i < m_NumClasses; i++) { if (trueClassCounts[i] > 0) { // If AUROC is NaN, we want the sum to also be NaN if count > 0 aucTotal += areaUnderROC(i) * trueClassCounts[i]; } } return aucTotal / numClassified(); } /** * Returns the area under precision-recall curve (AUPRC) for those predictions * that have been collected in the evaluateClassifier(Classifier, Instances) * method. Returns Utils.missingValue() if the area is not available. * * @param classIndex the index of the class to consider as "positive" * @return the area under the precision-recall curve or not a number */ public double areaUnderPRC(int classIndex) { // Check if any predictions have been collected if (m_Predictions == null) { return Utils.missingValue(); } else { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_Predictions, classIndex); return ThresholdCurve.getPRCArea(result); } } /** * Calculates the weighted (by class size) AUPRC. * * Unclassified instances are not included in the calculation. * * @return the weighted AUPRC. */ public double weightedAreaUnderPRC() { double[] trueClassCounts = trueClassCounts(); double aucTotal = 0; for (int i = 0; i < m_NumClasses; i++) { if (trueClassCounts[i] > 0) { // If AUPRC is NaN, we want the sum to also be NaN if count > 0 aucTotal += areaUnderPRC(i) * trueClassCounts[i]; } } return aucTotal / numClassified(); } /** * Returns a copy of the confusion matrix. * * @return a copy of the confusion matrix as a two-dimensional array */ public double[][] confusionMatrix() { double[][] newMatrix = new double[m_ConfusionMatrix.length][0]; for (int i = 0; i < m_ConfusionMatrix.length; i++) { newMatrix[i] = new double[m_ConfusionMatrix[i].length]; System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0, m_ConfusionMatrix[i].length); } return newMatrix; } /** * Performs a (stratified if class is nominal) cross-validation for a * classifier on a set of instances. Now performs a deep copy of the * classifier before each call to buildClassifier() (just in case the * classifier is not initialized properly). * * @param classifier the classifier with any options set. * @param data the data on which the cross-validation is to be performed * @param numFolds the number of folds for the cross-validation * @param random random number generator for randomization * @throws Exception if a classifier could not be generated successfully or * the class is not defined */ public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random) throws Exception { crossValidateModel(classifier, data, numFolds, random, new Object[0]); } /** * Performs a (stratified if class is nominal) cross-validation for a * classifier on a set of instances. Performs a deep copy of the * classifier before each call to buildClassifier() (just in case the * classifier is not initialized properly). * * @param classifier the classifier with any options set. * @param data the data on which the cross-validation is to be performed * @param numFolds the number of folds for the cross-validation * @param random random number generator for randomization * @param forPrinting varargs parameter that, if supplied, is * expected to hold a * weka.classifiers.evaluation.output.prediction.AbstractOutput * object or a StringBuffer for model output * @throws Exception if a classifier could not be generated successfully or * the class is not defined */ public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random, Object... forPrinting) throws Exception { // Make a copy of the data we can reorder data = new Instances(data); data.randomize(random); if (data.classAttribute().isNominal()) { data.stratify(numFolds); } // We assume that the first element is a // weka.classifiers.evaluation.output.prediction.AbstractOutput object AbstractOutput classificationOutput = null; if (forPrinting.length > 0 && forPrinting[0] instanceof AbstractOutput) { // print the header first classificationOutput = (AbstractOutput) forPrinting[0]; classificationOutput.setHeader(data); classificationOutput.printHeader(); } // Do the folds for (int i = 0; i < numFolds; i++) { Instances train = data.trainCV(numFolds, i, random); setPriors(train); Classifier copiedClassifier = AbstractClassifier.makeCopy(classifier); copiedClassifier.buildClassifier(train); if (classificationOutput == null && forPrinting.length > 0) { ((StringBuffer)forPrinting[0]).append("\n=== Classifier model (training fold " + (i + 1) +") ===\n\n" + copiedClassifier); } Instances test = data.testCV(numFolds, i); if (classificationOutput != null){ evaluateModel(copiedClassifier, test, forPrinting); } else { evaluateModel(copiedClassifier, test); } } m_NumFolds = numFolds; if (classificationOutput != null) { classificationOutput.printFooter(); } } /** * Performs a (stratified if class is nominal) cross-validation for a * classifier on a set of instances. * * @param classifierString a string naming the class of the classifier * @param data the data on which the cross-validation is to be performed * @param numFolds the number of folds for the cross-validation * @param options the options to the classifier. Any options * @param random the random number generator for randomizing the data accepted * by the classifier will be removed from this array. * @throws Exception if a classifier could not be generated successfully or * the class is not defined */ public void crossValidateModel(String classifierString, Instances data, int numFolds, String[] options, Random random) throws Exception { crossValidateModel(AbstractClassifier.forName(classifierString, options), data, numFolds, random); } /** * Evaluates a classifier with the options given in an array of strings. *

*

* Valid options are: *

*

* -t filename
* Name of the file with the training data. (required) *

*

* -T filename
* Name of the file with the test data. If missing a cross-validation is * performed. *

*

* -c index
* Index of the class attribute (1, 2, ...; default: last). *

*

* -x number
* The number of folds for the cross-validation (default: 10). *

*

* -no-cv
* No cross validation. If no test file is provided, no evaluation is done. *

*

* -split-percentage percentage
* Sets the percentage for the train/test set split, e.g., 66. *

*

* -preserve-order
* Preserves the order in the percentage split instead of randomizing the data * first with the seed value ('-s'). *

*

* -s seed
* Random number seed for the cross-validation and percentage split (default: * 1). *

*

* -m filename
* The name of a file containing a cost matrix. *

*

* -l filename
* Loads classifier from the given file. In case the filename ends with * ".xml",a PMML file is loaded or, if that fails, options are loaded from * XML. *

*

* -d filename
* Saves classifier built from the training data into the given file. In case * the filename ends with ".xml" the options are saved XML, not the model. *

*

* -v
* Outputs no statistics for the training data. *

*

* -o
* Outputs statistics only, not the classifier. *

*

* -output-models-for-training-splits
* Output models for training splits if cross-validation or percentage-split evaluation is used. *

*

* -do-not-output-per-class-statistics
* Do not output statistics per class. *

*

* -k
* Outputs information-theoretic statistics. *

*

* -classifications * "weka.classifiers.evaluation.output.prediction.AbstractOutput + options"
* Uses the specified class for generating the classification output. E.g.: * weka.classifiers.evaluation.output.prediction.PlainText or : * weka.classifiers.evaluation.output.prediction.CSV *

* -p range
* Outputs predictions for test instances (or the train instances if no test * instances provided and -no-cv is used), along with the attributes in the * specified range (and nothing else). Use '-p 0' if no attributes are * desired. *

* Deprecated: use "-classifications ..." instead. *

*

* -distribution
* Outputs the distribution instead of only the prediction in conjunction with * the '-p' option (only nominal classes). *

* Deprecated: use "-classifications ..." instead. *

*

* -no-predictions
* Turns off the collection of predictions in order to conserve memory. *

*

* -r
* Outputs cumulative margin distribution (and nothing else). *

*

* -g
* Only for classifiers that implement "Graphable." Outputs the graph * representation of the classifier (and nothing else). *

*

* -xml filename | xml-string
* Retrieves the options from the XML-data instead of the command line. *

*

* -threshold-file file
* The file to save the threshold data to. The format is determined by the * extensions, e.g., '.arff' for ARFF format or '.csv' for CSV. *

*

* -threshold-label label
* The class label to determine the threshold data for (default is the first * label) *

* * @param classifierString class of machine learning classifier as a string * @param options the array of string containing the options * @return a string describing the results * @throws Exception if model could not be evaluated successfully */ public static String evaluateModel(String classifierString, String[] options) throws Exception { Classifier classifier; // Create classifier try { classifier = AbstractClassifier.forName(classifierString, null); } catch (Exception e) { throw new Exception("Can't find class with name " + classifierString + '.'); } return evaluateModel(classifier, options); } /** * A test method for this class. Just extracts the first command line argument * as a classifier class name and calls evaluateModel. * * @param args an array of command line arguments, the first of which must be * the class name of a classifier. */ public static void main(String[] args) { try { if (args.length == 0) { throw new Exception("The first argument must be the class name of a classifier"); } String classifier = args[0]; args[0] = ""; System.out.println(evaluateModel(classifier, args)); } catch (Exception ex) { ex.printStackTrace(); System.err.println(ex.getMessage()); } } /** * Tries to get the classifier from the provided model file * * @param modelFileName the name of the model file * @param template the template header to compare the saved data header to * @return the classifier */ protected static Classifier getModelFromFile(String modelFileName, Instances template) throws Exception { Classifier classifier = null; // Try to load (gzipped) serialized Java objects or KOML InputStream is = new FileInputStream(modelFileName); if (modelFileName.endsWith(".gz")) { is = new GZIPInputStream(is); } ObjectInputStream objectInputStream = new ObjectInputStream(is); classifier = (Classifier) objectInputStream.readObject(); // try and read a header (if present) Instances savedStructure = null; try { savedStructure = (Instances) objectInputStream.readObject(); } catch (Exception ex) { // don't make a fuss } if (savedStructure != null) { // test for compatibility with template if (!(classifier instanceof InputMappedClassifier) && !template.equalHeaders(savedStructure)) { throw new Exception("training and test set are not compatible\n" + template.equalHeadersMsg(savedStructure)); } } objectInputStream.close(); if (classifier == null) { throw new IllegalArgumentException("Failed to classifier from model file " + modelFileName); } return classifier; } /** * Saves the given classifier, along with the template Instances object (if appropriate) to the given file. * * @param classifier the classifier * @param template the template * @param objectOutputFileName the file name */ protected static void saveClassifier(Classifier classifier, Instances template, String objectOutputFileName) throws Exception { File f = new File(objectOutputFileName).getAbsoluteFile(); OutputStream os = new FileOutputStream(f); if (objectOutputFileName.endsWith(".gz")) { os = new GZIPOutputStream(os); } ObjectOutputStream objectOutputStream = new ObjectOutputStream(os); objectOutputStream.writeObject(classifier); if (template != null) { objectOutputStream.writeObject(template); } objectOutputStream.flush(); objectOutputStream.close(); } /** * Evaluates a classifier with the options given in an array of strings. *

* * Valid options are: *

* * -t name of training file
* Name of the file with the training data. (required) *

* * -T name of test file
* Name of the file with the test data. If missing a cross-validation is * performed. *

* * -c class index
* Index of the class attribute (1, 2, ...; default: last). *

* * -x number of folds
* The number of folds for the cross-validation (default: 10). *

* * -no-cv
* No cross validation. If no test file is provided, no evaluation is done. *

* * -split-percentage percentage
* Sets the percentage for the train/test set split, e.g., 66. *

* * -preserve-order
* Preserves the order in the percentage split instead of randomizing the data * first with the seed value ('-s'). *

* * -s seed
* Random number seed for the cross-validation and percentage split (default: * 1). *

* * -m file with cost matrix
* The name of a file containing a cost matrix. *

* * -l filename
* Loads classifier from the given file. In case the filename ends with * ".xml",a PMML file is loaded or, if that fails, options are loaded from * XML. *

* * -d filename
* Saves classifier built from the training data into the given file. In case * the filename ends with ".xml" the options are saved XML, not the model. *

* * -v
* Outputs no statistics for the training data. *

* * -o
* Outputs statistics only, not the classifier. *

* * -output-models-for-training-splits
* Output models for training splits if cross-validation or percentage-split evaluation is used. *

* * -do-not-output-per-class-statistics
* Do not output statistics per class. *

* * -k
* Outputs information-theoretic statistics. *

* * -classifications * "weka.classifiers.evaluation.output.prediction.AbstractOutput + options"
* Uses the specified class for generating the classification output. E.g.: * weka.classifiers.evaluation.output.prediction.PlainText or : * weka.classifiers.evaluation.output.prediction.CSV * * -p range
* Outputs predictions for test instances (or the train instances if no test * instances provided and -no-cv is used), along with the attributes in the * specified range (and nothing else). Use '-p 0' if no attributes are * desired. *

* Deprecated: use "-classifications ..." instead. *

* * -distribution
* Outputs the distribution instead of only the prediction in conjunction with * the '-p' option (only nominal classes). *

* Deprecated: use "-classifications ..." instead. *

* * -no-predictions
* Turns off the collection of predictions in order to conserve memory. *

* * -r
* Outputs cumulative margin distribution (and nothing else). *

* * -g
* Only for classifiers that implement "Graphable." Outputs the graph * representation of the classifier (and nothing else). *

* * -xml filename | xml-string
* Retrieves the options from the XML-data instead of the command line. *

* * @param classifier machine learning classifier * @param options the array of string containing the options * @throws Exception if model could not be evaluated successfully * @return a string describing the results */ public static String evaluateModel(Classifier classifier, String[] options) throws Exception { StringBuffer schemeOptionsText = null; long trainTimeStart = 0, trainTimeElapsed = 0, testTimeStart = 0, testTimeElapsed = 0; // help requested? if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) { // global info requested as well? boolean globalInfo = Utils.getFlag("synopsis", options) || Utils.getFlag("info", options); throw new Exception("\nHelp requested." + makeOptionString(classifier, globalInfo)); } // Store settings for (almost all) general options boolean noCrossValidation = Utils.getFlag("no-cv", options); String classIndexString = Utils.getOption('c', options); String trainFileName = Utils.getOption('t', options); String objectInputFileName = Utils.getOption('l', options); String objectOutputFileName = Utils.getOption('d', options); String testFileName = Utils.getOption('T', options); String foldsString = Utils.getOption('x', options); String seedString = Utils.getOption('s', options); boolean outputModelsForTrainingSplits = Utils.getFlag("output-models-for-training-splits", options); boolean classStatistics = !Utils.getFlag("do-not-output-per-class-statistics", options); boolean noOutput = Utils.getFlag('o', options); boolean trainStatistics = !Utils.getFlag('v', options); boolean printComplexityStatistics = Utils.getFlag('k', options); boolean printMargins = Utils.getFlag('r', options); boolean printGraph = Utils.getFlag('g', options); String sourceClass = Utils.getOption('z', options); boolean printSource = (sourceClass.length() != 0); String thresholdFile = Utils.getOption("threshold-file", options); String thresholdLabel = Utils.getOption("threshold-label", options); boolean forceBatchTraining = Utils.getFlag("force-batch-training", options); String classifications = Utils.getOption("classifications", options); String classificationsOld = Utils.getOption("p", options); String splitPercentageString = Utils.getOption("split-percentage", options); boolean preserveOrder = Utils.getFlag("preserve-order", options); boolean discardPredictions = Utils.getFlag("no-predictions", options); String metricsToToggle = Utils.getOption("toggle", options); boolean continueIteratingIterative = Utils.getFlag("continue-iterating", options); boolean cleanUpIterative = Utils.getFlag("clean-up", options); // Some other variables that we might set later. CostMatrix costMatrix = null; double splitPercentage = -1; int classIndex = -1, actualClassIndex = -1; int seed = 1, folds = 10; Instances train = null, test = null, template = null; AbstractOutput classificationOutput = null; List toggleList = new ArrayList(); int labelIndex = 0; // We need to output help if something goes wrong with the option settings try { if (metricsToToggle.length() > 0) { String[] parts = metricsToToggle.split(","); for (String p : parts) { toggleList.add(p.trim().toLowerCase()); } } // Basic checking for global parameter settings if (trainFileName.length() == 0) { if (objectInputFileName.length() == 0) { throw new IllegalArgumentException("No training file and no object input file given."); } if (testFileName.length() == 0) { throw new IllegalArgumentException("No training file and no test file given."); } } else if ((objectInputFileName.length() != 0) && ((!(classifier instanceof UpdateableClassifier) && !(classifier instanceof IterativeClassifier) || forceBatchTraining) || (testFileName.length() == 0))) { throw new IllegalArgumentException("Classifier not incremental or batch training forced, or no test file provided: can't use model file."); } if ((objectInputFileName.length() != 0) && ((splitPercentageString.length() != 0) || (foldsString.length() != 0))) { throw new IllegalArgumentException("Cannot perform percentage split or cross-validation when model provided."); } if (splitPercentageString.length() != 0) { if (foldsString.length() != 0) { throw new IllegalArgumentException("Percentage split cannot be used in conjunction with cross-validation ('-x')."); } splitPercentage = Double.parseDouble(splitPercentageString); if ((splitPercentage <= 0) || (splitPercentage >= 100)) { throw new IllegalArgumentException("Split percentage needs to be >0 and <100."); } } if ((preserveOrder) && (splitPercentage == -1)) { throw new IllegalArgumentException("Split percentage is missing."); } if (discardPredictions && (classifications.length() > 0 || classificationsOld.length() > 0)) { throw new IllegalArgumentException("Cannot both discard and output predictions!"); } if (thresholdFile.length() > 0 && (classifications.length() > 0 || classificationsOld.length() > 0)) { throw new IllegalArgumentException("Cannot output predictions and also write threshold file!"); } if (thresholdFile.length() > 0 && (!trainStatistics && noCrossValidation && splitPercentageString.length() <= 0 && testFileName.length() <= 0)) { throw new IllegalArgumentException("Can only write a threshold file when performance statistics are computed!"); } if (printMargins && (!trainStatistics && noCrossValidation && splitPercentageString.length() <= 0 && testFileName.length() <= 0)) { throw new IllegalArgumentException("Can only print margins when performance statistics are computed!"); } if ((trainFileName.length() == 0) && (printComplexityStatistics)) { // if no training file given, we don't have any priors throw new IllegalArgumentException("Cannot print complexity statistics without training file!"); } if (printGraph && !(classifier instanceof Drawable)) { throw new IllegalArgumentException("Can only print graph if classifier implements Drawable interface!"); } if (printSource && !(classifier instanceof Sourcable)) { throw new IllegalArgumentException("Can only print source if classifier implements Sourcable interface!"); } if (printGraph && !(trainFileName.length() > 0) && !(objectInputFileName.length() > 0)) { throw new IllegalArgumentException("Can only print graph if training file or model file is provided!"); } if (printSource && !(trainFileName.length() > 0) && !(objectInputFileName.length() > 0)) { throw new IllegalArgumentException("Can only print source if training file or model file is provided!"); } if (objectInputFileName.length() > 0 && (trainFileName.length() > 0) && (!(classifier instanceof UpdateableClassifier) && !(classifier instanceof IterativeClassifier) || forceBatchTraining)) { throw new IllegalArgumentException("Can't use batch training when updating/continue iterating with an existing classifier!"); } if (noCrossValidation && testFileName.length() != 0) { throw new IllegalArgumentException("Attempt to turn off cross-validation when explicit test file is provided!"); } if (splitPercentageString.length() > 0 && testFileName.length() != 0) { throw new IllegalArgumentException("Cannot perform percentage split when explicit test file is provided!"); } if ((thresholdFile.length() != 0) && discardPredictions) { throw new IllegalArgumentException("Can only output to threshold file when predictions are not discarded!"); } if (outputModelsForTrainingSplits && (testFileName.length() > 0 || ((splitPercentageString.length() == 0) && noCrossValidation))) { throw new IllegalArgumentException("Can only output models for training splits if cross-validation or " + "percentage split evaluation is performed!"); } // Set seed, number of folds, and class index if required if (seedString.length() != 0) { seed = Integer.parseInt(seedString); } if (foldsString.length() != 0) { folds = Integer.parseInt(foldsString); } if (classIndexString.length() != 0) { if (classIndexString.equals("first")) { classIndex = 1; } else if (classIndexString.equals("last")) { classIndex = -1; } else { classIndex = Integer.parseInt(classIndexString); } } // Try to open training and/or test file if (testFileName.length() != 0) { try { template = test = new DataSource(testFileName).getStructure(); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { if ((test.classIndex() == -1) || (classIndexString.length() != 0)) { test.setClassIndex(test.numAttributes() - 1); } } actualClassIndex = test.classIndex(); } catch(Exception e){ throw new Exception("Can't open file " + testFileName + '.'); } } if (trainFileName.length() != 0) { try { template = train = new DataSource(trainFileName).getStructure(); if (classIndex != -1) { train.setClassIndex(classIndex - 1); } else { if ((train.classIndex() == -1) || (classIndexString.length() != 0)) { train.setClassIndex(train.numAttributes() - 1); } } actualClassIndex = train.classIndex(); } catch (Exception e) { throw new Exception("Can't open file " + trainFileName + '.'); } } // Need to check whether train and test file are compatible if (!(classifier instanceof weka.classifiers.misc.InputMappedClassifier)) { if ((test != null) && (train != null) && !test.equalHeaders(train)) { throw new IllegalArgumentException("Train and test file not compatible!\n" + test.equalHeadersMsg(train)); } } // Need to check whether output of threshold file is possible if desired by user if ((thresholdFile.length() != 0) && !template.classAttribute().isNominal()) { throw new IllegalArgumentException("Can only output to threshold file when class attribute is nominal!"); } // Need to check whether output of margins is possible if desired by user if (printMargins && !template.classAttribute().isNominal()) { throw new IllegalArgumentException("Can only print margins when class is nominal!"); } // Read model file if appropriate (which may just hold scheme-specific options, and not a built model) if (objectInputFileName.length() != 0) { Classifier backedUpClassifier = classifier; classifier = getModelFromFile(objectInputFileName, template); if (!classifier.getClass().equals(backedUpClassifier.getClass())) { throw new IllegalArgumentException("Loaded classifier is " + classifier.getClass().getCanonicalName() + ", not " + backedUpClassifier.getClass().getCanonicalName() + "!"); } // set options if IterativeClassifier && continue iterating if (continueIteratingIterative && classifier instanceof IterativeClassifier && classifier instanceof OptionHandler) { if (train == null) { throw new IllegalArgumentException("IterativeClassifiers require training " + "data to continue iterating"); } ((OptionHandler) classifier).setOptions(options); ((IterativeClassifier) classifier).setResume(!cleanUpIterative); } } // Check for cost matrix costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses()); // Need to check whether use of cost matrix is possible if desired by user if ((costMatrix != null) && !template.classAttribute().isNominal()) { throw new IllegalArgumentException("Can only use cost matrix when class attribute is nominal!"); } // Determine if predictions are to be output if (classifications.length() > 0) { classificationOutput = AbstractOutput.fromCommandline(classifications); if (classificationOutput == null) { throw new IllegalArgumentException("Failed to instantiate class for classification output: " + classifications); } classificationOutput.setHeader(template); } else if (classificationsOld.length() > 0) { // backwards compatible with old "-p range" and "-distribution" options classificationOutput = new PlainText(); classificationOutput.setHeader(template); if (!classificationsOld.equals("0")) { classificationOutput.setAttributes(classificationsOld); } classificationOutput.setOutputDistribution(Utils.getFlag("distribution", options)); } else { if (Utils.getFlag("distribution", options)) { // -distribution flag needs -p option throw new Exception("Cannot print distribution without '-p' option!"); } } if (thresholdLabel.length() != 0) { labelIndex = template.classAttribute().indexOfValue(thresholdLabel); } if (labelIndex == -1) { throw new IllegalArgumentException("Class label '" + thresholdLabel + "' is unknown!"); } // If a model file is given, we shouldn't process scheme-specific options // (with the exception of IterativeClassifiers) if (objectInputFileName.length() == 0) { if (classifier instanceof OptionHandler) { for (String option : options) { if (option.length() != 0) { if (schemeOptionsText == null) { schemeOptionsText = new StringBuffer(); } if (option.indexOf(' ') != -1) { schemeOptionsText.append('"' + option + "\" "); } else { schemeOptionsText.append(option + " "); } } } ((OptionHandler) classifier).setOptions(options); } } Utils.checkForRemainingOptions(options); } catch (Exception ex) { throw new Exception("\nWeka exception: " + ex.getMessage() + makeOptionString(classifier, false)); } // Build classifier on full training set if necessary Classifier classifierBackup = null; if (objectInputFileName.length() == 0) { classifierBackup = AbstractClassifier.makeCopy(classifier); // Back up configured classifier } if (trainFileName.length() > 0) { if (!noOutput || trainStatistics || printGraph || printSource || objectOutputFileName.length() > 0 || (testFileName.length() > 0) || (classificationOutput != null && noCrossValidation && splitPercentage == -1)) { if ((classifier instanceof UpdateableClassifier) && !forceBatchTraining) { // Build classifier incrementally trainTimeStart = System.currentTimeMillis(); DataSource trainSource = new DataSource(trainFileName); trainSource.getStructure(); // Need to advance in the file to get to the data if (objectInputFileName.length() <= 0) { // Only need to initialize classifier if we haven't loaded one classifier.buildClassifier(new Instances(train, 0)); } while (trainSource.hasMoreElements(train)) { ((UpdateableClassifier) classifier).updateClassifier(trainSource.nextElement(train)); } if (classifier instanceof UpdateableBatchProcessor) { ((UpdateableBatchProcessor) classifier).batchFinished(); } trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } else if (classifier instanceof IterativeClassifier && continueIteratingIterative) { IterativeClassifier iClassifier = (IterativeClassifier)classifier; Instances tempTrain = new DataSource(trainFileName).getDataSet(actualClassIndex); iClassifier.initializeClassifier(tempTrain); while (iClassifier.next()){ } iClassifier.done(); } else { // Build classifier in one go Instances tempTrain = new DataSource(trainFileName).getDataSet(actualClassIndex); if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { Instances mappedClassifierDataset = ((weka.classifiers.misc.InputMappedClassifier) classifier) .getModelHeader(new Instances(template, 0)); if (!mappedClassifierDataset.equalHeaders(tempTrain)) { for (int zz = 0; zz < tempTrain.numInstances(); zz++) { Instance mapped = ((weka.classifiers.misc.InputMappedClassifier) classifier) .constructMappedInstance(tempTrain.instance(zz)); mappedClassifierDataset.add(mapped); } tempTrain = mappedClassifierDataset; } } trainTimeStart = System.currentTimeMillis(); classifier.buildClassifier(tempTrain); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } } } // If classifier is drawable output string describing graph if (printGraph) { return ((Drawable) classifier).graph(); } // Output the classifier as equivalent source if (printSource) { return wekaStaticWrapper((Sourcable) classifier, sourceClass); } // Save classifier if appropriate if (objectOutputFileName.length() > 0) { saveClassifier(classifier, template, objectOutputFileName); } // Output model StringBuffer text = new StringBuffer(); if (!(noOutput || printMargins || classificationOutput != null)) { if (classifier instanceof OptionHandler) { if (schemeOptionsText != null) { text.append("\nOptions: " + schemeOptionsText + "\n"); } } text.append("\n=== Classifier model (full training set) ===\n\n" + classifier.toString() + "\n"); text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds\n"); } // Stop here if no output of performance statistics or predictions is required and no threshold data is required if (!trainStatistics && noCrossValidation && splitPercentage != -1 && testFileName.length() <= 0 && classificationOutput == null) { if (noOutput) { return ""; } else { return text.toString(); } } if (!printMargins && (costMatrix != null)) { text.append("\n=== Evaluation Cost Matrix ===\n\n"); text.append(costMatrix.toString()); } // Do we need a mapped classifier header? Instances mappedClassifierHeader = null; if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { mappedClassifierHeader = ((weka.classifiers.misc.InputMappedClassifier) classifier) .getModelHeader(new Instances(template, 0)); } // Do we just want to output predictions? if (classificationOutput != null) { // =============================================== // Code path for when predictions are to be output // =============================================== // Set up appropriate header for input mapped classifier if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { classificationOutput.setHeader(mappedClassifierHeader); } // Set up buffer StringBuffer predsBuff = new StringBuffer(); classificationOutput.setBuffer(predsBuff); if (testFileName.length() > 0) { // CASE 1: SEPARATE TEST SET predsBuff.append("\n=== Predictions on test data ===\n\n"); classificationOutput.print(classifier, new DataSource(testFileName)); } else if (splitPercentage > 0) { // CASE 2: PERCENTAGE SPLIT Instances tmpInst = new DataSource(trainFileName).getDataSet(actualClassIndex); if (!preserveOrder) { tmpInst.randomize(new Random(seed)); } int trainSize = (int) Math.round(tmpInst.numInstances() * splitPercentage / 100); int testSize = tmpInst.numInstances() - trainSize; Instances trainInst = new Instances(tmpInst, 0, trainSize); classifier = AbstractClassifier.makeCopy(classifierBackup); classifier.buildClassifier(trainInst); trainInst = null; Instances testInst = new Instances(tmpInst, trainSize, testSize); predsBuff.append("\n=== Predictions on test split ===\n\n"); classificationOutput.print(classifier, testInst); } else if (!noCrossValidation) { // CASE 3: CROSS-VALIDATION Random random = new Random(seed); Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix); if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { testingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix); } testingEvaluation.toggleEvalMetrics(toggleList); classifier = AbstractClassifier.makeCopy(classifierBackup); predsBuff.append("\n=== Predictions under cross-validation ===\n\n"); testingEvaluation.crossValidateModel(classifier, new DataSource(trainFileName).getDataSet(actualClassIndex), folds, random, classificationOutput); } else { predsBuff.append("\n=== Predictions on training data ===\n\n"); classificationOutput.print(classifier, new DataSource(trainFileName)); } text.append("\n" + predsBuff); } else { // ================================================ // Code path for when performance is to be computed // ================================================ Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix); if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { testingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix); } testingEvaluation.setDiscardPredictions(discardPredictions); testingEvaluation.toggleEvalMetrics(toggleList); // CASE 1: SEPARATE TEST SET if (testFileName.length() > 0) { // Evaluation on the training data required? if (train != null && trainStatistics && !printMargins) { Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix); if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { trainingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix); } trainingEvaluation.setDiscardPredictions(discardPredictions); trainingEvaluation.toggleEvalMetrics(toggleList); trainingEvaluation.setPriors(train); testingEvaluation.setPriors(train); DataSource trainSource = new DataSource(trainFileName); trainSource.getStructure(); // We already know the structure but need to advance to the data section while (trainSource.hasMoreElements(train)) { Instance trainInst = trainSource.nextElement(train); trainingEvaluation.updatePriors(trainInst); testingEvaluation.updatePriors(trainInst); } if ((classifier instanceof BatchPredictor) && (((BatchPredictor) classifier).implementsMoreEfficientBatchPrediction())) { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, new DataSource(trainFileName).getDataSet(actualClassIndex)); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } else { trainSource = new DataSource(trainFileName); trainSource.getStructure(); // We already know the structure but need to advance to the data section testTimeStart = System.currentTimeMillis(); while (trainSource.hasMoreElements(train)) { trainingEvaluation.evaluateModelOnceAndRecordPrediction(classifier, trainSource.nextElement(train)); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; } text.append("\nTime taken to test model on training data: "); text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } text.append("\n\n" + trainingEvaluation.toMatrixString()); } } // Evaluate on test data if (train == null) { testingEvaluation.useNoPriors(); } else if (!trainStatistics || printMargins) { // We may still need to set the priors based on the training data! testingEvaluation.setPriors(train); DataSource trainSource = new DataSource(trainFileName); trainSource.getStructure(); // We already know the structure but need to advance to the data section while (trainSource.hasMoreElements(train)) { testingEvaluation.updatePriors(trainSource.nextElement(train)); } } if (classifier instanceof BatchPredictor && ((BatchPredictor) classifier).implementsMoreEfficientBatchPrediction()) { testTimeStart = System.currentTimeMillis(); testingEvaluation.evaluateModel(classifier, new DataSource(testFileName).getDataSet(test.classIndex())); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } else { DataSource testSource = new DataSource(testFileName); testSource.getStructure(); // We already know the structure but need to advance to the data section testTimeStart = System.currentTimeMillis(); while (testSource.hasMoreElements(test)) { testingEvaluation.evaluateModelOnceAndRecordPrediction(classifier, testSource.nextElement(test)); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; } if (printMargins) { return testingEvaluation.toCumulativeMarginDistributionString(); } text.append("\nTime taken to test model on test data: "); text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); text.append(testingEvaluation.toSummaryString("\n\n=== Error on test data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } text.append("\n\n" + testingEvaluation.toMatrixString()); } } else if (splitPercentage > 0) { // CASE 2: PERCENTAGE SPLIT // Evaluation on the training data required? if (train != null && trainStatistics && !printMargins) { Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix); if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { trainingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix); } trainingEvaluation.setDiscardPredictions(discardPredictions); trainingEvaluation.toggleEvalMetrics(toggleList); DataSource trainSource = new DataSource(trainFileName); trainSource.getStructure(); // We already know the structure but need to advance to the data section trainingEvaluation.setPriors(train); while (trainSource.hasMoreElements(train)) { trainingEvaluation.updatePriors(trainSource.nextElement(train)); } if ((classifier instanceof BatchPredictor) && (((BatchPredictor) classifier).implementsMoreEfficientBatchPrediction())) { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, new DataSource(trainFileName).getDataSet(actualClassIndex)); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } else { trainSource = new DataSource(trainFileName); trainSource.getStructure(); // We already know the structure but need to advance to the data section testTimeStart = System.currentTimeMillis(); while (trainSource.hasMoreElements(train)) { trainingEvaluation.evaluateModelOnceAndRecordPrediction(classifier, trainSource.nextElement(train)); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; } text.append("\nTime taken to test model on training data: "); text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } text.append("\n\n" + trainingEvaluation.toMatrixString()); } } Instances tmpInst = new DataSource(trainFileName).getDataSet(actualClassIndex); if (!preserveOrder) { tmpInst.randomize(new Random(seed)); } int trainSize = (int) Math.round(tmpInst.numInstances() * splitPercentage / 100); int testSize = tmpInst.numInstances() - trainSize; Instances trainInst = new Instances(tmpInst, 0, trainSize); classifier = AbstractClassifier.makeCopy(classifierBackup); classifier.buildClassifier(trainInst); if (outputModelsForTrainingSplits) { text.append("\n=== Classifier model (training split) ===\n\n" + classifier.toString() + "\n"); } testingEvaluation.setPriors(trainInst); trainInst = null; Instances testInst = new Instances(tmpInst, trainSize, testSize); testTimeStart = System.currentTimeMillis(); testingEvaluation.evaluateModel(classifier, testInst); testTimeElapsed = System.currentTimeMillis() - testTimeStart; if (printMargins) { return testingEvaluation.toCumulativeMarginDistributionString(); } text.append("\nTime taken to test model on test split: "); text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); text.append(testingEvaluation.toSummaryString("\n\n=== Error on test split ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } text.append("\n\n" + testingEvaluation.toMatrixString()); } } else if (!noCrossValidation) { // CASE 3: CROSS-VALIDATION // Evaluation on the training data required? if (train != null && trainStatistics && !printMargins) { Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix); if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { trainingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix); } trainingEvaluation.setDiscardPredictions(discardPredictions); trainingEvaluation.toggleEvalMetrics(toggleList); DataSource trainSource = new DataSource(trainFileName); trainSource.getStructure(); // We already know the structure but need to advance to the data section trainingEvaluation.setPriors(train); while (trainSource.hasMoreElements(train)) { trainingEvaluation.updatePriors(trainSource.nextElement(train)); } if ((classifier instanceof BatchPredictor) && (((BatchPredictor) classifier).implementsMoreEfficientBatchPrediction())) { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, new DataSource(trainFileName).getDataSet(actualClassIndex)); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } else { trainSource = new DataSource(trainFileName); trainSource.getStructure(); // We already know the structure but need to advance to the data section testTimeStart = System.currentTimeMillis(); while (trainSource.hasMoreElements(train)) { trainingEvaluation.evaluateModelOnceAndRecordPrediction(classifier, trainSource.nextElement(train)); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; } text.append("\nTime taken to test model on training data: "); text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } text.append("\n\n" + trainingEvaluation.toMatrixString()); } } Random random = new Random(seed); // use untrained (!) classifier for cross-validation classifier = AbstractClassifier.makeCopy(classifierBackup); testTimeStart = System.currentTimeMillis(); if (!outputModelsForTrainingSplits) { testingEvaluation.crossValidateModel(classifier, new DataSource(trainFileName).getDataSet(actualClassIndex), folds, random); } else { testingEvaluation.crossValidateModel(classifier, new DataSource(trainFileName).getDataSet(actualClassIndex), folds, random, text); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; if (printMargins) { return testingEvaluation.toCumulativeMarginDistributionString(); } text.append("\nTime taken to perform cross-validation: "); text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation.toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation.toSummaryString("=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } text.append("\n\n" + testingEvaluation.toMatrixString()); } } else if (trainStatistics) { // CASE 4: Only evaluate on the training set // Evaluation on the training data required? if (train != null) { Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix); if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { trainingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix); } trainingEvaluation.setDiscardPredictions(discardPredictions); trainingEvaluation.toggleEvalMetrics(toggleList); DataSource trainSource = new DataSource(trainFileName); trainSource.getStructure(); // We already know the structure but need to advance to the data section trainingEvaluation.setPriors(train); while (trainSource.hasMoreElements(train)) { trainingEvaluation.updatePriors(trainSource.nextElement(train)); } if ((classifier instanceof BatchPredictor) && (((BatchPredictor) classifier).implementsMoreEfficientBatchPrediction())) { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, new DataSource(trainFileName).getDataSet(actualClassIndex)); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } else { trainSource = new DataSource(trainFileName); trainSource.getStructure(); // We already know the structure but need to advance to the data section testTimeStart = System.currentTimeMillis(); while (trainSource.hasMoreElements(train)) { trainingEvaluation.evaluateModelOnceAndRecordPrediction(classifier, trainSource.nextElement(train)); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; } if (printMargins) { return trainingEvaluation.toCumulativeMarginDistributionString(); } text.append("\nTime taken to test model on training data: "); text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } text.append("\n\n" + trainingEvaluation.toMatrixString()); } testingEvaluation = trainingEvaluation; } } // Output threshold file if (thresholdFile.length() != 0) { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(testingEvaluation.predictions(), labelIndex); DataSink.write(thresholdFile, result); } } return text.toString(); } /** * Attempts to load a cost matrix. * * @param costFileName the filename of the cost matrix * @param numClasses the number of classes that should be in the cost matrix * (only used if the cost file is in old format). * @return a CostMatrix value, or null if costFileName is empty * @throws Exception if an error occurs. */ protected static CostMatrix handleCostOption(String costFileName, int numClasses) throws Exception { if ((costFileName != null) && (costFileName.length() != 0)) { Reader costReader = null; try { costReader = new BufferedReader(new FileReader(costFileName)); } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } try { // First try as a proper cost matrix format return new CostMatrix(costReader); } catch (Exception ex) { try { // Now try as the poxy old format :-) // System.err.println("Attempting to read old format cost file"); try { costReader.close(); // Close the old one costReader = new BufferedReader(new FileReader(costFileName)); } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } CostMatrix costMatrix = new CostMatrix(numClasses); // System.err.println("Created default cost matrix"); costMatrix.readOldFormat(costReader); return costMatrix; // System.err.println("Read old format"); } catch (Exception e2) { // re-throw the original exception // System.err.println("Re-throwing original exception"); throw ex; } } } else { return null; } } /** * Evaluates the classifier on a given set of instances. Note that the data * must have exactly the same format (e.g. order of attributes) as the data * used to train the classifier! Otherwise the results will generally be * meaningless. * * @param classifier machine learning classifier * @param data set of test instances for evaluation * @param forPredictionsPrinting varargs parameter that, if supplied, is * expected to hold a * weka.classifiers.evaluation.output.prediction.AbstractOutput * object * @return the predictions * @throws Exception if model could not be evaluated successfully */ public double[] evaluateModel(Classifier classifier, Instances data, Object... forPredictionsPrinting) throws Exception { // for predictions printing AbstractOutput classificationOutput = null; double predictions[] = new double[data.numInstances()]; if (forPredictionsPrinting.length > 0) { classificationOutput = (AbstractOutput) forPredictionsPrinting[0]; } if (classifier instanceof BatchPredictor && ((BatchPredictor) classifier).implementsMoreEfficientBatchPrediction()) { // make a copy and set the class to missing Instances dataPred = new Instances(data); for (int i = 0; i < data.numInstances(); i++) { dataPred.instance(i).setClassMissing(); } double[][] preds = ((BatchPredictor) classifier).distributionsForInstances(dataPred); for (int i = 0; i < data.numInstances(); i++) { double[] p = preds[i]; predictions[i] = evaluationForSingleInstance(p, data.instance(i), true); if (classificationOutput != null) { classificationOutput.printClassification(p, data.instance(i), i); } } } else { // Need to be able to collect predictions if appropriate (for AUC) for (int i = 0; i < data.numInstances(); i++) { predictions[i] = evaluateModelOnceAndRecordPrediction(classifier, data.instance(i)); if (classificationOutput != null) { classificationOutput.printClassification(classifier, data.instance(i), i); } } } return predictions; } /** * Evaluates the supplied distribution on a single instance. * * @param dist the supplied distribution * @param instance the test instance to be classified * @param storePredictions whether to store predictions for nominal classifier * @return the prediction * @throws Exception if model could not be evaluated successfully */ public double evaluationForSingleInstance(double[] dist, Instance instance, boolean storePredictions) throws Exception { double pred; if (m_ClassIsNominal) { pred = Utils.maxIndex(dist); if (dist[(int) pred] <= 0) { pred = Utils.missingValue(); } updateStatsForClassifier(dist, instance); if (storePredictions && !m_DiscardPredictions) { if (m_Predictions == null) { m_Predictions = new ArrayList(); } m_Predictions.add(new NominalPrediction(instance.classValue(), dist, instance.weight())); } } else { pred = dist[0]; updateStatsForPredictor(pred, instance); if (storePredictions && !m_DiscardPredictions) { if (m_Predictions == null) { m_Predictions = new ArrayList(); } m_Predictions.add(new NumericPrediction(instance.classValue(), pred, instance.weight())); } } return pred; } /** * Evaluates the classifier on a single instance and records the prediction. * * @param classifier machine learning classifier * @param instance the test instance to be classified * @param storePredictions whether to store predictions for nominal classifier * @return the prediction made by the clasifier * @throws Exception if model could not be evaluated successfully or the data * contains string attributes */ protected double evaluationForSingleInstance(Classifier classifier, Instance instance, boolean storePredictions) throws Exception { Instance classMissing = (Instance) instance.copy(); classMissing.setDataset(instance.dataset()); if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { instance = (Instance) instance.copy(); instance = ((weka.classifiers.misc.InputMappedClassifier) classifier) .constructMappedInstance(instance); // System.out.println("Mapped instance " + instance); int mappedClass = ((weka.classifiers.misc.InputMappedClassifier) classifier) .getMappedClassIndex(); classMissing.setMissing(mappedClass); } else { classMissing.setClassMissing(); } // System.out.println("instance (to predict)" + classMissing); double pred = evaluationForSingleInstance( classifier.distributionForInstance(classMissing), instance, storePredictions); // We don't need to do the following if the class is nominal because in that // case // entropy and coverage statistics are always computed. if (!m_ClassIsNominal) { if (!instance.classIsMissing() && !Utils.isMissingValue(pred)) { if (classifier instanceof IntervalEstimator) { updateStatsForIntervalEstimator((IntervalEstimator) classifier, classMissing, instance.classValue()); } else { m_CoverageStatisticsAvailable = false; } if (classifier instanceof ConditionalDensityEstimator) { updateStatsForConditionalDensityEstimator( (ConditionalDensityEstimator) classifier, classMissing, instance.classValue()); } else { m_ComplexityStatisticsAvailable = false; } } } return pred; } /** * Evaluates the classifier on a single instance and records the prediction. * * @param classifier machine learning classifier * @param instance the test instance to be classified * @return the prediction made by the clasifier * @throws Exception if model could not be evaluated successfully or the data * contains string attributes */ public double evaluateModelOnceAndRecordPrediction(Classifier classifier, Instance instance) throws Exception { return evaluationForSingleInstance(classifier, instance, true); } /** * Evaluates the classifier on a single instance. * * @param classifier machine learning classifier * @param instance the test instance to be classified * @return the prediction made by the clasifier * @throws Exception if model could not be evaluated successfully or the data * contains string attributes */ public double evaluateModelOnce(Classifier classifier, Instance instance) throws Exception { return evaluationForSingleInstance(classifier, instance, false); } /** * Evaluates the supplied distribution on a single instance. * * @param dist the supplied distribution * @param instance the test instance to be classified * @return the prediction * @throws Exception if model could not be evaluated successfully */ public double evaluateModelOnce(double[] dist, Instance instance) throws Exception { return evaluationForSingleInstance(dist, instance, false); } /** * Evaluates the supplied distribution on a single instance. * * @param dist the supplied distribution * @param instance the test instance to be classified * @return the prediction * @throws Exception if model could not be evaluated successfully */ public double evaluateModelOnceAndRecordPrediction(double[] dist, Instance instance) throws Exception { return evaluationForSingleInstance(dist, instance, true); } /** * Evaluates the supplied prediction on a single instance. * * @param prediction the supplied prediction * @param instance the test instance to be classified * @throws Exception if model could not be evaluated successfully */ public void evaluateModelOnce(double prediction, Instance instance) throws Exception { evaluateModelOnce(makeDistribution(prediction), instance); } /** * Returns the predictions that have been collected. * * @return a reference to the FastVector containing the predictions that have * been collected. This should be null if no predictions have been * collected. */ public ArrayList predictions() { if (m_DiscardPredictions) { return null; } else { return m_Predictions; } } /** * Wraps a static classifier in enough source to test using the weka class * libraries. * * @param classifier a Sourcable Classifier * @param className the name to give to the source code class * @return the source for a static classifier that can be tested with weka * libraries. * @throws Exception if code-generation fails */ public static String wekaStaticWrapper(Sourcable classifier, String className) throws Exception { StringBuffer result = new StringBuffer(); String staticClassifier = classifier.toSource(className); result.append("// Generated with Weka " + Version.VERSION + "\n"); result.append("//\n"); result .append("// This code is public domain and comes with no warranty.\n"); result.append("//\n"); result.append("// Timestamp: " + new Date() + "\n"); result.append("\n"); result.append("package weka.classifiers;\n"); result.append("\n"); result.append("import weka.core.Attribute;\n"); result.append("import weka.core.Capabilities;\n"); result.append("import weka.core.Capabilities.Capability;\n"); result.append("import weka.core.Instance;\n"); result.append("import weka.core.Instances;\n"); result.append("import weka.core.RevisionUtils;\n"); result .append("import weka.classifiers.Classifier;\nimport weka.classifiers.AbstractClassifier;\n"); result.append("\n"); result.append("public class WekaWrapper\n"); result.append(" extends AbstractClassifier {\n"); // globalInfo result.append("\n"); result.append(" /**\n"); result.append(" * Returns only the toString() method.\n"); result.append(" *\n"); result.append(" * @return a string describing the classifier\n"); result.append(" */\n"); result.append(" public String globalInfo() {\n"); result.append(" return toString();\n"); result.append(" }\n"); // getCapabilities result.append("\n"); result.append(" /**\n"); result.append(" * Returns the capabilities of this classifier.\n"); result.append(" *\n"); result.append(" * @return the capabilities\n"); result.append(" */\n"); result.append(" public Capabilities getCapabilities() {\n"); result.append(((Classifier) classifier).getCapabilities().toSource( "result", 4)); result.append(" return result;\n"); result.append(" }\n"); // buildClassifier result.append("\n"); result.append(" /**\n"); result.append(" * only checks the data against its capabilities.\n"); result.append(" *\n"); result.append(" * @param i the training data\n"); result.append(" */\n"); result .append(" public void buildClassifier(Instances i) throws Exception {\n"); result.append(" // can classifier handle the data?\n"); result.append(" getCapabilities().testWithFail(i);\n"); result.append(" }\n"); // classifyInstance result.append("\n"); result.append(" /**\n"); result.append(" * Classifies the given instance.\n"); result.append(" *\n"); result.append(" * @param i the instance to classify\n"); result.append(" * @return the classification result\n"); result.append(" */\n"); result .append(" public double classifyInstance(Instance i) throws Exception {\n"); result.append(" Object[] s = new Object[i.numAttributes()];\n"); result.append(" \n"); result.append(" for (int j = 0; j < s.length; j++) {\n"); result.append(" if (!i.isMissing(j)) {\n"); result.append(" if (i.attribute(j).isNominal())\n"); result.append(" s[j] = new String(i.stringValue(j));\n"); result.append(" else if (i.attribute(j).isNumeric())\n"); result.append(" s[j] = new Double(i.value(j));\n"); result.append(" }\n"); result.append(" }\n"); result.append(" \n"); result.append(" // set class value to missing\n"); result.append(" s[i.classIndex()] = null;\n"); result.append(" \n"); result.append(" return " + className + ".classify(s);\n"); result.append(" }\n"); // getRevision result.append("\n"); result.append(" /**\n"); result.append(" * Returns the revision string.\n"); result.append(" * \n"); result.append(" * @return the revision\n"); result.append(" */\n"); result.append(" public String getRevision() {\n"); result.append(" return RevisionUtils.extract(\"1.0\");\n"); result.append(" }\n"); // toString result.append("\n"); result.append(" /**\n"); result .append(" * Returns only the classnames and what classifier it is based on.\n"); result.append(" *\n"); result.append(" * @return a short description\n"); result.append(" */\n"); result.append(" public String toString() {\n"); result.append(" return \"Auto-generated classifier wrapper, based on " + classifier.getClass().getName() + " (generated with Weka " + Version.VERSION + ").\\n" + "\" + this.getClass().getName() + \"/" + className + "\";\n"); result.append(" }\n"); // main result.append("\n"); result.append(" /**\n"); result.append(" * Runs the classfier from commandline.\n"); result.append(" *\n"); result.append(" * @param args the commandline arguments\n"); result.append(" */\n"); result.append(" public static void main(String args[]) {\n"); result.append(" runClassifier(new WekaWrapper(), args);\n"); result.append(" }\n"); result.append("}\n"); // actual classifier code result.append("\n"); result.append(staticClassifier); return result.toString(); } /** * Gets the number of test instances that had a known class value (actually * the sum of the weights of test instances with known class value). * * @return the number of test instances with known class */ public final double numInstances() { return m_WithClass; } /** * Gets the coverage of the test cases by the predicted regions at the * confidence level specified when evaluation was performed. * * @return the coverage of the test cases by the predicted regions */ public final double coverageOfTestCasesByPredictedRegions() { if (!m_CoverageStatisticsAvailable) { return Double.NaN; } return 100 * m_TotalCoverage / m_WithClass; } /** * Gets the average size of the predicted regions, relative to the range of * the target in the training data, at the confidence level specified when * evaluation was performed. * * @return the average size of the predicted regions */ public final double sizeOfPredictedRegions() { if (m_NoPriors || !m_CoverageStatisticsAvailable) { return Double.NaN; } return 100 * m_TotalSizeOfRegions / m_WithClass; } /** * Gets the weight of the instances that had a non-missing class value * * @return the weight of the instances that had a non-missing class value */ public final double withClass() { return m_WithClass; } /** * Gets the weight of the instances that had missing class values * * @return the weight of the instances that had missing class values */ public final double missingClass() { return m_MissingClass; } /** * Gets the number of instances incorrectly classified (that is, for which an * incorrect prediction was made). (Actually the sum of the weights of these * instances) * * @return the number of incorrectly classified instances */ public final double incorrect() { return m_Incorrect; } /** * Gets the percentage of instances incorrectly classified (that is, for which * an incorrect prediction was made). * * @return the percent of incorrectly classified instances (between 0 and 100) */ public final double pctIncorrect() { return 100 * m_Incorrect / m_WithClass; } /** * Gets the total cost, that is, the cost of each prediction times the weight * of the instance, summed over all instances. * * @return the total cost */ public final double totalCost() { return m_TotalCost; } /** * Gets the average cost, that is, total cost of misclassifications (incorrect * plus unclassified) over the total number of instances. * * @return the average cost. */ public final double avgCost() { return m_TotalCost / m_WithClass; } /** * Gets the number of instances correctly classified (that is, for which a * correct prediction was made). (Actually the sum of the weights of these * instances) * * @return the number of correctly classified instances */ public final double correct() { return m_Correct; } /** * Gets the percentage of instances correctly classified (that is, for which a * correct prediction was made). * * @return the percent of correctly classified instances (between 0 and 100) */ public final double pctCorrect() { return 100 * m_Correct / m_WithClass; } /** * Gets the number of instances not classified (that is, for which no * prediction was made by the classifier). (Actually the sum of the weights of * these instances) * * @return the number of unclassified instances */ public final double unclassified() { return m_Unclassified; } /** * Gets the percentage of instances not classified (that is, for which no * prediction was made by the classifier). * * @return the percent of unclassified instances (between 0 and 100) */ public final double pctUnclassified() { return 100 * m_Unclassified / m_WithClass; } /** * Returns the estimated error rate or the root mean squared error (if the * class is numeric). If a cost matrix was given this error rate gives the * average cost. * * @return the estimated error rate (between 0 and 1, or between 0 and maximum * cost) */ public final double errorRate() { if (!m_ClassIsNominal) { return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified)); } if (m_CostMatrix == null) { return m_Incorrect / m_WithClass; } else { return avgCost(); } } /** * Returns value of kappa statistic if class is nominal. * * @return the value of the kappa statistic */ public final double kappa() { double[] sumRows = new double[m_ConfusionMatrix.length]; double[] sumColumns = new double[m_ConfusionMatrix.length]; double sumOfWeights = 0; for (int i = 0; i < m_ConfusionMatrix.length; i++) { for (int j = 0; j < m_ConfusionMatrix.length; j++) { sumRows[i] += m_ConfusionMatrix[i][j]; sumColumns[j] += m_ConfusionMatrix[i][j]; sumOfWeights += m_ConfusionMatrix[i][j]; } } double correct = 0, chanceAgreement = 0; for (int i = 0; i < m_ConfusionMatrix.length; i++) { chanceAgreement += (sumRows[i] * sumColumns[i]); correct += m_ConfusionMatrix[i][i]; } chanceAgreement /= (sumOfWeights * sumOfWeights); correct /= sumOfWeights; if (chanceAgreement < 1) { return (correct - chanceAgreement) / (1 - chanceAgreement); } else { return 1; } } /** * Returns the correlation coefficient if the class is numeric. * * @return the correlation coefficient * @throws Exception if class is not numeric */ public final double correlationCoefficient() throws Exception { if (m_ClassIsNominal) { throw new Exception("Can't compute correlation coefficient: " + "class is nominal!"); } double correlation = 0; double varActual = m_SumSqrClass - m_SumClass * m_SumClass / (m_WithClass - m_Unclassified); double varPredicted = m_SumSqrPredicted - m_SumPredicted * m_SumPredicted / (m_WithClass - m_Unclassified); double varProd = m_SumClassPredicted - m_SumClass * m_SumPredicted / (m_WithClass - m_Unclassified); if (varActual * varPredicted <= 0) { correlation = 0.0; } else { correlation = varProd / Math.sqrt(varActual * varPredicted); } return correlation; } /** * Returns the mean absolute error. Refers to the error of the predicted * values for numeric classes, and the error of the predicted probability * distribution for nominal classes. * * @return the mean absolute error */ public final double meanAbsoluteError() { return m_SumAbsErr / (m_WithClass - m_Unclassified); } /** * Returns the mean absolute error of the prior. * * @return the mean absolute error */ public final double meanPriorAbsoluteError() { if (m_NoPriors) { return Double.NaN; } return m_SumPriorAbsErr / m_WithClass; } /** * Returns the relative absolute error. * * @return the relative absolute error * @throws Exception if it can't be computed */ public final double relativeAbsoluteError() throws Exception { if (m_NoPriors) { return Double.NaN; } return 100 * meanAbsoluteError() / meanPriorAbsoluteError(); } /** * Returns the root mean squared error. * * @return the root mean squared error */ public final double rootMeanSquaredError() { return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified)); } /** * Returns the root mean prior squared error. * * @return the root mean prior squared error */ public final double rootMeanPriorSquaredError() { if (m_NoPriors) { return Double.NaN; } return Math.sqrt(m_SumPriorSqrErr / m_WithClass); } /** * Returns the root relative squared error if the class is numeric. * * @return the root relative squared error */ public final double rootRelativeSquaredError() { if (m_NoPriors) { return Double.NaN; } return 100.0 * rootMeanSquaredError() / rootMeanPriorSquaredError(); } /** * Returns the mean base-2 log loss wrt the null model. Just calls * SFMeanPriorEntropy. * * @return the null model entropy per instance */ public final double priorEntropy() { // The previous version of this method calculated // mean base-2 log loss for the null model wrt the // instances passed into setPriors(). Now, this method will // return the loss for the null model wrt to the test // instances (whatever they are). return SFMeanPriorEntropy(); } /** * Return the total Kononenko & Bratko Information score in bits. * * @return the K&B information score * @throws Exception if the class is not nominal */ public final double KBInformation() throws Exception { if (!m_ClassIsNominal) { throw new Exception("Can't compute K&B Info score: " + "class numeric!"); } if (m_NoPriors) { return Double.NaN; } return m_SumKBInfo; } /** * Return the Kononenko & Bratko Information score in bits per instance. * * @return the K&B information score * @throws Exception if the class is not nominal */ public final double KBMeanInformation() throws Exception { if (!m_ClassIsNominal) { throw new Exception("Can't compute K&B Info score: class numeric!"); } if (m_NoPriors) { return Double.NaN; } return m_SumKBInfo / (m_WithClass - m_Unclassified); } /** * Return the Kononenko & Bratko Relative Information score. * Differs slightly from the expression used in KB's paper * because it uses the mean log-loss of the TEST instances * wrt to the null model for normalization. * * @return the K&B relative information score * @throws Exception if the class is not nominal */ public final double KBRelativeInformation() throws Exception { if (!m_ClassIsNominal) { throw new Exception("Can't compute K&B Info score: " + "class numeric!"); } if (m_NoPriors) { return Double.NaN; } return 100.0 * KBMeanInformation() / priorEntropy(); } /** * Returns the base-2 log loss wrt the null model. * * @return the total null model entropy */ public final double SFPriorEntropy() { if (m_NoPriors || !m_ComplexityStatisticsAvailable) { return Double.NaN; } return m_SumPriorEntropy; } /** * Returns the mean base-2 log loss wrt the null model. * * @return the null model entropy per instance */ public final double SFMeanPriorEntropy() { if (m_NoPriors || !m_ComplexityStatisticsAvailable) { return Double.NaN; } return m_SumPriorEntropy / m_WithClass; } /** * Returns the base-2 log loss wrt the scheme. * * @return the total scheme entropy */ public final double SFSchemeEntropy() { if (!m_ComplexityStatisticsAvailable) { return Double.NaN; } return m_SumSchemeEntropy; } /** * Returns the mean base-2 log loss wrt the scheme. * * @return the scheme entropy per instance */ public final double SFMeanSchemeEntropy() { if (!m_ComplexityStatisticsAvailable) { return Double.NaN; } return m_SumSchemeEntropy / (m_WithClass - m_Unclassified); } /** * Returns the difference in base-2 log loss between null model and scheme. * * @return the total "SF score" */ public final double SFEntropyGain() { if (m_NoPriors || !m_ComplexityStatisticsAvailable) { return Double.NaN; } return m_SumPriorEntropy - m_SumSchemeEntropy; } /** * Returns the mean difference in base-2 log loss between null model and scheme. * * @return the "SF score" per instance */ public final double SFMeanEntropyGain() { if (m_NoPriors || !m_ComplexityStatisticsAvailable) { return Double.NaN; } return (m_SumPriorEntropy - m_SumSchemeEntropy) / (m_WithClass - m_Unclassified); } /** * Output the cumulative margin distribution as a string suitable for input * for gnuplot or similar package. * * @return the cumulative margin distribution * @throws Exception if the class attribute is nominal */ public String toCumulativeMarginDistributionString() throws Exception { if (!m_ClassIsNominal) { throw new Exception("Class must be nominal for margin distributions"); } String result = ""; double cumulativeCount = 0; double margin; for (int i = 0; i <= k_MarginResolution; i++) { if (m_MarginCounts[i] != 0) { cumulativeCount += m_MarginCounts[i]; margin = i * 2.0 / k_MarginResolution - 1.0; result = result + Utils.doubleToString(margin, 7, 3) + ' ' + Utils.doubleToString(cumulativeCount * 100 / m_WithClass, 7, 3) + '\n'; } else if (i == 0) { result = Utils.doubleToString(-1.0, 7, 3) + ' ' + Utils.doubleToString(0, 7, 3) + '\n'; } } return result; } /** * Calls toSummaryString() with no title and no complexity stats. * * @return a summary description of the classifier evaluation */ @Override public String toSummaryString() { return toSummaryString("", false); } /** * Calls toSummaryString() with a default title. * * @param printComplexityStatistics if true, complexity statistics are * returned as well * @return the summary string */ public String toSummaryString(boolean printComplexityStatistics) { return toSummaryString("=== Summary ===\n", printComplexityStatistics); } /** * Outputs the performance statistics in summary form. Lists number (and * percentage) of instances classified correctly, incorrectly and * unclassified. Outputs the total number of instances classified, and the * number of instances (if any) that had no class value provided. * * @param title the title for the statistics * @param printComplexityStatistics if true, complexity statistics are * returned as well * @return the summary as a String */ public String toSummaryString(String title, boolean printComplexityStatistics) { StringBuffer text = new StringBuffer(); if (printComplexityStatistics && m_NoPriors) { printComplexityStatistics = false; System.err .println("Priors disabled, cannot print complexity statistics!"); } text.append(title + "\n"); try { if (m_WithClass > 0) { if (m_ClassIsNominal) { boolean displayCorrect = m_metricsToDisplay.contains("correct"); boolean displayIncorrect = m_metricsToDisplay.contains("incorrect"); boolean displayKappa = m_metricsToDisplay.contains("kappa"); boolean displayTotalCost = m_metricsToDisplay.contains("total cost"); boolean displayAverageCost = m_metricsToDisplay.contains("average cost"); if (displayCorrect) { text.append("Correctly Classified Instances "); text.append(Utils.doubleToString(correct(), 12, 4) + " " + Utils.doubleToString(pctCorrect(), 12, 4) + " %\n"); } if (displayIncorrect) { text.append("Incorrectly Classified Instances "); text.append(Utils.doubleToString(incorrect(), 12, 4) + " " + Utils.doubleToString(pctIncorrect(), 12, 4) + " %\n"); } if (displayKappa) { text.append("Kappa statistic "); text.append(Utils.doubleToString(kappa(), 12, 4) + "\n"); } if (m_CostMatrix != null) { if (displayTotalCost) { text.append("Total Cost "); text.append(Utils.doubleToString(totalCost(), 12, 4) + "\n"); } if (displayAverageCost) { text.append("Average Cost "); text.append(Utils.doubleToString(avgCost(), 12, 4) + "\n"); } } if (printComplexityStatistics) { boolean displayKBRelative = m_metricsToDisplay.contains("kb relative"); boolean displayKBInfo = m_metricsToDisplay.contains("kb information"); if (displayKBRelative) { text.append("K&B Relative Info Score "); text.append(Utils.doubleToString(KBRelativeInformation(), 12, 4) + " %\n"); } if (displayKBInfo) { text.append("K&B Information Score "); text.append(Utils.doubleToString(KBInformation(), 12, 4) + " bits"); text.append(Utils.doubleToString(KBMeanInformation(), 12, 4) + " bits/instance\n"); } } } else { boolean displayCorrelation = m_metricsToDisplay.contains("correlation"); if (displayCorrelation) { text.append("Correlation coefficient "); text.append(Utils.doubleToString(correlationCoefficient(), 12, 4) + "\n"); } } if (printComplexityStatistics && m_ComplexityStatisticsAvailable) { boolean displayComplexityOrder0 = m_metricsToDisplay.contains("complexity 0"); boolean displayComplexityScheme = m_metricsToDisplay.contains("complexity scheme"); boolean displayComplexityImprovement = m_metricsToDisplay.contains("complexity improvement"); if (displayComplexityOrder0) { text.append("Class complexity | order 0 "); text .append(Utils.doubleToString(SFPriorEntropy(), 12, 4) + " bits"); text.append(Utils.doubleToString(SFMeanPriorEntropy(), 12, 4) + " bits/instance\n"); } if (displayComplexityScheme) { text.append("Class complexity | scheme "); text.append(Utils.doubleToString(SFSchemeEntropy(), 12, 4) + " bits"); text.append(Utils.doubleToString(SFMeanSchemeEntropy(), 12, 4) + " bits/instance\n"); } if (displayComplexityImprovement) { text.append("Complexity improvement (Sf) "); text.append(Utils.doubleToString(SFEntropyGain(), 12, 4) + " bits"); text.append(Utils.doubleToString(SFMeanEntropyGain(), 12, 4) + " bits/instance\n"); } } boolean displayMAE = m_metricsToDisplay.contains("mae"); boolean displayRMSE = m_metricsToDisplay.contains("rmse"); boolean displayRAE = m_metricsToDisplay.contains("rae"); boolean displayRRSE = m_metricsToDisplay.contains("rrse"); if (displayMAE) { text.append("Mean absolute error "); text.append(Utils.doubleToString(meanAbsoluteError(), 12, 4) + "\n"); } if (displayRMSE) { text.append("Root mean squared error "); text.append(Utils.doubleToString(rootMeanSquaredError(), 12, 4) + "\n"); } if (!m_NoPriors) { if (displayRAE) { text.append("Relative absolute error "); text.append(Utils.doubleToString(relativeAbsoluteError(), 12, 4) + " %\n"); } if (displayRRSE) { text.append("Root relative squared error "); text.append(Utils.doubleToString(rootRelativeSquaredError(), 12, 4) + " %\n"); } } if (m_CoverageStatisticsAvailable) { boolean displayCoverage = m_metricsToDisplay.contains("coverage"); boolean displayRegionSize = m_metricsToDisplay.contains("region size"); if (displayCoverage) { text.append("Coverage of cases (" + Utils.doubleToString(m_ConfLevel, 4, 2) + " level) "); text.append(Utils.doubleToString( coverageOfTestCasesByPredictedRegions(), 12, 4) + " %\n"); } if (!m_NoPriors) { if (displayRegionSize) { text.append("Mean rel. region size (" + Utils.doubleToString(m_ConfLevel, 4, 2) + " level) "); text.append(Utils.doubleToString(sizeOfPredictedRegions(), 12, 4) + " %\n"); } } } } if (Utils.gr(unclassified(), 0)) { text.append("UnClassified Instances "); text.append(Utils.doubleToString(unclassified(), 12, 4) + " " + Utils.doubleToString(pctUnclassified(), 12, 4) + " %\n"); } text.append("Total Number of Instances "); text.append(Utils.doubleToString(m_WithClass, 12, 4) + "\n"); if (m_MissingClass > 0) { text.append("Ignored Class Unknown Instances "); text.append(Utils.doubleToString(m_MissingClass, 12, 4) + "\n"); } } catch (Exception ex) { // Should never occur since the class is known to be nominal // here System.err.println("Arggh - Must be a bug in Evaluation class"); ex.printStackTrace(); } return text.toString(); } /** * Calls toMatrixString() with a default title. * * @return the confusion matrix as a string * @throws Exception if the class is numeric */ public String toMatrixString() throws Exception { return toMatrixString("=== Confusion Matrix ===\n"); } /** * Outputs the performance statistics as a classification confusion matrix. * For each class value, shows the distribution of predicted class values. * * @param title the title for the confusion matrix * @return the confusion matrix as a String * @throws Exception if the class is numeric */ public String toMatrixString(String title) throws Exception { StringBuffer text = new StringBuffer(); char[] IDChars = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z' }; int IDWidth; boolean fractional = false; if (!m_ClassIsNominal) { throw new Exception("Evaluation: No confusion matrix possible!"); } // Find the maximum value in the matrix // and check for fractional display requirement double maxval = 0; for (int i = 0; i < m_NumClasses; i++) { for (int j = 0; j < m_NumClasses; j++) { double current = m_ConfusionMatrix[i][j]; if (current < 0) { current *= -10; } if (current > maxval) { maxval = current; } double fract = current - Math.rint(current); if (!fractional && ((Math.log(fract) / Math.log(10)) >= -2)) { fractional = true; } } } IDWidth = 1 + Math.max( (int) (Math.log(maxval) / Math.log(10) + (fractional ? 3 : 0)), (int) (Math.log(m_NumClasses) / Math.log(IDChars.length))); text.append(title).append("\n"); for (int i = 0; i < m_NumClasses; i++) { if (fractional) { text.append(" ").append(num2ShortID(i, IDChars, IDWidth - 3)) .append(" "); } else { text.append(" ").append(num2ShortID(i, IDChars, IDWidth)); } } text.append(" <-- classified as\n"); for (int i = 0; i < m_NumClasses; i++) { for (int j = 0; j < m_NumClasses; j++) { text.append(" ").append( Utils.doubleToString(m_ConfusionMatrix[i][j], IDWidth, (fractional ? 2 : 0))); } text.append(" | ").append(num2ShortID(i, IDChars, IDWidth)).append(" = ") .append(m_ClassNames[i]).append("\n"); } return text.toString(); } /** * Generates a breakdown of the accuracy for each class (with default title), * incorporating various information-retrieval statistics, such as true/false * positive rate, precision/recall/F-Measure. Should be useful for ROC curves, * recall/precision curves. * * @return the statistics presented as a string * @throws Exception if class is not nominal */ public String toClassDetailsString() throws Exception { return toClassDetailsString("=== Detailed Accuracy By Class ===\n"); } /** * Generates a breakdown of the accuracy for each class, incorporating various * information-retrieval statistics, such as true/false positive rate, * precision/recall/F-Measure. Should be useful for ROC curves, * recall/precision curves. * * @param title the title to prepend the stats string with * @return the statistics presented as a string * @throws Exception if class is not nominal */ public String toClassDetailsString(String title) throws Exception { if (!m_ClassIsNominal) { throw new Exception("Evaluation: No per class statistics possible!"); } boolean displayTP = m_metricsToDisplay.contains("tp rate"); boolean displayFP = m_metricsToDisplay.contains("fp rate"); boolean displayP = m_metricsToDisplay.contains("precision"); boolean displayR = m_metricsToDisplay.contains("recall"); boolean displayFM = m_metricsToDisplay.contains("f-measure"); boolean displayMCC = m_metricsToDisplay.contains("mcc"); boolean displayROC = m_metricsToDisplay.contains("roc area"); boolean displayPRC = m_metricsToDisplay.contains("prc area"); StringBuffer text = new StringBuffer(title + "\n " + (displayTP ? "TP Rate " : "") + (displayFP ? "FP Rate " : "") + (displayP ? "Precision " : "") + (displayR ? "Recall " : "") + (displayFM ? "F-Measure " : "") + (displayMCC ? "MCC " : "") + (displayROC ? "ROC Area " : "") + (displayPRC ? "PRC Area " : "")); text.append("Class\n"); for (int i = 0; i < m_NumClasses; i++) { text.append(" "); if (displayTP) { double tpr = truePositiveRate(i); if (Utils.isMissingValue(tpr)) { text.append("? "); } else { text.append(String.format("%-9.3f", tpr)); } } if (displayFP) { double fpr = falsePositiveRate(i); if (Utils.isMissingValue(fpr)) { text.append("? "); } else { text.append(String.format("%-9.3f", fpr)); } } if (displayP) { double p = precision(i); if (Utils.isMissingValue(p)) { text.append("? "); } else { text.append(String.format("%-11.3f", precision(i))); } } if (displayR) { double r = recall(i); if (Utils.isMissingValue(r)) { text.append("? "); } else { text.append(String.format("%-9.3f", recall(i))); } } if (displayFM) { double fm = fMeasure(i); if (Utils.isMissingValue(fm)) { text.append("? "); } else { text.append(String.format("%-11.3f", fMeasure(i))); } } if (displayMCC) { double mat = matthewsCorrelationCoefficient(i); if (Utils.isMissingValue(mat)) { text.append("? "); } else { text.append(String .format("%-9.3f", matthewsCorrelationCoefficient(i))); } } if (displayROC) { double rocVal = areaUnderROC(i); if (Utils.isMissingValue(rocVal)) { text.append("? "); } else { text.append(String.format("%-10.3f", rocVal)); } } if (displayPRC) { double prcVal = areaUnderPRC(i); if (Utils.isMissingValue(prcVal)) { text.append("? "); } else { text.append(String.format("%-10.3f", prcVal)); } } text.append(m_ClassNames[i]).append('\n'); } text.append("Weighted Avg. "); if (displayTP) { double wtpr = weightedTruePositiveRate(); if (Utils.isMissingValue(wtpr)) { text.append("? "); } else { text.append(String.format("%-9.3f", wtpr)); } } if (displayFP) { double wfpr = weightedFalsePositiveRate(); if (Utils.isMissingValue(wfpr)) { text.append("? "); } else { text.append(String.format("%-9.3f", wfpr)); } } if (displayP) { double wp = weightedPrecision(); if (Utils.isMissingValue(wp)) { text.append("? "); } else { text.append(String.format("%-11.3f", wp)); } } if (displayR) { double wr = weightedRecall(); if (Utils.isMissingValue(wr)) { text.append("? "); } else { text.append(String.format("%-9.3f", wr)); } } if (displayFM) { double wf = weightedFMeasure(); if (Utils.isMissingValue(wf)) { text.append("? "); } else { text.append(String.format("%-11.3f", wf)); } } if (displayMCC) { double wmc = weightedMatthewsCorrelation(); if (Utils.isMissingValue(wmc)) { text.append("? "); } else { text.append(String.format("%-9.3f", wmc)); } } if (displayROC) { double wroc = weightedAreaUnderROC(); if (Utils.isMissingValue(wroc)) { text.append("? "); } else { text.append(String.format("%-10.3f", wroc)); } } if (displayPRC) { double wprc = weightedAreaUnderPRC(); if (Utils.isMissingValue(wprc)) { text.append("? "); } else { text.append(String.format("%-10.3f", wprc)); } } text.append("\n"); return text.toString(); } /** * Returns the number (really, weight) of instances that have been classified. This is equal to the sum * of all entries in the confusion matrix. * * @return the number (really, weight) of instances that have been classified */ public double numClassified() { return m_WithClass - m_Unclassified; } /** * Calculate the number (really, weight) of instances of the given class. This * only includes instances for which a prediction is made. Unclassified instances * are not included even if they have a class value. * * @param classIndex the index of the class to consider as "positive" * @return the number of positive instances */ public double numPositives(int classIndex) { double total = 0; for (int j = 0; j < m_NumClasses; j++) { total += m_ConfusionMatrix[classIndex][j]; } return total; } /** * Calculate the number (really, weight) of instances predicted to be of the given class. * * Unclassified instances are not included in the calculation. * * @param classIndex the index of the class to consider as "positive" * @return the number of instances that are predicted to be positive */ public double numPredictedPositives(int classIndex) { double total = 0; for (int j = 0; j < m_NumClasses; j++) { total += m_ConfusionMatrix[j][classIndex]; } return total; } /** * Calculate the number (really, weight) of instances not in the given class. This * only includes instances for which a prediction is made. Unclassified instances * are not included even if they have a class value. * * @param classIndex the index of the class to consider as "positive" * @return the number of negative instances */ public double numNegatives(int classIndex) { return numClassified() - numPositives(classIndex); } /** * Calculate the number (really, weight) of instances predicted not to be of the given class. * * Unclassified instances are not included in the calculation. * * @param classIndex the index of the class to consider as "positive" * @return the number of instances that are predicted to be negative */ public double numPredictedNegatives(int classIndex) { return numClassified() - numPredictedPositives(classIndex); } /** * Returns the number (really, weight) of instances in each class. Unclassified instances are not included. * * @return an array with the number (really, weight) of instances in each class */ public double[] trueClassCounts() { double[] classCounts = new double[m_NumClasses]; for (int i = 0; i < m_NumClasses; i++) { classCounts[i] += numPositives(i); } return classCounts; } /** * Returns the predicted number (really, weight) of instances in each class. * * Unclassified instances are not included in the calculation. * * @return an array with the predicted number (really, weight) of instances in each class */ public double[] predictedClassCounts() { double[] classCounts = new double[m_NumClasses]; for (int i = 0; i < m_NumClasses; i++) { classCounts[i] += numPredictedPositives(i); } return classCounts; } /** * Calculate the number (really, weight) of true positives with respect to a particular class. * This is defined as *

* *

   * correctly classified positives
   * 
* * Unclassified instances are not included in the calculation. * * @param classIndex the index of the class to consider as "positive" * @return the true positive rate */ public double numTruePositives(int classIndex) { return m_ConfusionMatrix[classIndex][classIndex]; } /** * Calculate number (really, weight) of false negatives with respect to a particular class. * This is defined as *

* *

   * incorrectly classified positives
   * 
* * Unclassified instances are not included in the calculation. * * @param classIndex the index of the class to consider as "positive" * @return the false positive rate */ public double numFalseNegatives(int classIndex) { return numPositives(classIndex) - numTruePositives(classIndex); } /** * Calculate the number (really, weight) of true negatives with respect to a particular class. * This is defined as *

* *

   * correctly classified negatives
   * 
* * Unclassified instances are not included in the calculation * * @param classIndex the index of the class to consider as "positive" * @return the true positive rate */ public double numTrueNegatives(int classIndex) { return numNegatives(classIndex) - numFalsePositives(classIndex); } /** * Calculate the number (really, weight) of false positives with respect to a particular class. * This is defined as *

* *

   * incorrectly classified negatives
   * 
* * Unclassified instances are not included in the calculation * * @param classIndex the index of the class to consider as "positive" * @return the false positive rate */ public double numFalsePositives(int classIndex) { return numPredictedPositives(classIndex) - numTruePositives(classIndex); } /** * Calculate the true positive rate with respect to a particular class. This * is defined as *

* *

   * correctly classified positives
   * ------------------------------
   *       total positives
   * 
* * This does not include any instances that are left unclassified by the classifier. * * @param classIndex the index of the class to consider as "positive" * @return the true positive rate */ public double truePositiveRate(int classIndex) { return truePositiveRate(classIndex, numPositives(classIndex)); } /** * Calculate the true positive rate with respect to a particular class. This * is defined as *

* *

   * correctly classified positives
   * ------------------------------
   *       total positives
   * 
* * Unclassified instances are not included in the calculation. * * @param classIndex the index of the class to consider as "positive" * @param numPositives the total number (really, weight) of positives for that class * @return the true positive rate */ public double truePositiveRate(int classIndex, double numPositives) { return numTruePositives(classIndex) / numPositives; } /** * Calculates the weighted (by class size) true positive rate. Unclassified instances are not * included in the calculation. * * Unclassified instances are not included in the calculation. * * @return the weighted true positive rate. */ public double weightedTruePositiveRate() { double[] classCounts = trueClassCounts(); double truePosTotal = 0; for (int i = 0; i < m_NumClasses; i++) { if (classCounts[i] > 0) { // If TPR is NaN, we want the sum to also be NaN if count > 0 truePosTotal += truePositiveRate(i, classCounts[i]) * classCounts[i]; } } return truePosTotal / numClassified(); } /** * Calculate the true negative rate with respect to a particular class. This * is defined as *

* *

   * correctly classified negatives
   * ------------------------------
   *       total negatives
   * 
* * Unclassified instances are not included in the calculation * * @param classIndex the index of the class to consider as "positive" * @return the true positive rate */ public double trueNegativeRate(int classIndex) { return trueNegativeRate(classIndex, numNegatives(classIndex)); } /** * Calculate the true negative rate with respect to a particular class. This * is defined as *

* *

   * correctly classified negatives
   * ------------------------------
   *       total negatives
   * 
* * Unclassified instances are not included in the calculation * * @param classIndex the index of the class to consider as "positive" * @param numNegatives the number (really, weight) of the negative instances * @return the true positive rate */ public double trueNegativeRate(int classIndex, double numNegatives) { return numTrueNegatives(classIndex) / numNegatives; } /** * Calculate the true negative rate with respect to a particular class. This * is defined as *

* *

   * correctly classified negatives
   * ------------------------------
   *       total negatives
   * 
* * Unclassified instances are not included in the calculation * * @param classIndex the index of the class to consider as "positive" * @return the true positive rate */ /** * Calculates the weighted (by class size) true negative rate. Unclassified instances are not * included in the calculation. * * Unclassified instances are not included in the calculation. * * @return the weighted true negative rate. */ public double weightedTrueNegativeRate() { double[] classCounts = trueClassCounts(); double trueNegTotal = 0; for (int i = 0; i < m_NumClasses; i++) { if (classCounts[i] > 0) { // If TNR is NaN, we want the sum to also be NaN if count > 0 trueNegTotal += trueNegativeRate(i, numClassified() - classCounts[i]) * classCounts[i]; } } return trueNegTotal / numClassified(); } /** * Calculate the false positive rate with respect to a particular class. This * is defined as *

* *

   * incorrectly classified negatives
   * --------------------------------
   *        total negatives
   * 
* * Unclassified instances are not included in the calculation * * @param classIndex the index of the class to consider as "positive" * @return the false positive rate */ public double falsePositiveRate(int classIndex) { return falsePositiveRate(classIndex, numNegatives(classIndex)); } /** * Calculate the false positive rate with respect to a particular class. This * is defined as *

* *

   * incorrectly classified negatives
   * --------------------------------
   *        total negatives
   * 
* * Unclassified instances are not included in the calculation * * @param classIndex the index of the class to consider as "positive" * @param numNegatives the number (really, weight) of the negative instances * @return the false positive rate */ public double falsePositiveRate(int classIndex, double numNegatives) { return numFalsePositives(classIndex) / numNegatives; } /** * Calculates the weighted (by class size) false positive rate. Unclassified instances are not * included in the calculation. * * Unclassified instances are not included in the calculation. * * @return the weighted false positive rate. */ public double weightedFalsePositiveRate() { double[] classCounts = trueClassCounts(); double falsePosTotal = 0; for (int i = 0; i < m_NumClasses; i++) { if (classCounts[i] > 0) { // If FPR is NaN, we want the sum to also be NaN if count > 0 falsePosTotal += falsePositiveRate(i, numClassified() - classCounts[i]) * classCounts[i]; } } return falsePosTotal / numClassified(); } /** * Calculate the false negative rate with respect to a particular class. This * is defined as *

* *

   * incorrectly classified positives
   * --------------------------------
   *        total positives
   * 
* * Unclassified instances are not included in the calculation. * * @param classIndex the index of the class to consider as "positive" * @return the false positive rate */ public double falseNegativeRate(int classIndex) { double numPositives = numPositives(classIndex); return (numPositives - numTruePositives(classIndex)) / numPositives; } /** * Calculate the false negative rate with respect to a particular class. This * is defined as *

* *

   * incorrectly classified positives
   * --------------------------------
   *        total positives
   * 
* * Unclassified instances are not included in the calculation. * * @param classIndex the index of the class to consider as "positive" * @param numPositives the number (really, weight) of the instances in the positive class * @return the false positive rate */ public double falseNegativeRate(int classIndex, double numPositives) { return (numPositives - numTruePositives(classIndex)) / numPositives; } /** * Calculates the weighted (by class size) false negative rate. * * Unclassified instances are not included in the calculation. * * @return the weighted false negative rate. */ public double weightedFalseNegativeRate() { double[] classCounts = trueClassCounts(); double falseNegTotal = 0; for (int i = 0; i < m_NumClasses; i++) { if (classCounts[i] > 0) { // If FNR is NaN, we want the sum to also be NaN if count > 0 falseNegTotal += falseNegativeRate(i, classCounts[i]) * classCounts[i]; } } return falseNegTotal / numClassified(); } /** * Calculates the matthews correlation coefficient (sometimes called phi * coefficient) for the supplied class. Unclassified instances are * not included in the calculation. * * @param classIndex the index of the class to compute the matthews * correlation coefficient for * * Unclassified instances are not included in the calculation. * * @return the mathews correlation coefficient */ public double matthewsCorrelationCoefficient(int classIndex) { return matthewsCorrelationCoefficient(classIndex, numTruePositives(classIndex), numTrueNegatives(classIndex), numFalsePositives(classIndex), numFalseNegatives(classIndex)); } /** * Calculates the matthews correlation coefficient (sometimes called phi * coefficient) for the supplied class. Unclassified instances are * not included in the calculation. * * @param classIndex the index of the class to compute the matthews * correlation coefficient for * @param numTP the number (really, weight) of the true positive instances * @param numTN the number (really, weight) of the true negative instances * @param numFP the number (really, weight) of the false positive instances * @param numFN the number (really, weight) of the false negative instances * * @return the mathews correlation coefficient */ public double matthewsCorrelationCoefficient(int classIndex, double numTP, double numTN, double numFP, double numFN) { double n = (numTP * numTN) - (numFP * numFN); double d = (numTP + numFP) * (numTP + numFN) * (numTN + numFP) * (numTN + numFN); return n / Math.sqrt(d); } /** * Calculates the weighted (by class size) matthews correlation coefficient. * * Unclassified instances are not included in the calculation. * * @return the weighted matthews correlation coefficient. */ public double weightedMatthewsCorrelation() { double[] trueClassCounts = trueClassCounts(); double[] predictedClassCounts = predictedClassCounts(); double mccTotal = 0; for (int i = 0; i < m_NumClasses; i++) { if (trueClassCounts[i] > 0) { // If MCC is NaN, we want the sum to also be NaN if count > 0 double numFP = predictedClassCounts[i] - numTruePositives(i); mccTotal += matthewsCorrelationCoefficient(i, numTruePositives(i), (numClassified() - trueClassCounts[i]) - numFP, numFP, trueClassCounts[i] - numTruePositives(i)) * trueClassCounts[i]; } } return mccTotal / numClassified(); } /** * Calculate the recall with respect to a particular class. This is defined as *

* *

   * correctly classified positives
   * ------------------------------
   *       total positives
   * 
*

* * This is the same as the true positive rate. * * Unclassified instances are not included in the calculation. * * @param classIndex the index of the class to consider as "positive" * @return the recall */ public double recall(int classIndex) { return truePositiveRate(classIndex); } /** * Calculates the weighted (by class size) recall. * * Unclassified instances are not included in the calculation. * * @return the weighted recall. */ public double weightedRecall() { return weightedTruePositiveRate(); } /** * Calculate the precision with respect to a particular class. This is defined * as *

* *

   * correctly classified positives
   * ------------------------------
   *  total predicted as positive
   * 
* * Unclassified instances are not included in the calculation. * * @param classIndex the index of the class to consider as "positive" * @return the precision */ public double precision(int classIndex) { return numTruePositives(classIndex) / numPredictedPositives(classIndex); } /** * Calculates the weighted (by class size) precision. * * Unclassified instances are not included in the calculation. * * @return the weighted precision. */ public double weightedPrecision() { double[] trueClassCounts = trueClassCounts(); double[] predictedClassCounts = predictedClassCounts(); double precisionTotal = 0; for (int i = 0; i < m_NumClasses; i++) { if (trueClassCounts[i] > 0) { // If precision is NaN, we want the sum to also be NaN if count > 0 precisionTotal += (numTruePositives(i) / predictedClassCounts[i]) * trueClassCounts[i]; } } return precisionTotal / numClassified(); } /** * Calculate the F-Measure with respect to a particular class. This is defined * as *

* *

   * 2 * recall * precision
   * ----------------------
   *   recall + precision
   * 
* * Returns zero when both precision and recall are zero. * * Unclassified instances are not included in the calculation. * * @param classIndex the index of the class to consider as "positive" * @return the F-Measure */ public double fMeasure(int classIndex) { return fMeasure(precision(classIndex), recall(classIndex)); } /** * Calculate the F-Measure with respect to a particular class. This is defined * as *

* *

   * 2 * recall * precision
   * ----------------------
   *   recall + precision
   * 
* * Returns zero when both precision and recall are zero. * * Unclassified instances are not included in the calculation. * * @param precision the precision * @param recall the recall * @return the F-Measure */ public double fMeasure(double precision, double recall) { if ((precision == 0) && (recall == 0)) { return 0; } return 2 * precision * recall / (precision + recall); } /** * Calculates the macro weighted (by class size) average F-Measure. * * Unclassified instances are not included in the calculation. * * @return the weighted F-Measure. */ public double weightedFMeasure() { double[] trueClassCounts = trueClassCounts(); double[] predictedClassCounts = predictedClassCounts(); double fMeasureTotal = 0; for (int i = 0; i < m_NumClasses; i++) { if (trueClassCounts[i] > 0) { // If F-measure is NaN, we want the sum to also be NaN if count > 0 fMeasureTotal += fMeasure(numTruePositives(i) / predictedClassCounts[i], numTruePositives(i) / trueClassCounts[i]) * trueClassCounts[i]; } } return fMeasureTotal / numClassified(); } /** * Unweighted macro-averaged F-measure. If some classes not present in the * test set, they're just skipped (since recall is undefined there anyway) . * * Unclassified instances are not included in the calculation. * * @return unweighted macro-averaged F-measure. * */ public double unweightedMacroFmeasure() { double[] trueClassCounts = trueClassCounts(); double[] predictedClassCounts = predictedClassCounts(); weka.experiment.Stats rr = new weka.experiment.Stats(); for (int c = 0; c < m_NumClasses; c++) { // skip if no testing positive cases of this class if (trueClassCounts[c] > 0) { rr.add(fMeasure(numTruePositives(c) / predictedClassCounts[c], numTruePositives(c) / trueClassCounts[c])); } } rr.calculateDerived(); return rr.mean; } /** * Unweighted micro-averaged F-measure. If some classes not present in the * test set, they have no effect. * * Note: if the test set is *single-label*, then this is the same as accuracy. * * Unclassified instances are not included in the calculation. * * @return unweighted micro-averaged F-measure. */ public double unweightedMicroFmeasure() { double[] trueClassCounts = trueClassCounts(); double[] predictedClassCounts = predictedClassCounts(); double tp = 0; double fn = 0; double fp = 0; for (int c = 0; c < m_NumClasses; c++) { tp += numTruePositives(c); fn += trueClassCounts[c] - numTruePositives(c); fp += predictedClassCounts[c] - numTruePositives(c); } return 2 * tp / (2 * tp + fn + fp); } /** * Sets the class prior probabilities. * * @param train the training instances used to determine the prior * probabilities * @throws Exception if the class attribute of the instances is not set */ public void setPriors(Instances train) throws Exception { m_NoPriors = false; if (!m_ClassIsNominal) { m_NumTrainClassVals = 0; m_TrainClassVals = null; m_TrainClassWeights = null; m_PriorEstimator = null; m_MinTarget = Double.MAX_VALUE; m_MaxTarget = -Double.MAX_VALUE; for (int i = 0; i < train.numInstances(); i++) { Instance currentInst = train.instance(i); if (!currentInst.classIsMissing()) { addNumericTrainClass(currentInst.classValue(), currentInst.weight()); } } m_ClassPriors[0] = m_ClassPriorsSum = 0; for (int i = 0; i < train.numInstances(); i++) { if (!train.instance(i).classIsMissing()) { m_ClassPriors[0] += train.instance(i).classValue() * train.instance(i).weight(); m_ClassPriorsSum += train.instance(i).weight(); } } } else { for (int i = 0; i < m_NumClasses; i++) { m_ClassPriors[i] = 1; } m_ClassPriorsSum = m_NumClasses; for (int i = 0; i < train.numInstances(); i++) { if (!train.instance(i).classIsMissing()) { m_ClassPriors[(int) train.instance(i).classValue()] += train.instance(i).weight(); m_ClassPriorsSum += train.instance(i).weight(); } } m_MaxTarget = m_NumClasses; m_MinTarget = 0; } } /** * Get the current weighted class counts. * * @return the weighted class counts */ public double[] getClassPriors() { return m_ClassPriors; } /** * Updates the class prior probabilities or the mean respectively (when * incrementally training). * * @param instance the new training instance seen * @throws Exception if the class of the instance is not set */ public void updatePriors(Instance instance) throws Exception { if (!instance.classIsMissing()) { if (!m_ClassIsNominal) { addNumericTrainClass(instance.classValue(), instance.weight()); m_ClassPriors[0] += instance.classValue() * instance.weight(); m_ClassPriorsSum += instance.weight(); } else { m_ClassPriors[(int) instance.classValue()] += instance.weight(); m_ClassPriorsSum += instance.weight(); } } } /** * disables the use of priors, e.g., in case of de-serialized schemes that * have no access to the original training set, but are evaluated on a set * set. */ public void useNoPriors() { m_NoPriors = true; } /** * Tests whether the current evaluation object is equal to another evaluation * object. * * @param obj the object to compare against * @return true if the two objects are equal */ @Override public boolean equals(Object obj) { if ((obj == null) || !(obj.getClass().equals(this.getClass()))) { return false; } Evaluation cmp = (Evaluation) obj; if (m_ClassIsNominal != cmp.m_ClassIsNominal) { return false; } if (m_NumClasses != cmp.m_NumClasses) { return false; } if (m_Incorrect != cmp.m_Incorrect) { return false; } if (m_Correct != cmp.m_Correct) { return false; } if (m_Unclassified != cmp.m_Unclassified) { return false; } if (m_MissingClass != cmp.m_MissingClass) { return false; } if (m_WithClass != cmp.m_WithClass) { return false; } if (m_SumErr != cmp.m_SumErr) { return false; } if (m_SumAbsErr != cmp.m_SumAbsErr) { return false; } if (m_SumSqrErr != cmp.m_SumSqrErr) { return false; } if (m_SumClass != cmp.m_SumClass) { return false; } if (m_SumSqrClass != cmp.m_SumSqrClass) { return false; } if (m_SumPredicted != cmp.m_SumPredicted) { return false; } if (m_SumSqrPredicted != cmp.m_SumSqrPredicted) { return false; } if (m_SumClassPredicted != cmp.m_SumClassPredicted) { return false; } if (m_ClassIsNominal) { for (int i = 0; i < m_NumClasses; i++) { for (int j = 0; j < m_NumClasses; j++) { if (m_ConfusionMatrix[i][j] != cmp.m_ConfusionMatrix[i][j]) { return false; } } } } return true; } /** * Make up the help string giving all the command line options. * * @param classifier the classifier to include options for * @param globalInfo include the global information string for the classifier * (if available). * @return a string detailing the valid command line options */ protected static String makeOptionString(Classifier classifier, boolean globalInfo) { StringBuffer optionsText = new StringBuffer(""); // General options optionsText.append("\n\nGeneral options:\n\n"); optionsText.append("-h or -help\n"); optionsText.append("\tOutput help information.\n"); optionsText.append("-synopsis or -info\n"); optionsText.append("\tOutput synopsis for classifier (use in conjunction " + " with -h)\n"); optionsText.append("-t \n"); optionsText.append("\tSets training file.\n"); optionsText.append("-T \n"); optionsText .append("\tSets test file. If missing, a cross-validation will be performed\n"); optionsText.append("\ton the training data.\n"); optionsText.append("-c \n"); optionsText.append("\tSets index of class attribute (default: last).\n"); optionsText.append("-x \n"); optionsText .append("\tSets number of folds for cross-validation (default: 10).\n"); optionsText.append("-no-cv\n"); optionsText.append("\tDo not perform any cross validation.\n"); optionsText.append("-force-batch-training\n"); optionsText .append("\tAlways train classifier in batch mode, never incrementally.\n"); optionsText.append("-split-percentage \n"); optionsText .append("\tSets the percentage for the train/test set split, e.g., 66.\n"); optionsText.append("-preserve-order\n"); optionsText.append("\tPreserves the order in the percentage split.\n"); optionsText.append("-s \n"); optionsText .append("\tSets random number seed for cross-validation or percentage split\n"); optionsText.append("\t(default: 1).\n"); optionsText.append("-m \n"); optionsText.append("\tSets file with cost matrix.\n"); optionsText.append("-continue-iterating\n"); optionsText.append("\tContinue training an IterativeClassifier model that has\n\t" + "been loaded via -l.\n"); optionsText.append("-clean-up\n"); optionsText.append("\tReduce storage size of an loaded IterativeClassifier\n" + "\tafter iterating. This effectively 'freezes' the model, and no further\n" + "\titeration is then possible.\n"); optionsText .append("-toggle \n"); optionsText .append("\tComma separated list of metric names to toggle in the output.\n\t" + "All metrics are output by default with the exception of 'Coverage' and " + "'Region size'.\n\t"); optionsText.append("Available metrics:\n\t"); List metricsToDisplay = new ArrayList(Arrays.asList(BUILT_IN_EVAL_METRICS)); int length = 0; for (int i = 0; i < metricsToDisplay.size(); i++) { optionsText.append(metricsToDisplay.get(i)); length += metricsToDisplay.get(i).length(); if (i != metricsToDisplay.size() - 1) { optionsText.append(","); } if (length >= 60) { optionsText.append("\n\t"); length = 0; } } optionsText.append("\n"); optionsText.append("-l \n"); optionsText .append("\tSets model input file. In case the filename ends with '.xml',\n"); optionsText .append("\ta PMML file is loaded or, if that fails, options are loaded\n"); optionsText.append("\tfrom the XML file.\n"); optionsText.append("-d \n"); optionsText .append("\tSets model output file. In case the filename ends with '.xml',\n"); optionsText .append("\tonly the options are saved to the XML file, not the model.\n"); optionsText.append("-v\n"); optionsText.append("\tOutputs no statistics for training data.\n"); optionsText.append("-o\n"); optionsText.append("\tOutputs statistics only, not the classifier.\n"); optionsText.append("-output-models-for-training-splits\n"); optionsText.append("\tOutput models for training splits if cross-validation or percentage-split evaluation is used.\n"); optionsText.append("-do-not-output-per-class-statistics\n"); optionsText.append("\tDo not output statistics for each class.\n"); optionsText.append("-k\n"); optionsText.append("\tOutputs information-theoretic statistics.\n"); optionsText .append("-classifications \"weka.classifiers.evaluation.output.prediction.AbstractOutput + options\"\n"); optionsText .append("\tUses the specified class for generating the classification output.\n"); optionsText.append("\tE.g.: " + PlainText.class.getName() + "\n"); optionsText.append("-p range\n"); optionsText .append("\tOutputs predictions for test instances (or the train instances if\n"); optionsText .append("\tno test instances provided and -no-cv is used), along with the \n"); optionsText .append("\tattributes in the specified range (and nothing else). \n"); optionsText.append("\tUse '-p 0' if no attributes are desired.\n"); optionsText.append("\tDeprecated: use \"-classifications ...\" instead.\n"); optionsText.append("-distribution\n"); optionsText .append("\tOutputs the distribution instead of only the prediction\n"); optionsText .append("\tin conjunction with the '-p' option (only nominal classes).\n"); optionsText.append("\tDeprecated: use \"-classifications ...\" instead.\n"); optionsText.append("-r\n"); optionsText.append("\tOnly outputs cumulative margin distribution.\n"); if (classifier instanceof Sourcable) { optionsText.append("-z \n"); optionsText.append("\tOnly outputs the source representation" + " of the classifier,\n\tgiving it the supplied" + " name.\n"); } if (classifier instanceof Drawable) { optionsText.append("-g\n"); optionsText.append("\tOnly outputs the graph representation" + " of the classifier.\n"); } optionsText.append("-xml filename | xml-string\n"); optionsText .append("\tRetrieves the options from the XML-data instead of the " + "command line.\n"); optionsText.append("-threshold-file \n"); optionsText .append("\tThe file to save the threshold data to.\n" + "\tThe format is determined by the extensions, e.g., '.arff' for ARFF \n" + "\tformat or '.csv' for CSV.\n"); optionsText.append("-threshold-label




© 2015 - 2024 Weber Informatics LLC | Privacy Policy