All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.experiment.CrossValidationSplitResultProducer Maven / Gradle / Ivy

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    CrossValidationSplitResultProducer.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.experiment;

import java.util.Random;

import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;

/**
 *  Carries out one split of a repeated k-fold
 * cross-validation, using the set SplitEvaluator to generate some results. Note
 * that the run number is actually the nth split of a repeated k-fold
 * cross-validation, i.e. if k=10, run number 100 is the 10th fold of the 10th
 * cross-validation run. This producer's sole purpose is to allow more
 * fine-grained distribution of cross-validation experiments. If the class
 * attribute is nominal, the dataset is stratified.
 * 

* * * Valid options are: *

* *

 * -X <number of folds>
 *  The number of folds to use for the cross-validation.
 *  (default 10)
 * 
* *
 * -D
 * Save raw split evaluator output.
 * 
* *
 * -O <file/directory name/path>
 *  The filename where raw output will be stored.
 *  If a directory name is specified then then individual
 *  outputs will be gzipped, otherwise all output will be
 *  zipped to the named file. Use in conjuction with -D. (default splitEvalutorOut.zip)
 * 
* *
 * -W <class name>
 *  The full class name of a SplitEvaluator.
 *  eg: weka.experiment.ClassifierSplitEvaluator
 * 
* *
 * Options specific to split evaluator weka.experiment.ClassifierSplitEvaluator:
 * 
* *
 * -W <class name>
 *  The full class name of the classifier.
 *  eg: weka.classifiers.bayes.NaiveBayes
 * 
* *
 * -C <index>
 *  The index of the class for which IR statistics
 *  are to be output. (default 1)
 * 
* *
 * -I <index>
 *  The index of an attribute to output in the
 *  results. This attribute should identify an
 *  instance in order to know which instances are
 *  in the test set of a cross validation. if 0
 *  no output (default 0).
 * 
* *
 * -P
 *  Add target and prediction columns to the result
 *  for each fold.
 * 
* *
 * Options specific to classifier weka.classifiers.rules.ZeroR:
 * 
* *
 * -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
 * 
* * * * All options after -- will be passed to the split evaluator. * * @author Len Trigg * @author Eibe Frank * @version $Revision: 10203 $ */ public class CrossValidationSplitResultProducer extends CrossValidationResultProducer { /** for serialization */ static final long serialVersionUID = 1403798164046795073L; /** * Returns a string describing this result producer * * @return a description of the result producer suitable for displaying in the * explorer/experimenter gui */ @Override public String globalInfo() { return "Carries out one split of a repeated k-fold cross-validation, " + "using the set SplitEvaluator to generate some results. " + "Note that the run number is actually the nth split of a repeated " + "k-fold cross-validation, i.e. if k=10, run number 100 is the 10th " + "fold of the 10th cross-validation run. This producer's sole purpose " + "is to allow more fine-grained distribution of cross-validation " + "experiments. If the class attribute is nominal, the dataset is stratified."; } /** * Gets the keys for a specified run number. Different run numbers correspond * to different randomizations of the data. Keys produced should be sent to * the current ResultListener * * @param run the run number to get keys for. * @throws Exception if a problem occurs while getting the keys */ @Override public void doRunKeys(int run) throws Exception { if (m_Instances == null) { throw new Exception("No Instances set"); } // Add in some fields to the key like run and fold number, dataset name Object[] seKey = m_SplitEvaluator.getKey(); Object[] key = new Object[seKey.length + 3]; key[0] = Utils.backQuoteChars(m_Instances.relationName()); key[2] = "" + (((run - 1) % m_NumFolds) + 1); key[1] = "" + (((run - 1) / m_NumFolds) + 1); System.arraycopy(seKey, 0, key, 3, seKey.length); if (m_ResultListener.isResultRequired(this, key)) { try { m_ResultListener.acceptResult(this, key, null); } catch (Exception ex) { // Save the train and test datasets for debugging purposes? throw ex; } } } /** * Gets the results for a specified run number. Different run numbers * correspond to different randomizations of the data. Results produced should * be sent to the current ResultListener * * @param run the run number to get results for. * @throws Exception if a problem occurs while getting the results */ @Override public void doRun(int run) throws Exception { if (getRawOutput()) { if (m_ZipDest == null) { m_ZipDest = new OutputZipper(m_OutputFile); } } if (m_Instances == null) { throw new Exception("No Instances set"); } // Compute run and fold number from given run int fold = (run - 1) % m_NumFolds; run = ((run - 1) / m_NumFolds) + 1; // Randomize on a copy of the original dataset Instances runInstances = new Instances(m_Instances); Random random = new Random(run); runInstances.randomize(random); if (runInstances.classAttribute().isNominal()) { runInstances.stratify(m_NumFolds); } // Add in some fields to the key like run and fold number, dataset name Object[] seKey = m_SplitEvaluator.getKey(); Object[] key = new Object[seKey.length + 3]; key[0] = Utils.backQuoteChars(m_Instances.relationName()); key[1] = "" + run; key[2] = "" + (fold + 1); System.arraycopy(seKey, 0, key, 3, seKey.length); if (m_ResultListener.isResultRequired(this, key)) { // Just to make behaviour absolutely consistent with // CrossValidationResultProducer for (int tempFold = 0; tempFold < fold; tempFold++) { runInstances.trainCV(m_NumFolds, tempFold, random); } Instances train = runInstances.trainCV(m_NumFolds, fold, random); Instances test = runInstances.testCV(m_NumFolds, fold); try { Object[] seResults = m_SplitEvaluator.getResult(train, test); Object[] results = new Object[seResults.length + 1]; results[0] = getTimestamp(); System.arraycopy(seResults, 0, results, 1, seResults.length); if (m_debugOutput) { String resultName = ("" + run + "." + (fold + 1) + "." + Utils.backQuoteChars(runInstances.relationName()) + "." + m_SplitEvaluator .toString()).replace(' ', '_'); resultName = Utils.removeSubstring(resultName, "weka.classifiers."); resultName = Utils.removeSubstring(resultName, "weka.filters."); resultName = Utils.removeSubstring(resultName, "weka.attributeSelection."); m_ZipDest.zipit(m_SplitEvaluator.getRawResultOutput(), resultName); } m_ResultListener.acceptResult(this, key, results); } catch (Exception ex) { // Save the train and test datasets for debugging purposes? throw ex; } } } /** * Gets a text descrption of the result producer. * * @return a text description of the result producer. */ @Override public String toString() { String result = "CrossValidationSplitResultProducer: "; result += getCompatibilityState(); if (m_Instances == null) { result += ": "; } else { result += ": " + Utils.backQuoteChars(m_Instances.relationName()); } return result; } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 10203 $"); } } // CrossValidationSplitResultProducer




© 2015 - 2025 Weber Informatics LLC | Privacy Policy