weka.classifiers.meta.MultiSearch Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of multisearch-weka-package Show documentation
Show all versions of multisearch-weka-package Show documentation
Parameter optimization similar to GridSearch.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* MultiSearch.java
* Copyright (C) 2008-2017 University of Waikato, Hamilton, New Zealand
*/
package weka.classifiers.meta;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.meta.multisearch.AbstractEvaluationFactory;
import weka.classifiers.meta.multisearch.AbstractEvaluationMetrics;
import weka.classifiers.meta.multisearch.AbstractSearch;
import weka.classifiers.meta.multisearch.AbstractSearch.SearchResult;
import weka.classifiers.meta.multisearch.DefaultEvaluationFactory;
import weka.classifiers.meta.multisearch.DefaultSearch;
import weka.classifiers.meta.multisearch.MultiSearchCapable;
import weka.classifiers.meta.multisearch.Performance;
import weka.classifiers.meta.multisearch.PerformanceComparator;
import weka.classifiers.meta.multisearch.TraceableOptimizer;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Debug;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.SerializedObject;
import weka.core.SetupGenerator;
import weka.core.SingleIndex;
import weka.core.Summarizable;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.setupgenerator.AbstractParameter;
import weka.core.setupgenerator.MathParameter;
import weka.core.setupgenerator.ParameterGroup;
import weka.core.setupgenerator.Point;
import weka.core.setupgenerator.Space;
import java.io.File;
import java.io.Serializable;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Vector;
/**
* Performs a search of an arbitrary number of parameters of a classifier and chooses the best pair found for the actual filtering and training.
* The default MultiSearch is using the following Classifier setup:
* LinearRegression, searching for the "Ridge"
* The properties being explored are totally up to the user.
*
* E.g., if you have a FilteredClassifier selected as base classifier, sporting a PLSFilter and you want to explore the number of PLS components, then your property will be made up of the following components:
* - filter: referring to the FilteredClassifier's property (= PLSFilter)
* - numComponents: the actual property of the PLSFilter that we want to modify
* And assembled, the property looks like this:
* filter.numComponents
*
*
* The best classifier setup can be accessed after the buildClassifier call via the getBestClassifier method.
*
* The trace of setups evaluated can be accessed after the buildClassifier call as well, using the following methods:
* - getTrace()
* - getTraceSize()
* - getTraceValue(int)
* - getTraceFolds(int)
* - getTraceClassifierAsCli(int)
* - getTraceParameterSettings(int)
*
* Using the weka.core.setupgenerator.ParameterGroup parameter, it is possible to group dependent parameters. In this case, all top-level parameters must be of type weka.core.setupgenerator.ParameterGroup.
*
*
* Valid options are:
*
* -E <CC|MCC|RMSE|RRSE|MAE|RAE|COMB|ACC|KAP|PREC|WPREC|REC|WREC|AUC|WAUC|PRC|WPRC|FM|WFM|TPR|TNR|FPR|FNR>
* Determines the parameter used for evaluation:
* CC = Correlation coefficient
* MCC = Matthews correlation coefficient
* RMSE = Root mean squared error
* RRSE = Root relative squared error
* MAE = Mean absolute error
* RAE = Root absolute error
* COMB = Combined = (1-abs(CC)) + RRSE + RAE
* ACC = Accuracy
* KAP = Kappa
* PREC = Precision (per class)
* WPREC = Weighted precision
* REC = Recall (per class)
* WREC = Weighted recall
* AUC = Area under ROC (per class)
* WAUC = Weighted area under ROC
* PRC = Area under PRC (per class)
* WPRC = Weighted area under PRC
* FM = F-Measure (per class)
* WFM = Weighted F-Measure
* TPR = True positive rate (per class)
* TNR = True negative rate (per class)
* FPR = False positive rate (per class)
* FNR = False negative rate (per class)
* (default: CC)
*
* -class-label "<1-based index>"
* The class label index to retrieve the metric for (if applicable).
*
*
* -search "<classname options>"
* A property search setup.
*
*
* -algorithm "<classname options>"
* A search algorithm.
*
*
* -log-file <filename>
* The log file to log the messages to.
* (default: none)
*
* -S <num>
* Random number seed.
* (default 1)
*
* -W
* Full name of base classifier.
* (default: weka.classifiers.functions.LinearRegression)
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
* -num-decimal-places
* The number of decimal places for the output of numbers in the model (default 2).
*
* -batch-size
* The desired batch size for batch prediction (default 100).
*
*
* Options specific to classifier weka.classifiers.functions.LinearRegression:
*
*
* -S <number of selection method>
* Set the attribute selection method to use. 1 = None, 2 = Greedy.
* (default 0 = M5' method)
*
* -C
* Do not try to eliminate colinear attributes.
*
*
* -S <number of selection method>
* Set the attribute selection method to use. 1 = None, 2 = Greedy.
* (default 0 = M5' method)
*
* -R <double>
* Set ridge parameter (default 1.0e-8).
*
*
* -minimal
* Conserve memory, don't keep dataset header and means/stdevs.
* Model cannot be printed out if this option is enabled. (default: keep data)
*
* -additional-stats
* Output additional statistics.
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
* -num-decimal-places
* The number of decimal places for the output of numbers in the model (default 4).
*
* -batch-size
* The desired batch size for batch prediction (default 100).
*
*
* @author fracpete (fracpete at waikato dot ac dot nz)
* @version $Revision: 4521 $
*/
public class MultiSearch
extends RandomizableSingleClassifierEnhancer
implements MultiSearchCapable, AdditionalMeasureProducer, Summarizable, TraceableOptimizer {
/** for serialization. */
private static final long serialVersionUID = -5129316523575906233L;
/** the Classifier with the best setup. */
protected SearchResult m_BestClassifier;
/** the evaluation factory to use. */
protected AbstractEvaluationFactory m_Factory;
/** the metrics to use. */
protected AbstractEvaluationMetrics m_Metrics;
/** the type of evaluation. */
protected int m_Evaluation;
/** the class label index (if applicable). */
protected SingleIndex m_ClassLabel;
/** the log file to use. */
protected File m_LogFile = new File(System.getProperty("user.dir"));
/** the default parameters. */
protected AbstractParameter[] m_DefaultParameters;
/** the parameters. */
protected AbstractParameter[] m_Parameters;
/** the search algorithm. */
protected AbstractSearch m_Algorithm;
/** the current setup generator. */
protected SetupGenerator m_Generator;
/** for tracking the setups. */
protected List> m_Trace;
/**
* the default constructor.
*/
public MultiSearch() {
super();
m_Factory = newFactory();
m_Metrics = m_Factory.newMetrics();
m_Evaluation = m_Metrics.getDefaultMetric();
m_ClassLabel = new SingleIndex("1");
m_Classifier = defaultClassifier();
m_DefaultParameters = defaultSearchParameters();
m_Parameters = defaultSearchParameters();
m_Algorithm = defaultAlgorithm();
m_Trace = new ArrayList>();
try {
m_BestClassifier = new SearchResult();
m_BestClassifier.classifier = AbstractClassifier.makeCopy(m_Classifier);
}
catch (Exception e) {
System.err.println("Failed to create copy of default classifier!");
e.printStackTrace();
}
}
/**
* Returns a string describing classifier.
*
* @return a description suitable for displaying in the
* explorer/experimenter gui
*/
public String globalInfo() {
return
"Performs a search of an arbitrary number of parameters of a classifier "
+ "and chooses the best pair found for the actual filtering and training.\n"
+ "The default MultiSearch is using the following Classifier setup:\n"
+ " LinearRegression, searching for the \"Ridge\"\n"
+ "The properties being explored are totally up to the user.\n"
+ "\n"
+ "E.g., if you have a FilteredClassifier selected as base classifier, "
+ "sporting a PLSFilter and you want to explore the number of PLS components, "
+ "then your property will be made up of the following components:\n"
+ " - filter: referring to the FilteredClassifier's property (= PLSFilter)\n"
+ " - numComponents: the actual property of the PLSFilter that we want to modify\n"
+ "And assembled, the property looks like this:\n"
+ " filter.numComponents\n"
+ "\n"
+ "\n"
+ "The best classifier setup can be accessed after the buildClassifier "
+ "call via the getBestClassifier method.\n"
+ "\n"
+ "The trace of setups evaluated can be accessed after the buildClassifier "
+ "call as well, using the following methods:\n"
+ "- getTrace()\n"
+ "- getTraceSize()\n"
+ "- getTraceValue(int)\n"
+ "- getTraceFolds(int)\n"
+ "- getTraceClassifierAsCli(int)\n"
+ "- getTraceParameterSettings(int)\n"
+ "\n"
+ "Using the " + ParameterGroup.class.getName() + " parameter, it is "
+ "possible to group dependent parameters. In this case, all top-level "
+ "parameters must be of type " + ParameterGroup.class.getName() + ".";
}
/**
* String describing default classifier.
*
* @return the classname of the default classifier
*/
@Override
protected String defaultClassifierString() {
return defaultClassifier().getClass().getName();
}
/**
* Returns the default classifier to use.
*
* @return the default classifier
*/
protected Classifier defaultClassifier() {
LinearRegression result;
result = new LinearRegression();
result.setAttributeSelectionMethod(new SelectedTag(LinearRegression.SELECTION_NONE, LinearRegression.TAGS_SELECTION));
result.setEliminateColinearAttributes(false);
return result;
}
/**
* Returns the default search parameters.
*
* @return the parameters
*/
protected AbstractParameter[] defaultSearchParameters() {
AbstractParameter[] result;
MathParameter param;
result = new AbstractParameter[1];
param = new MathParameter();
param.setProperty("ridge");
param.setMin(-10);
param.setMax(+5);
param.setStep(1);
param.setBase(10);
param.setExpression("pow(BASE,I)");
result[0] = param;
try {
result = (AbstractParameter[]) new SerializedObject(result).getObject();
}
catch (Exception e) {
result = new AbstractParameter[0];
System.err.println("Failed to create copy of default parameters!");
e.printStackTrace();
}
return result;
}
/**
* Gets an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration listOptions() {
Vector result;
Enumeration en;
String desc;
SelectedTag tag;
int i;
result = new Vector();
desc = "";
for (i = 0; i < m_Metrics.getTags().length; i++) {
tag = new SelectedTag(m_Metrics.getTags()[i].getID(), m_Metrics.getTags());
desc += "\t" + tag.getSelectedTag().getIDStr()
+ " = " + tag.getSelectedTag().getReadable()
+ "\n";
}
result.addElement(new Option(
"\tDetermines the parameter used for evaluation:\n"
+ desc
+ "\t(default: " + new SelectedTag(m_Metrics.getDefaultMetric(), m_Metrics.getTags()) + ")",
"E", 1, "-E " + Tag.toOptionList(m_Metrics.getTags())));
result.addElement(new Option(
"\tThe class label index to retrieve the metric for (if applicable).\n",
"class-label", 1, "-class-label \"<1-based index>\""));
result.addElement(new Option(
"\tA property search setup.\n",
"search", 1, "-search \"\""));
result.addElement(new Option(
"\tA search algorithm.\n",
"algorithm", 1, "-algorithm \"\""));
result.addElement(new Option(
"\tThe log file to log the messages to.\n"
+ "\t(default: none)",
"log-file", 1, "-log-file "));
en = super.listOptions();
while (en.hasMoreElements())
result.addElement(en.nextElement());
return result.elements();
}
/**
* returns the options of the current setup.
*
* @return the current options
*/
@Override
public String[] getOptions() {
int i;
Vector result;
String[] options;
result = new Vector();
result.add("-E");
result.add("" + getEvaluation());
for (i = 0; i < getSearchParameters().length; i++) {
result.add("-search");
result.add(getCommandline(getSearchParameters()[i]));
}
result.add("-class-label");
result.add(getClassLabel());
result.add("-algorithm");
result.add(getCommandline(m_Algorithm));
result.add("-log-file");
result.add("" + getLogFile());
options = super.getOptions();
for (i = 0; i < options.length; i++)
result.add(options[i]);
return result.toArray(new String[result.size()]);
}
/**
* Parses the options for this object.
*
* @param options the options to use
* @throws Exception if setting of options fails
*/
@Override
public void setOptions(String[] options) throws Exception {
String tmpStr;
String[] tmpOptions;
Vector search;
int i;
AbstractParameter[] params;
tmpStr = Utils.getOption('E', options);
if (tmpStr.length() != 0)
setEvaluation(new SelectedTag(tmpStr, m_Metrics.getTags()));
else
setEvaluation(new SelectedTag(m_Metrics.getDefaultMetric(), m_Metrics.getTags()));
search = new Vector();
do {
tmpStr = Utils.getOption("search", options);
if (tmpStr.length() > 0)
search.add(tmpStr);
}
while (tmpStr.length() > 0);
if (search.size() == 0) {
for (i = 0; i < m_DefaultParameters.length; i++)
search.add(getCommandline(m_DefaultParameters[i]));
}
params = new AbstractParameter[search.size()];
for (i = 0; i < search.size(); i++) {
tmpOptions = Utils.splitOptions(search.get(i));
tmpStr = tmpOptions[0];
tmpOptions[0] = "";
params[i] = (AbstractParameter) Utils.forName(AbstractParameter.class, tmpStr, tmpOptions);
}
setSearchParameters(params);
tmpStr = Utils.getOption("class-label", options);
if (!tmpStr.isEmpty())
setClassLabel(tmpStr);
else
setClassLabel("1");
tmpStr = Utils.getOption("algorithm", options);
if (!tmpStr.isEmpty()) {
tmpOptions = Utils.splitOptions(tmpStr);
tmpStr = tmpOptions[0];
tmpOptions[0] = "";
setAlgorithm((AbstractSearch) Utils.forName(AbstractSearch.class, tmpStr, tmpOptions));
}
else {
setAlgorithm(new DefaultSearch());
}
tmpStr = Utils.getOption("log-file", options);
if (tmpStr.length() != 0)
setLogFile(new File(tmpStr));
else
setLogFile(new File(System.getProperty("user.dir")));
super.setOptions(options);
}
/**
* Set the base learner.
*
* @param newClassifier the classifier to use.
*/
@Override
public void setClassifier(Classifier newClassifier) {
super.setClassifier(newClassifier);
try {
m_BestClassifier.classifier = AbstractClassifier.makeCopy(m_Classifier);
}
catch (Exception e) {
e.printStackTrace();
}
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String searchParametersTipText() {
return "Defines the search parameters.";
}
/**
* Sets the search parameters.
*
* @param value the parameters
*/
public void setSearchParameters(AbstractParameter[] value) {
m_Parameters = value;
}
/**
* Returns the search parameters.
*
* @return the parameters
*/
public AbstractParameter[] getSearchParameters() {
return m_Parameters;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String algorithmTipText() {
return "Defines the search algorithm.";
}
/**
* Sets the search algorithm.
*
* @param value the algorithm
*/
public void setAlgorithm(AbstractSearch value) {
m_Algorithm = value;
}
/**
* Returns the search algorithm.
*
* @return the algorithm
*/
public AbstractSearch getAlgorithm() {
return m_Algorithm;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String classLabelTipText() {
return "The class label index (1-based) to retrieve the metrics for (if applicable).";
}
/**
* Sets the class label to retrieve the metrics for (if applicable).
*
* @param value the class lable index (1-based)
*/
public void setClassLabel(String value) {
m_ClassLabel.setSingleIndex(value);
}
/**
* Returns the class label to retrieve the metrics for (if applicable).
*
* @return the class label index (1-based)
*/
public String getClassLabel() {
return m_ClassLabel.getSingleIndex();
}
/**
* Returns the integer index.
*
* @param upper the maximum to use
* @return the index (0-based)
*/
public int getClassLabelIndex(int upper) {
SingleIndex index;
index = new SingleIndex(m_ClassLabel.getSingleIndex());
index.setUpper(upper);
return index.getIndex();
}
/**
* Creates the default search algorithm.
*
* @return the algorithm
*/
public AbstractSearch defaultAlgorithm() {
DefaultSearch result;
result = new DefaultSearch();
return result;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String evaluationTipText() {
return
"Sets the criterion for evaluating the classifier performance and "
+ "choosing the best one.";
}
/**
* Returns the underlying tags.
*
* @return the tags
*/
public Tag[] getMetricsTags() {
return m_Metrics.getTags();
}
/**
* Sets the criterion to use for evaluating the classifier performance.
*
* @param value the evaluation criterion
*/
public void setEvaluation(SelectedTag value) {
if (value.getTags() == m_Metrics.getTags()) {
m_Evaluation = value.getSelectedTag().getID();
}
}
/**
* Gets the criterion used for evaluating the classifier performance.
*
* @return the current evaluation criterion.
*/
public SelectedTag getEvaluation() {
return new SelectedTag(m_Evaluation, m_Metrics.getTags());
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String logFileTipText() {
return "The log file to log the messages to.";
}
/**
* Gets current log file.
*
* @return the log file.
*/
public File getLogFile() {
return m_LogFile;
}
/**
* Sets the log file to use.
*
* @param value the log file.
*/
public void setLogFile(File value) {
m_LogFile = value;
}
/**
* returns the best Classifier setup.
*
* @return the best Classifier setup
*/
public Classifier getBestClassifier() {
return m_BestClassifier.classifier;
}
/**
* Returns the setup generator.
*
* @return the generator
*/
public SetupGenerator getGenerator() {
return m_Generator;
}
/**
* Returns an enumeration of the measure names.
*
* @return an enumeration of the measure names
*/
public Enumeration enumerateMeasures() {
Vector result;
int i;
result = new Vector();
if (getBestValues() != null) {
for (i = 0; i < getBestValues().dimensions(); i++) {
if (getBestValues().getValue(i) instanceof Double)
result.add("measure-" + i);
}
}
return result.elements();
}
/**
* Returns the value of the named measure.
*
* @param measureName the name of the measure to query for its value
* @return the value of the named measure
*/
public double getMeasure(String measureName) {
if (measureName.startsWith("measure-"))
return (Double) getBestValues().getValue(Integer.parseInt(measureName.replace("measure-", "")));
else
throw new IllegalArgumentException("Measure '" + measureName + "' not supported!");
}
/**
* Returns the evaluation factory to use.
*
* @return the factory
*/
protected AbstractEvaluationFactory newFactory() {
return new DefaultEvaluationFactory();
}
/**
* Returns the factory instance.
*
* @return the factory
*/
public AbstractEvaluationFactory getFactory() {
return m_Factory;
}
/**
* Returns the evaluation metrics.
*
* @return the metrics
*/
public AbstractEvaluationMetrics getMetrics() {
return m_Metrics;
}
/**
* returns the parameter values that were found to work best.
*
* @return the best parameter combination
*/
public Point