weka.classifiers.meta.IterativeClassifierOptimizer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* IterativeClassifierOptimizer.java
* Copyright (C) 2014 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.meta;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.evaluation.Evaluation;
import weka.classifiers.evaluation.EvaluationMetricHelper;
import weka.classifiers.evaluation.ThresholdProducingMetric;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.AdditionalMeasureProducer;
/**
* Chooses the best number of iterations for an IterativeClassifier such as
* LogitBoost using cross-validation.
*
* Optimizes the number of iterations of the given iterative classifier using cross-validation.
*
*
* Valid options are:
*
* -A
* If set, average estimate is used rather than one estimate from pooled predictions.
*
*
* -L <num>
* The number of iterations to look ahead for to find a better optimum.
* (default 50)
*
* -P <int>
* The size of the thread pool, for example, the number of cores in the CPU.
* (default 1)
*
* -E <int>
* The number of threads to use, which should be >= size of thread pool.
* (default 1)
*
* -I <num>
* Step size for the evaluation, if evaluation is time consuming.
* (default 1)
*
* -F <num>
* Number of folds for cross-validation.
* (default 10)
*
* -R <num>
* Number of runs for cross-validation.
* (default 1)
*
* -W
* Full name of base classifier.
* (default: weka.classifiers.meta.LogitBoost)
*
* -metric <name>
* Evaluation metric to optimise (default rmse). Available metrics:
* correct,incorrect,kappa,total cost,average cost,kb relative,kb information,
* correlation,complexity 0,complexity scheme,complexity improvement,
* mae,rmse,rae,rrse,coverage,region size,tp rate,fp rate,precision,recall,
* f-measure,mcc,roc area,prc area
*
* -class-value-index <0-based index>
* Class value index to optimise. Ignored for all but information-retrieval
* type metrics (such as roc area). If unspecified (or a negative value is supplied),
* and an information-retrieval metric is specified, then the class-weighted average
* metric used. (default -1)
*
* -S <num>
* Random number seed.
* (default 1)
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
*
* Options specific to classifier weka.classifiers.meta.LogitBoost:
*
*
* -Q
* Use resampling instead of reweighting for boosting.
*
* -P <percent>
* Percentage of weight mass to base training on.
* (default 100, reduce to around 90 speed up)
*
* -L <num>
* Threshold on the improvement of the likelihood.
* (default -Double.MAX_VALUE)
*
* -H <num>
* Shrinkage parameter.
* (default 1)
*
* -Z <num>
* Z max threshold for responses.
* (default 3)
*
* -O <int>
* The size of the thread pool, for example, the number of cores in the CPU. (default 1)
*
* -E <int>
* The number of threads to use for batch prediction, which should be >= size of thread pool.
* (default 1)
*
* -S <num>
* Random number seed.
* (default 1)
*
* -I <num>
* Number of iterations.
* (default 10)
*
* -W
* Full name of base classifier.
* (default: weka.classifiers.trees.DecisionStump)
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
*
* Options specific to classifier weka.classifiers.trees.DecisionStump:
*
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
*
* @author Eibe Frank ([email protected])
* @version $Revision: 10141 $
*/
public class IterativeClassifierOptimizer extends RandomizableClassifier
implements AdditionalMeasureProducer {
/** for serialization */
private static final long serialVersionUID = -3665485256313525864L;
/** The base classifier to use */
protected IterativeClassifier m_IterativeClassifier = new LogitBoost();
/** The number of folds for the cross-validation. */
protected int m_NumFolds = 10;
/** The number of runs for the cross-validation. */
protected int m_NumRuns = 1;
/** The steps size determining when evaluations happen. */
protected int m_StepSize = 1;
/** Whether to use average. */
protected boolean m_UseAverage = false;
/** The number of iterations to look ahead for to find a better optimum. */
protected int m_lookAheadIterations = 50;
public static Tag[] TAGS_EVAL;
static {
List evalNames = EvaluationMetricHelper.getAllMetricNames();
TAGS_EVAL = new Tag[evalNames.size()];
for (int i = 0; i < evalNames.size(); i++) {
TAGS_EVAL[i] = new Tag(i, evalNames.get(i), evalNames.get(i), false);
}
}
/** The evaluation metric to use */
protected String m_evalMetric = "rmse";
/**
* The class value index to use with information retrieval type metrics. < 0
* indicates to use the class weighted average version of the metric".
*/
protected int m_classValueIndex = -1;
/**
* The thresholds to be used for classification, if the metric implements
* ThresholdProducingMetric.
*/
protected double[] m_thresholds = null;
/** The best value found for the criterion to be optimized. */
protected double m_bestResult = Double.MAX_VALUE;
/** The best number of iterations identified. */
protected int m_bestNumIts;
/** The number of threads to use for parallel building of classifiers. */
protected int m_numThreads = 1;
/** The size of the thread pool. */
protected int m_poolSize = 1;
/**
* Returns a string describing classifier
*
* @return a description suitable for displaying in the explorer/experimenter
* gui
*/
public String globalInfo() {
return "Optimizes the number of iterations of the given iterative "
+ "classifier using cross-validation.";
}
/**
* String describing default classifier.
*/
protected String defaultIterativeClassifierString() {
return "weka.classifiers.meta.LogitBoost";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String useAverageTipText() {
return "If true, average estimates are used instead of one estimate from pooled predictions.";
}
/**
* Get the value of UseAverage.
*
* @return Value of UseAverage.
*/
public boolean getUseAverage() {
return m_UseAverage;
}
/**
* Set the value of UseAverage.
*
* @param newUseAverage Value to assign to UseAverage.
*/
public void setUseAverage(boolean newUseAverage) {
m_UseAverage = newUseAverage;
}
/**
* @return a string to describe the option
*/
public String numThreadsTipText() {
return "The number of threads to use, which should be >= size of thread pool.";
}
/**
* Gets the number of threads.
*/
public int getNumThreads() {
return m_numThreads;
}
/**
* Sets the number of threads
*/
public void setNumThreads(int nT) {
m_numThreads = nT;
}
/**
* @return a string to describe the option
*/
public String poolSizeTipText() {
return "The size of the thread pool, for example, the number of cores in the CPU.";
}
/**
* Gets the number of threads.
*/
public int getPoolSize() {
return m_poolSize;
}
/**
* Sets the number of threads
*/
public void setPoolSize(int nT) {
m_poolSize = nT;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String stepSizeTipText() {
return "Step size for the evaluation, if evaluation is time consuming.";
}
/**
* Get the value of StepSize.
*
* @return Value of StepSize.
*/
public int getStepSize() {
return m_StepSize;
}
/**
* Set the value of StepSize.
*
* @param newStepSize Value to assign to StepSize.
*/
public void setStepSize(int newStepSize) {
m_StepSize = newStepSize;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String numRunsTipText() {
return "Number of runs for cross-validation.";
}
/**
* Get the value of NumRuns.
*
* @return Value of NumRuns.
*/
public int getNumRuns() {
return m_NumRuns;
}
/**
* Set the value of NumRuns.
*
* @param newNumRuns Value to assign to NumRuns.
*/
public void setNumRuns(int newNumRuns) {
m_NumRuns = newNumRuns;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String numFoldsTipText() {
return "Number of folds for cross-validation.";
}
/**
* Get the value of NumFolds.
*
* @return Value of NumFolds.
*/
public int getNumFolds() {
return m_NumFolds;
}
/**
* Set the value of NumFolds.
*
* @param newNumFolds Value to assign to NumFolds.
*/
public void setNumFolds(int newNumFolds) {
m_NumFolds = newNumFolds;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lookAheadIterationsTipText() {
return "The number of iterations to look ahead for to find a better optimum.";
}
/**
* Get the value of LookAheadIterations.
*
* @return Value of LookAheadIterations.
*/
public int getLookAheadIterations() {
return m_lookAheadIterations;
}
/**
* Set the value of LookAheadIterations.
*
* @param newLookAheadIterations Value to assign to LookAheadIterations.
*/
public void setLookAheadIterations(int newLookAheadIterations) {
m_lookAheadIterations = newLookAheadIterations;
}
/**
* Builds the classifier.
*/
@Override
public void buildClassifier(Instances data) throws Exception {
if (m_IterativeClassifier == null) {
throw new Exception("A base classifier has not been specified!");
}
// Can classifier handle the data?
getCapabilities().testWithFail(data);
// Need to shuffle the data
Random randomInstance = new Random(m_Seed);
// Save reference to original data
Instances origData = data;
// Remove instances with missing class
data = new Instances(data);
data.deleteWithMissingClass();
if (data.numInstances() < m_NumFolds) {
System.err.println("WARNING: reducing number of folds to number of instances in " +
"IterativeClassifierOptimizer");
m_NumFolds = data.numInstances();
}
// Initialize datasets and classifiers
Instances[][] trainingSets = new Instances[m_NumRuns][m_NumFolds];
Instances[][] testSets = new Instances[m_NumRuns][m_NumFolds];
final IterativeClassifier[][] classifiers = new IterativeClassifier[m_NumRuns][m_NumFolds];
for (int j = 0; j < m_NumRuns; j++) {
data.randomize(randomInstance);
if (data.classAttribute().isNominal()) {
data.stratify(m_NumFolds);
}
for (int i = 0; i < m_NumFolds; i++) {
trainingSets[j][i] = data.trainCV(m_NumFolds, i, randomInstance);
testSets[j][i] = data.testCV(m_NumFolds, i);
classifiers[j][i] =
(IterativeClassifier) AbstractClassifier.makeCopy(m_IterativeClassifier);
classifiers[j][i].initializeClassifier(trainingSets[j][i]);
}
}
// The thread pool to be used for parallel execution.
ExecutorService pool = Executors.newFixedThreadPool(m_poolSize);;
// Perform evaluation
Evaluation eval = new Evaluation(data);
EvaluationMetricHelper helper = new EvaluationMetricHelper(eval);
boolean maximise = helper.metricIsMaximisable(m_evalMetric);
if (maximise) {
m_bestResult = Double.MIN_VALUE;
} else {
m_bestResult = Double.MAX_VALUE;
}
m_thresholds = null;
int numIts = 0;
m_bestNumIts = 0;
int numberOfIterationsSinceMinimum = -1;
while (true) {
// Should we perform an evaluation?
if (numIts % m_StepSize == 0) {
double result = 0;
double[] tempThresholds = null;
// Shall we use the average score obtained from the folds or not?
if (!m_UseAverage) {
eval = new Evaluation(data);
helper.setEvaluation(eval);
for (int r = 0; r < m_NumRuns; r++) {
for (int i = 0; i < m_NumFolds; i++) {
eval.evaluateModel(classifiers[r][i], testSets[r][i]);
}
}
result =
getClassValueIndex() >= 0 ?
helper.getNamedMetric(m_evalMetric,
getClassValueIndex()) : helper.getNamedMetric(m_evalMetric);
tempThresholds = helper.getNamedMetricThresholds(m_evalMetric);
} else {
// Using average score
for (int r = 0; r < m_NumRuns; r++) {
for (int i = 0; i < m_NumFolds; i++) {
eval = new Evaluation(trainingSets[r][i]);
helper.setEvaluation(eval);
eval.evaluateModel(classifiers[r][i], testSets[r][i]);
result +=
getClassValueIndex() >= 0 ?
helper.getNamedMetric(m_evalMetric,
getClassValueIndex()) : helper.getNamedMetric(m_evalMetric);
double[] thresholds = helper.getNamedMetricThresholds(m_evalMetric);
// Add thresholds (if applicable) so that we can compute average thresholds later
if (thresholds != null) {
if (tempThresholds == null) {
tempThresholds = new double[data.numClasses()];
}
for (int j = 0; j < thresholds.length; j++) {
tempThresholds[j] += thresholds[j];
}
}
}
}
result /= (double)(m_NumFolds * m_NumRuns);
// Compute average thresholds if applicable
if (tempThresholds != null) {
for (int j = 0; j < tempThresholds.length; j++) {
tempThresholds[j] /= (double) (m_NumRuns * m_NumFolds);
}
}
}
if (m_Debug) {
System.err.println("Iteration: " + numIts + " " + "Measure: " + result);
if (tempThresholds != null) {
System.err.print("Thresholds:");
for (int j = 0; j < tempThresholds.length; j++) {
System.err.print(" " + tempThresholds[j]);
}
System.err.println();
}
}
double delta = maximise ? m_bestResult - result : result - m_bestResult;
// Is there an improvement?
if (delta < 0) {
m_bestResult = result;
m_bestNumIts = numIts;
m_thresholds = tempThresholds;
numberOfIterationsSinceMinimum = -1;
}
}
numberOfIterationsSinceMinimum++;
numIts++;
if (numberOfIterationsSinceMinimum >= m_lookAheadIterations) {
break;
}
// Set up result set, and chunk size
int numRuns = m_NumRuns * m_NumFolds;
final int N = m_NumFolds;
final int chunksize = numRuns / m_numThreads;
Set> results = new HashSet>();
// For each thread
for (int j = 0; j < m_numThreads; j++) {
// Determine batch to be processed
final int lo = j * chunksize;
final int hi = (j < m_numThreads - 1) ? (lo + chunksize) : numRuns;
// Create and submit new job
Future futureT = pool.submit(new Callable() {
@Override
public Boolean call() throws Exception {
for (int k = lo; k < hi; k++) {
if (!classifiers[k / N][k % N].next()) {
if (m_Debug) {
System.err.println("Classifier failed to iterate in cross-validation.");
}
return false;
}
}
return true;
}
});
results.add(futureT);
}
// Check that all classifiers succeeded
try {
boolean failure = false;
for (Future futureT : results) {
if (!futureT.get()) {
failure = true;
break; // Break out if one classifier fails to iterate
}
}
if (failure) {
break;
}
} catch (Exception e) {
System.out.println("Classifiers could not be generated.");
e.printStackTrace();
}
}
trainingSets = null;
testSets = null;
data = null;
// Build classifieer based on identified number of iterations
m_IterativeClassifier.initializeClassifier(origData);
int i = 0;
while (i++ < m_bestNumIts && m_IterativeClassifier.next()) {
}
;
m_IterativeClassifier.done();
// Shut down thread pool
pool.shutdown();
}
/**
* Returns the class distribution for an instance.
*/
@Override
public double[] distributionForInstance(Instance inst) throws Exception {
// Does the metric produce thresholds that need to be applied?
if (m_thresholds != null) {
double[] dist = m_IterativeClassifier.distributionForInstance(inst);
double[] newDist = new double[dist.length];
for (int i = 0; i < dist.length; i++) {
if (dist[i] >= m_thresholds[i]) {
newDist[i] = 1.0;
}
}
Utils.normalize(newDist); // Could have multiple 1.0 entries
return newDist;
} else {
return m_IterativeClassifier.distributionForInstance(inst);
}
}
/**
* Returns a string describing the classifier.
*/
@Override
public String toString() {
if (m_IterativeClassifier == null) {
return "No classifier built yet.";
} else {
StringBuffer sb = new StringBuffer();
sb.append("Best value found: " + m_bestResult + "\n");
sb.append("Best number of iterations found: " + m_bestNumIts + "\n\n");
if (m_thresholds != null) {
sb.append("Thresholds found: ");
for (int i = 0; i < m_thresholds.length; i++) {
sb.append(m_thresholds[i] + " ");
}
}
sb.append("\n\n");
sb.append(m_IterativeClassifier.toString());
return sb.toString();
}
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration
© 2015 - 2024 Weber Informatics LLC | Privacy Policy