All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.experiment.AveragingResultProducer Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    AveragingResultProducer.java
 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.experiment;

import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Vector;

import weka.core.AdditionalMeasureProducer;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;

/**
 *  Takes the results from a ResultProducer and submits
 * the average to the result listener. Normally used with a
 * CrossValidationResultProducer to perform n x m fold cross validation. For
 * non-numeric result fields, the first value is used.
 * 

* * * Valid options are: *

* *

 * -F <field name>
 *  The name of the field to average over.
 *  (default "Fold")
 * 
* *
 * -X <num results>
 *  The number of results expected per average.
 *  (default 10)
 * 
* *
 * -S
 *  Calculate standard deviations.
 *  (default only averages)
 * 
* *
 * -W <class name>
 *  The full class name of a ResultProducer.
 *  eg: weka.experiment.CrossValidationResultProducer
 * 
* *
 * Options specific to result producer weka.experiment.CrossValidationResultProducer:
 * 
* *
 * -X <number of folds>
 *  The number of folds to use for the cross-validation.
 *  (default 10)
 * 
* *
 * -D
 * Save raw split evaluator output.
 * 
* *
 * -O <file/directory name/path>
 *  The filename where raw output will be stored.
 *  If a directory name is specified then then individual
 *  outputs will be gzipped, otherwise all output will be
 *  zipped to the named file. Use in conjuction with -D. (default splitEvalutorOut.zip)
 * 
* *
 * -W <class name>
 *  The full class name of a SplitEvaluator.
 *  eg: weka.experiment.ClassifierSplitEvaluator
 * 
* *
 * Options specific to split evaluator weka.experiment.ClassifierSplitEvaluator:
 * 
* *
 * -W <class name>
 *  The full class name of the classifier.
 *  eg: weka.classifiers.bayes.NaiveBayes
 * 
* *
 * -C <index>
 *  The index of the class for which IR statistics
 *  are to be output. (default 1)
 * 
* *
 * -I <index>
 *  The index of an attribute to output in the
 *  results. This attribute should identify an
 *  instance in order to know which instances are
 *  in the test set of a cross validation. if 0
 *  no output (default 0).
 * 
* *
 * -P
 *  Add target and prediction columns to the result
 *  for each fold.
 * 
* *
 * Options specific to classifier weka.classifiers.rules.ZeroR:
 * 
* *
 * -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
 * 
* * * * All options after -- will be passed to the result producer. * * @author Len Trigg ([email protected]) * @version $Revision: 11198 $ */ public class AveragingResultProducer implements ResultListener, ResultProducer, OptionHandler, AdditionalMeasureProducer, RevisionHandler { /** for serialization */ static final long serialVersionUID = 2551284958501991352L; /** The dataset of interest */ protected Instances m_Instances; /** The ResultListener to send results to */ protected ResultListener m_ResultListener = new CSVResultListener(); /** The ResultProducer used to generate results */ protected ResultProducer m_ResultProducer = new CrossValidationResultProducer(); /** The names of any additional measures to look for in SplitEvaluators */ protected String[] m_AdditionalMeasures = null; /** The number of results expected to average over for each run */ protected int m_ExpectedResultsPerAverage = 10; /** True if standard deviation fields should be produced */ protected boolean m_CalculateStdDevs; /** * The name of the field that will contain the number of results averaged * over. */ protected String m_CountFieldName = "Num_" + CrossValidationResultProducer .FOLD_FIELD_NAME; /** The name of the key field to average over */ protected String m_KeyFieldName = CrossValidationResultProducer .FOLD_FIELD_NAME; /** The index of the field to average over in the resultproducers key */ protected int m_KeyIndex = -1; /** Collects the keys from a single run */ protected FastVector m_Keys = new FastVector(); /** Collects the results from a single run */ protected FastVector m_Results = new FastVector(); /** * Returns a string describing this result producer * * @return a description of the result producer suitable for displaying in the * explorer/experimenter gui */ public String globalInfo() { return "Takes the results from a ResultProducer " + "and submits the average to the result listener. Normally used with " + "a CrossValidationResultProducer to perform n x m fold cross " + "validation. For non-numeric result fields, the first value is used."; } /** * Scans through the key field names of the result producer to find the index * of the key field to average over. Sets the value of m_KeyIndex to the * index, or -1 if no matching key field was found. * * @return the index of the key field to average over */ protected int findKeyIndex() { m_KeyIndex = -1; try { if (m_ResultProducer != null) { String[] keyNames = m_ResultProducer.getKeyNames(); for (int i = 0; i < keyNames.length; i++) { if (keyNames[i].equals(m_KeyFieldName)) { m_KeyIndex = i; break; } } } } catch (Exception ex) { } return m_KeyIndex; } /** * Determines if there are any constraints (imposed by the destination) on the * result columns to be produced by resultProducers. Null should be returned * if there are NO constraints, otherwise a list of column names should be * returned as an array of Strings. * * @param rp the ResultProducer to which the constraints will apply * @return an array of column names to which resutltProducer's results will be * restricted. * @throws Exception if constraints can't be determined */ @Override public String[] determineColumnConstraints(ResultProducer rp) throws Exception { return null; } /** * Simulates a run to collect the keys the sub-resultproducer could generate. * Does some checking on the keys and determines the template key. * * @param run the run number * @return a template key (null for the field being averaged) * @throws Exception if an error occurs */ protected Object[] determineTemplate(int run) throws Exception { if (m_Instances == null) { throw new Exception("No Instances set"); } m_ResultProducer.setInstances(m_Instances); // Clear the collected results m_Keys.removeAllElements(); m_Results.removeAllElements(); m_ResultProducer.doRunKeys(run); checkForMultipleDifferences(); Object[] template = ((Object[]) m_Keys.elementAt(0)).clone(); template[m_KeyIndex] = null; // Check for duplicate keys checkForDuplicateKeys(template); return template; } /** * Gets the keys for a specified run number. Different run numbers correspond * to different randomizations of the data. Keys produced should be sent to * the current ResultListener * * @param run the run number to get keys for. * @throws Exception if a problem occurs while getting the keys */ @Override public void doRunKeys(int run) throws Exception { // Generate the template Object[] template = determineTemplate(run); String[] newKey = new String[template.length - 1]; System.arraycopy(template, 0, newKey, 0, m_KeyIndex); System.arraycopy(template, m_KeyIndex + 1, newKey, m_KeyIndex, template.length - m_KeyIndex - 1); m_ResultListener.acceptResult(this, newKey, null); } /** * Gets the results for a specified run number. Different run numbers * correspond to different randomizations of the data. Results produced should * be sent to the current ResultListener * * @param run the run number to get results for. * @throws Exception if a problem occurs while getting the results */ @Override public void doRun(int run) throws Exception { // Generate the key and ask whether the result is required Object[] template = determineTemplate(run); String[] newKey = new String[template.length - 1]; System.arraycopy(template, 0, newKey, 0, m_KeyIndex); System.arraycopy(template, m_KeyIndex + 1, newKey, m_KeyIndex, template.length - m_KeyIndex - 1); if (m_ResultListener.isResultRequired(this, newKey)) { // Clear the collected keys m_Keys.removeAllElements(); m_Results.removeAllElements(); m_ResultProducer.doRun(run); // Average the results collected // System.err.println("Number of results collected: " + m_Keys.size()); // Check that the keys only differ on the selected key field checkForMultipleDifferences(); template = ((Object[]) m_Keys.elementAt(0)).clone(); template[m_KeyIndex] = null; // Check for duplicate keys checkForDuplicateKeys(template); // Calculate the average and submit it if necessary doAverageResult(template); } } /** * Compares a key to a template to see whether they match. Null fields in the * template are ignored in the matching. * * @param template the template to match against * @param test the key to test * @return true if the test key matches the template on all non-null template * fields */ protected boolean matchesTemplate(Object[] template, Object[] test) { if (template.length != test.length) { return false; } for (int i = 0; i < test.length; i++) { if ((template[i] != null) && (!template[i].equals(test[i]))) { return false; } } return true; } /** * Asks the resultlistener whether an average result is required, and if so, * calculates it. * * @param template the template to match keys against when calculating the * average * @throws Exception if an error occurs */ protected void doAverageResult(Object[] template) throws Exception { // Generate the key and ask whether the result is required String[] newKey = new String[template.length - 1]; System.arraycopy(template, 0, newKey, 0, m_KeyIndex); System.arraycopy(template, m_KeyIndex + 1, newKey, m_KeyIndex, template.length - m_KeyIndex - 1); if (m_ResultListener.isResultRequired(this, newKey)) { Object[] resultTypes = m_ResultProducer.getResultTypes(); Stats[] stats = new Stats[resultTypes.length]; for (int i = 0; i < stats.length; i++) { stats[i] = new Stats(); } Object[] result = getResultTypes(); int numMatches = 0; for (int i = 0; i < m_Keys.size(); i++) { Object[] currentKey = (Object[]) m_Keys.elementAt(i); // Skip non-matching keys if (!matchesTemplate(template, currentKey)) { continue; } // Add the results to the stats accumulator Object[] currentResult = (Object[]) m_Results.elementAt(i); numMatches++; for (int j = 0; j < resultTypes.length; j++) { if (resultTypes[j] instanceof Double) { if (currentResult[j] == null) { // set the stats object for this result to null--- // more than likely this is an additional measure field // not supported by the low level split evaluator if (stats[j] != null) { stats[j] = null; } /* * throw new Exception("Null numeric result field found:\n" + * DatabaseUtils.arrayToString(currentKey) + " -- " + * DatabaseUtils .arrayToString(currentResult)); */ } if (stats[j] != null) { double currentVal = ((Double) currentResult[j]).doubleValue(); stats[j].add(currentVal); } } } } if (numMatches != m_ExpectedResultsPerAverage) { throw new Exception("Expected " + m_ExpectedResultsPerAverage + " results matching key \"" + DatabaseUtils.arrayToString(template) + "\" but got " + numMatches); } result[0] = new Double(numMatches); Object[] currentResult = (Object[]) m_Results.elementAt(0); int k = 1; for (int j = 0; j < resultTypes.length; j++) { if (resultTypes[j] instanceof Double) { if (stats[j] != null) { stats[j].calculateDerived(); result[k++] = new Double(stats[j].mean); } else { result[k++] = null; } if (getCalculateStdDevs()) { if (stats[j] != null) { result[k++] = new Double(stats[j].stdDev); } else { result[k++] = null; } } } else { result[k++] = currentResult[j]; } } m_ResultListener.acceptResult(this, newKey, result); } } /** * Checks whether any duplicate results (with respect to a key template) were * received. * * @param template the template key. * @throws Exception if duplicate results are detected */ protected void checkForDuplicateKeys(Object[] template) throws Exception { Hashtable hash = new Hashtable(); int numMatches = 0; for (int i = 0; i < m_Keys.size(); i++) { Object[] current = (Object[]) m_Keys.elementAt(i); // Skip non-matching keys if (!matchesTemplate(template, current)) { continue; } if (hash.containsKey(current[m_KeyIndex])) { throw new Exception("Duplicate result received:" + DatabaseUtils.arrayToString(current)); } numMatches++; hash.put(current[m_KeyIndex], current[m_KeyIndex]); } if (numMatches != m_ExpectedResultsPerAverage) { throw new Exception("Expected " + m_ExpectedResultsPerAverage + " results matching key \"" + DatabaseUtils.arrayToString(template) + "\" but got " + numMatches); } } /** * Checks that the keys for a run only differ in one key field. If they differ * in more than one field, a more sophisticated averager will submit multiple * results - for now an exception is thrown. Currently assumes that the most * differences will be shown between the first and last result received. * * @throws Exception if the keys differ on fields other than the key averaging * field */ protected void checkForMultipleDifferences() throws Exception { Object[] firstKey = (Object[]) m_Keys.elementAt(0); Object[] lastKey = (Object[]) m_Keys.elementAt(m_Keys.size() - 1); /* * System.err.println("First key:" + DatabaseUtils.arrayToString(firstKey)); * System.err.println("Last key :" + DatabaseUtils.arrayToString(lastKey)); */ for (int i = 0; i < firstKey.length; i++) { if ((i != m_KeyIndex) && !firstKey[i].equals(lastKey[i])) { throw new Exception("Keys differ on fields other than \"" + m_KeyFieldName + "\" -- time to implement multiple averaging"); } } } /** * Prepare for the results to be received. * * @param rp the ResultProducer that will generate the results * @throws Exception if an error occurs during preprocessing. */ @Override public void preProcess(ResultProducer rp) throws Exception { if (m_ResultListener == null) { throw new Exception("No ResultListener set"); } m_ResultListener.preProcess(this); } /** * Prepare to generate results. The ResultProducer should call * preProcess(this) on the ResultListener it is to send results to. * * @throws Exception if an error occurs during preprocessing. */ @Override public void preProcess() throws Exception { if (m_ResultProducer == null) { throw new Exception("No ResultProducer set"); } // Tell the resultproducer to send results to us m_ResultProducer.setResultListener(this); findKeyIndex(); if (m_KeyIndex == -1) { throw new Exception("No key field called " + m_KeyFieldName + " produced by " + m_ResultProducer.getClass().getName()); } m_ResultProducer.preProcess(); } /** * When this method is called, it indicates that no more results will be sent * that need to be grouped together in any way. * * @param rp the ResultProducer that generated the results * @throws Exception if an error occurs */ @Override public void postProcess(ResultProducer rp) throws Exception { m_ResultListener.postProcess(this); } /** * When this method is called, it indicates that no more requests to generate * results for the current experiment will be sent. The ResultProducer should * call preProcess(this) on the ResultListener it is to send results to. * * @throws Exception if an error occurs */ @Override public void postProcess() throws Exception { m_ResultProducer.postProcess(); } /** * Accepts results from a ResultProducer. * * @param rp the ResultProducer that generated the results * @param key an array of Objects (Strings or Doubles) that uniquely identify * a result for a given ResultProducer with given compatibilityState * @param result the results stored in an array. The objects stored in the * array may be Strings, Doubles, or null (for the missing value). * @throws Exception if the result could not be accepted. */ @Override public void acceptResult(ResultProducer rp, Object[] key, Object[] result) throws Exception { if (m_ResultProducer != rp) { throw new Error("Unrecognized ResultProducer sending results!!"); } m_Keys.addElement(key); m_Results.addElement(result); } /** * Determines whether the results for a specified key must be generated. * * @param rp the ResultProducer wanting to generate the results * @param key an array of Objects (Strings or Doubles) that uniquely identify * a result for a given ResultProducer with given compatibilityState * @return true if the result should be generated * @throws Exception if it could not be determined if the result is needed. */ @Override public boolean isResultRequired(ResultProducer rp, Object[] key) throws Exception { if (m_ResultProducer != rp) { throw new Error("Unrecognized ResultProducer sending results!!"); } return true; } /** * Gets the names of each of the columns produced for a single run. * * @return an array containing the name of each column * @throws Exception if key names cannot be generated */ @Override public String[] getKeyNames() throws Exception { if (m_KeyIndex == -1) { throw new Exception("No key field called " + m_KeyFieldName + " produced by " + m_ResultProducer.getClass().getName()); } String[] keyNames = m_ResultProducer.getKeyNames(); String[] newKeyNames = new String[keyNames.length - 1]; System.arraycopy(keyNames, 0, newKeyNames, 0, m_KeyIndex); System.arraycopy(keyNames, m_KeyIndex + 1, newKeyNames, m_KeyIndex, keyNames.length - m_KeyIndex - 1); return newKeyNames; } /** * Gets the data types of each of the columns produced for a single run. This * method should really be static. * * @return an array containing objects of the type of each column. The objects * should be Strings, or Doubles. * @throws Exception if the key types could not be determined (perhaps because * of a problem from a nested sub-resultproducer) */ @Override public Object[] getKeyTypes() throws Exception { if (m_KeyIndex == -1) { throw new Exception("No key field called " + m_KeyFieldName + " produced by " + m_ResultProducer.getClass().getName()); } Object[] keyTypes = m_ResultProducer.getKeyTypes(); // Find and remove the key field that is being averaged over Object[] newKeyTypes = new String[keyTypes.length - 1]; System.arraycopy(keyTypes, 0, newKeyTypes, 0, m_KeyIndex); System.arraycopy(keyTypes, m_KeyIndex + 1, newKeyTypes, m_KeyIndex, keyTypes.length - m_KeyIndex - 1); return newKeyTypes; } /** * Gets the names of each of the columns produced for a single run. A new * result field is added for the number of results used to produce each * average. If only averages are being produced the names are not altered, if * standard deviations are produced then "Dev_" and "Avg_" are prepended to * each result deviation and average field respectively. * * @return an array containing the name of each column * @throws Exception if the result names could not be determined (perhaps * because of a problem from a nested sub-resultproducer) */ @Override public String[] getResultNames() throws Exception { String[] resultNames = m_ResultProducer.getResultNames(); // Add in the names of our extra Result fields if (getCalculateStdDevs()) { Object[] resultTypes = m_ResultProducer.getResultTypes(); int numNumeric = 0; for (Object resultType : resultTypes) { if (resultType instanceof Double) { numNumeric++; } } String[] newResultNames = new String[resultNames.length + 1 + numNumeric]; newResultNames[0] = m_CountFieldName; int j = 1; for (int i = 0; i < resultNames.length; i++) { newResultNames[j++] = "Avg_" + resultNames[i]; if (resultTypes[i] instanceof Double) { newResultNames[j++] = "Dev_" + resultNames[i]; } } return newResultNames; } else { String[] newResultNames = new String[resultNames.length + 1]; newResultNames[0] = m_CountFieldName; System.arraycopy(resultNames, 0, newResultNames, 1, resultNames.length); return newResultNames; } } /** * Gets the data types of each of the columns produced for a single run. * * @return an array containing objects of the type of each column. The objects * should be Strings, or Doubles. * @throws Exception if the result types could not be determined (perhaps * because of a problem from a nested sub-resultproducer) */ @Override public Object[] getResultTypes() throws Exception { Object[] resultTypes = m_ResultProducer.getResultTypes(); // Add in the types of our extra Result fields if (getCalculateStdDevs()) { int numNumeric = 0; for (Object resultType : resultTypes) { if (resultType instanceof Double) { numNumeric++; } } Object[] newResultTypes = new Object[resultTypes.length + 1 + numNumeric]; newResultTypes[0] = new Double(0); int j = 1; for (Object resultType : resultTypes) { newResultTypes[j++] = resultType; if (resultType instanceof Double) { newResultTypes[j++] = new Double(0); } } return newResultTypes; } else { Object[] newResultTypes = new Object[resultTypes.length + 1]; newResultTypes[0] = new Double(0); System.arraycopy(resultTypes, 0, newResultTypes, 1, resultTypes.length); return newResultTypes; } } /** * Gets a description of the internal settings of the result producer, * sufficient for distinguishing a ResultProducer instance from another with * different settings (ignoring those settings set through this interface). * For example, a cross-validation ResultProducer may have a setting for the * number of folds. For a given state, the results produced should be * compatible. Typically if a ResultProducer is an OptionHandler, this string * will represent the command line arguments required to set the * ResultProducer to that state. * * @return the description of the ResultProducer state, or null if no state is * defined */ @Override public String getCompatibilityState() { String result = // "-F " + Utils.quote(getKeyFieldName()) " -X " + getExpectedResultsPerAverage() + " "; if (getCalculateStdDevs()) { result += "-S "; } if (m_ResultProducer == null) { result += ""; } else { result += "-W " + m_ResultProducer.getClass().getName(); result += " -- " + m_ResultProducer.getCompatibilityState(); } return result.trim(); } /** * Returns an enumeration describing the available options.. * * @return an enumeration of all the available options. */ @Override public Enumeration listOptions() { Vector newVector = new Vector(2); newVector.addElement(new Option( "\tThe name of the field to average over.\n" + "\t(default \"Fold\")", "F", 1, "-F ")); newVector.addElement(new Option( "\tThe number of results expected per average.\n" + "\t(default 10)", "X", 1, "-X ")); newVector.addElement(new Option( "\tCalculate standard deviations.\n" + "\t(default only averages)", "S", 0, "-S")); newVector.addElement(new Option( "\tThe full class name of a ResultProducer.\n" + "\teg: weka.experiment.CrossValidationResultProducer", "W", 1, "-W ")); if ((m_ResultProducer != null) && (m_ResultProducer instanceof OptionHandler)) { newVector.addElement(new Option( "", "", 0, "\nOptions specific to result producer " + m_ResultProducer.getClass().getName() + ":")); Enumeration enu = ((OptionHandler) m_ResultProducer).listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } } return newVector.elements(); } /** * Parses a given list of options. *

* * Valid options are: *

* *

   * -F <field name>
   *  The name of the field to average over.
   *  (default "Fold")
   * 
* *
   * -X <num results>
   *  The number of results expected per average.
   *  (default 10)
   * 
* *
   * -S
   *  Calculate standard deviations.
   *  (default only averages)
   * 
* *
   * -W <class name>
   *  The full class name of a ResultProducer.
   *  eg: weka.experiment.CrossValidationResultProducer
   * 
* *
   * Options specific to result producer weka.experiment.CrossValidationResultProducer:
   * 
* *
   * -X <number of folds>
   *  The number of folds to use for the cross-validation.
   *  (default 10)
   * 
* *
   * -D
   * Save raw split evaluator output.
   * 
* *
   * -O <file/directory name/path>
   *  The filename where raw output will be stored.
   *  If a directory name is specified then then individual
   *  outputs will be gzipped, otherwise all output will be
   *  zipped to the named file. Use in conjuction with -D. (default splitEvalutorOut.zip)
   * 
* *
   * -W <class name>
   *  The full class name of a SplitEvaluator.
   *  eg: weka.experiment.ClassifierSplitEvaluator
   * 
* *
   * Options specific to split evaluator weka.experiment.ClassifierSplitEvaluator:
   * 
* *
   * -W <class name>
   *  The full class name of the classifier.
   *  eg: weka.classifiers.bayes.NaiveBayes
   * 
* *
   * -C <index>
   *  The index of the class for which IR statistics
   *  are to be output. (default 1)
   * 
* *
   * -I <index>
   *  The index of an attribute to output in the
   *  results. This attribute should identify an
   *  instance in order to know which instances are
   *  in the test set of a cross validation. if 0
   *  no output (default 0).
   * 
* *
   * -P
   *  Add target and prediction columns to the result
   *  for each fold.
   * 
* *
   * Options specific to classifier weka.classifiers.rules.ZeroR:
   * 
* *
   * -D
   *  If set, classifier is run in debug mode and
   *  may output additional info to the console
   * 
* * * * All options after -- will be passed to the result producer. * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ @Override public void setOptions(String[] options) throws Exception { String keyFieldName = Utils.getOption('F', options); if (keyFieldName.length() != 0) { setKeyFieldName(keyFieldName); } else { setKeyFieldName(CrossValidationResultProducer.FOLD_FIELD_NAME); } String numResults = Utils.getOption('X', options); if (numResults.length() != 0) { setExpectedResultsPerAverage(Integer.parseInt(numResults)); } else { setExpectedResultsPerAverage(10); } setCalculateStdDevs(Utils.getFlag('S', options)); String rpName = Utils.getOption('W', options); if (rpName.length() > 0) { // Do it first without options, so if an exception is thrown during // the option setting, listOptions will contain options for the actual // RP. setResultProducer((ResultProducer) Utils.forName( ResultProducer.class, rpName, null)); } if (getResultProducer() instanceof OptionHandler) { ((OptionHandler) getResultProducer()) .setOptions(Utils.partitionOptions(options)); } } /** * Gets the current settings of the result producer. * * @return an array of strings suitable for passing to setOptions */ @Override public String[] getOptions() { String[] seOptions = new String[0]; if ((m_ResultProducer != null) && (m_ResultProducer instanceof OptionHandler)) { seOptions = ((OptionHandler) m_ResultProducer).getOptions(); } String[] options = new String[seOptions.length + 8]; int current = 0; options[current++] = "-F"; options[current++] = "" + getKeyFieldName(); options[current++] = "-X"; options[current++] = "" + getExpectedResultsPerAverage(); if (getCalculateStdDevs()) { options[current++] = "-S"; } if (getResultProducer() != null) { options[current++] = "-W"; options[current++] = getResultProducer().getClass().getName(); } options[current++] = "--"; System.arraycopy(seOptions, 0, options, current, seOptions.length); current += seOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * Set a list of method names for additional measures to look for in * SplitEvaluators. This could contain many measures (of which only a subset * may be produceable by the current resultProducer) if an experiment is the * type that iterates over a set of properties. * * @param additionalMeasures an array of measure names, null if none */ @Override public void setAdditionalMeasures(String[] additionalMeasures) { m_AdditionalMeasures = additionalMeasures; if (m_ResultProducer != null) { System.err.println("AveragingResultProducer: setting additional " + "measures for " + "ResultProducer"); m_ResultProducer.setAdditionalMeasures(m_AdditionalMeasures); } } /** * Returns an enumeration of any additional measure names that might be in the * result producer * * @return an enumeration of the measure names */ @Override public Enumeration enumerateMeasures() { Vector newVector = new Vector(); if (m_ResultProducer instanceof AdditionalMeasureProducer) { Enumeration en = ((AdditionalMeasureProducer) m_ResultProducer). enumerateMeasures(); while (en.hasMoreElements()) { String mname = (String) en.nextElement(); newVector.addElement(mname); } } return newVector.elements(); } /** * Returns the value of the named measure * * @param additionalMeasureName the name of the measure to query for its value * @return the value of the named measure * @throws IllegalArgumentException if the named measure is not supported */ @Override public double getMeasure(String additionalMeasureName) { if (m_ResultProducer instanceof AdditionalMeasureProducer) { return ((AdditionalMeasureProducer) m_ResultProducer). getMeasure(additionalMeasureName); } else { throw new IllegalArgumentException("AveragingResultProducer: " + "Can't return value for : " + additionalMeasureName + ". " + m_ResultProducer.getClass().getName() + " " + "is not an AdditionalMeasureProducer"); } } /** * Sets the dataset that results will be obtained for. * * @param instances a value of type 'Instances'. */ @Override public void setInstances(Instances instances) { m_Instances = instances; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String calculateStdDevsTipText() { return "Record standard deviations for each run."; } /** * Get the value of CalculateStdDevs. * * @return Value of CalculateStdDevs. */ public boolean getCalculateStdDevs() { return m_CalculateStdDevs; } /** * Set the value of CalculateStdDevs. * * @param newCalculateStdDevs Value to assign to CalculateStdDevs. */ public void setCalculateStdDevs(boolean newCalculateStdDevs) { m_CalculateStdDevs = newCalculateStdDevs; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String expectedResultsPerAverageTipText() { return "Set the expected number of results to average per run. " + "For example if a CrossValidationResultProducer is being used " + "(with the number of folds set to 10), then the expected number " + "of results per run is 10."; } /** * Get the value of ExpectedResultsPerAverage. * * @return Value of ExpectedResultsPerAverage. */ public int getExpectedResultsPerAverage() { return m_ExpectedResultsPerAverage; } /** * Set the value of ExpectedResultsPerAverage. * * @param newExpectedResultsPerAverage Value to assign to * ExpectedResultsPerAverage. */ public void setExpectedResultsPerAverage(int newExpectedResultsPerAverage) { m_ExpectedResultsPerAverage = newExpectedResultsPerAverage; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String keyFieldNameTipText() { return "Set the field name that will be unique for a run."; } /** * Get the value of KeyFieldName. * * @return Value of KeyFieldName. */ public String getKeyFieldName() { return m_KeyFieldName; } /** * Set the value of KeyFieldName. * * @param newKeyFieldName Value to assign to KeyFieldName. */ public void setKeyFieldName(String newKeyFieldName) { m_KeyFieldName = newKeyFieldName; m_CountFieldName = "Num_" + m_KeyFieldName; findKeyIndex(); } /** * Sets the object to send results of each run to. * * @param listener a value of type 'ResultListener' */ @Override public void setResultListener(ResultListener listener) { m_ResultListener = listener; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String resultProducerTipText() { return "Set the resultProducer for which results are to be averaged."; } /** * Get the ResultProducer. * * @return the ResultProducer. */ public ResultProducer getResultProducer() { return m_ResultProducer; } /** * Set the ResultProducer. * * @param newResultProducer new ResultProducer to use. */ public void setResultProducer(ResultProducer newResultProducer) { m_ResultProducer = newResultProducer; m_ResultProducer.setResultListener(this); findKeyIndex(); } /** * Gets a text descrption of the result producer. * * @return a text description of the result producer. */ @Override public String toString() { String result = "AveragingResultProducer: "; result += getCompatibilityState(); if (m_Instances == null) { result += ": "; } else { result += ": " + Utils.backQuoteChars(m_Instances.relationName()); } return result; } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 11198 $"); } } // AveragingResultProducer




© 2015 - 2025 Weber Informatics LLC | Privacy Policy