weka.classifiers.meta.multisearch.RandomSearch Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of multisearch-weka-package Show documentation
Show all versions of multisearch-weka-package Show documentation
Parameter optimization similar to GridSearch.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* RandomSearch.java
* Copyright (C) 2016 Leiden University, NL
* Copyright (C) 2018 University of Waikato, Hamilton, NZ
*/
package weka.classifiers.meta.multisearch;
import weka.classifiers.Classifier;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;
import weka.core.converters.ConverterUtils.DataSource;
import weka.core.setupgenerator.Point;
import weka.core.setupgenerator.Space;
import weka.filters.Filter;
import weka.filters.unsupervised.instance.Resample;
import java.io.File;
import java.io.Serializable;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
public class RandomSearch
extends AbstractMultiThreadedSearch {
/** for serialization. */
private static final long serialVersionUID = 2542453917013899104L;
/** the sample size to search with. */
protected double m_SampleSize = 100;
/** number of cross-validation folds for each point in space. */
protected int m_SearchSpaceNumFolds = 2;
/**
* the optional test set to use for the evaluation (overrides
* cross-validation, ignored if dir).
*/
protected File m_SearchSpaceTestSet = new File(".");
/** the optional test set to use for the evaluation. */
protected Instances m_SearchSpaceTestInst;
/** maximum number of iterations to find optimum. */
protected int m_NumIterations = 100;
/** the random seed */
protected int m_RandomSeed = 1;
@Override
public String globalInfo() {
return "Performs a search of an arbitrary number of parameters of a classifier "
+ "and chooses the best pair found for the actual filtering and training.\n";
}
/**
* Gets an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration listOptions() {
Vector result = new Vector();
Enumeration en;
result.addElement(new Option(
"\tThe size (in percent) of the sample to search the inital space with.\n"
+ "\t(default: 100)", "sample-size", 1,
"-sample-size "));
result.addElement(new Option(
"\tThe number of cross-validation folds for the search space.\n"
+ "\tNumbers smaller than 2 turn off cross-validation and\n"
+ "\tjust perform evaluation on the training set.\n"
+ "\t(default: 2)", "num-folds", 1, "-num-folds "));
result.addElement(new Option(
"\tThe (optional) test set to use for the search space.\n"
+ "\tGets ignored if pointing to a file. Overrides cross-validation.\n"
+ "\t(default: .)", "test-set", 1,
"-test-set "));
result.addElement(new Option(
"\tThe number parameter settings that are tried "
+ "(i.e., number of points in the search space are checked).\n"
+ "\t(default: 100)", "num-iterations", 1,
"-num-iterations "));
result.addElement(new Option("\tThe random seed", "seed", 1, "-S "));
en = super.listOptions();
while (en.hasMoreElements())
result.addElement(en.nextElement());
return result.elements();
}
/**
* returns the options of the current setup.
*
* @return the current options
*/
@Override
public String[] getOptions() {
int i;
Vector result;
String[] options;
result = new Vector();
result.add("-sample-size");
result.add("" + getSampleSizePercent());
result.add("-num-folds");
result.add("" + getSearchSpaceNumFolds());
result.add("-test-set");
result.add("" + getSearchSpaceTestSet());
result.add("-num-iterations");
result.add("" + getNumIterations());
result.add("-S");
result.add("" + getRandomSeed());
options = super.getOptions();
for (i = 0; i < options.length; i++)
result.add(options[i]);
return result.toArray(new String[result.size()]);
}
/**
* Parses the options for this object.
*
* @param options
* the options to use
* @throws Exception
* if setting of options fails
*/
@Override
public void setOptions(String[] options) throws Exception {
String tmpStr;
tmpStr = Utils.getOption("sample-size", options);
if (tmpStr.length() != 0)
setSampleSizePercent(Double.parseDouble(tmpStr));
else
setSampleSizePercent(100);
tmpStr = Utils.getOption("num-folds", options);
if (tmpStr.length() != 0)
setSearchSpaceNumFolds(Integer.parseInt(tmpStr));
else
setSearchSpaceNumFolds(2);
tmpStr = Utils.getOption("test-set", options);
if (tmpStr.length() != 0)
setSearchSpaceTestSet(new File(tmpStr));
else
setSearchSpaceTestSet(new File(System.getProperty("user.dir")));
tmpStr = Utils.getOption("num-iterations", options);
if (tmpStr.length() != 0)
setNumIterations(Integer.parseInt(tmpStr));
else
setNumIterations(100);
tmpStr = Utils.getOption("S", options);
if (tmpStr.length() != 0)
setRandomSeed(Integer.parseInt(tmpStr));
else
setRandomSeed(1);
super.setOptions(options);
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String sampleSizePercentTipText() {
return "The sample size (in percent) to use in the search.";
}
/**
* Gets the sample size for the search space search.
*
* @return the sample size.
*/
public double getSampleSizePercent() {
return m_SampleSize;
}
/**
* Sets the sample size for the search space search.
*
* @param value
* the sample size for the search space search.
*/
public void setSampleSizePercent(double value) {
m_SampleSize = value;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String searchSpaceNumFoldsTipText() {
return "The number of cross-validation folds when evaluating the search "
+ "space; values smaller than 2 turn cross-validation off and simple "
+ "evaluation on the training set is performed.";
}
/**
* Gets the number of CV folds for the search space.
*
* @return the number of folds.
*/
public int getSearchSpaceNumFolds() {
return m_SearchSpaceNumFolds;
}
/**
* Sets the number of CV folds for the search space.
*
* @param value
* the number of folds.
*/
public void setSearchSpaceNumFolds(int value) {
m_SearchSpaceNumFolds = value;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String searchSpaceTestSetTipText() {
return "The (optional) test set to use for evaluating the search space; "
+ "overrides cross-validation; gets ignored if pointing to a directory.";
}
/**
* Gets the test set to use for the search space.
*
* @return the number of folds.
*/
public File getSearchSpaceTestSet() {
return m_SearchSpaceTestSet;
}
/**
* Sets the test set to use folds for the search space.
*
* @param value
* the test set, ignored if dir.
*/
public void setSearchSpaceTestSet(File value) {
m_SearchSpaceTestSet = value;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String numIterationsTipText() {
return "The number parameter settings that are tried; ";
}
/**
* Gets the number of iterations.
*
* @return the number of folds.
*/
public int getNumIterations() {
return m_NumIterations;
}
/**
* Sets the number of iterations.
*
* @param value
* the test set, ignored if dir.
*/
public void setNumIterations(int value) {
m_NumIterations = value;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String randomSeedTipText() {
return "The seed used for randomization";
}
/**
* Gets the number of iterations.
*
* @return the number of folds.
*/
public int getRandomSeed() {
return m_RandomSeed;
}
/**
* Sets the random seed.
*
* @param value
* the random seed
*/
public void setRandomSeed(int value) {
m_RandomSeed = value;
}
/**
* determines the best point for the given space, using CV with specified
* number of folds.
*
* @param space
* the space to work on
* @param train
* the training data to work with
* @param test
* the test data to use, null if to use cross-validation
* @param folds
* the number of folds for cross-validation, if <2 then
* evaluation based on the training set is used
* @return the best point (not actual parameters!)
* @throws Exception
* if setup or training fails
*/
protected Performance determineBestInSpace(Space space, Instances train,
Instances test, int folds, Random random) throws Exception {
Performance result;
List> enm;
Performance performance;
Point