weka.classifiers.meta.GridSearch Maven / Gradle / Ivy
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* GridSearch.java
* Copyright (C) 2006-2010 University of Waikato, Hamilton, New Zealand
*/
package weka.classifiers.meta;
import java.beans.PropertyDescriptor;
import java.io.File;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.*;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Debug;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.PropertyPath;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.SerializedObject;
import weka.core.Summarizable;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.WekaException;
import weka.core.expressionlanguage.common.Primitives.DoubleExpression;
import weka.core.expressionlanguage.common.SimpleVariableDeclarations;
import weka.core.expressionlanguage.common.MacroDeclarationsCompositor;
import weka.core.expressionlanguage.common.MathFunctions;
import weka.core.expressionlanguage.common.IfElseMacro;
import weka.core.expressionlanguage.common.JavaMacro;
import weka.core.expressionlanguage.parser.Parser;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.MathExpression;
import weka.filters.unsupervised.instance.Resample;
/**
* Performs a grid search of parameter pairs for a
* classifier and chooses the best pair found for the actual predicting.
*
* The initial grid is worked on with 2-fold CV to determine the values of the
* parameter pairs for the selected type of evaluation (e.g., accuracy). The
* best point in the grid is then taken and a 10-fold CV is performed with the
* adjacent parameter pairs. If a better pair is found, then this will act as
* new center and another 10-fold CV will be performed (kind of hill-climbing).
* This process is repeated until no better pair is found or the best pair is on
* the border of the grid.
* In case the best pair is on the border, one can let GridSearch automatically
* extend the grid and continue the search. Check out the properties
* 'gridIsExtendable' (option '-extend-grid') and 'maxGridExtensions' (option
* '-max-grid-extensions <num>').
*
* GridSearch can handle doubles, integers (values are just cast to int) and
* booleans (0 is false, otherwise true). float, char and long are supported as
* well.
*
* The best classifier setup can be accessed after the buildClassifier call via
* the getBestClassifier methods.
* Note: with -num-slots/numExecutionSlots you can specify how many setups are
* evaluated in parallel, taking advantage of multi-cpu/core architectures.
*
*
*
* Valid options are:
*
*
*
* -E <CC|RMSE|RRSE|MAE|RAE|COMB|ACC|WAUC|KAP>
* Determines the parameter used for evaluation:
* CC = Correlation coefficient
* RMSE = Root mean squared error
* RRSE = Root relative squared error
* MAE = Mean absolute error
* RAE = Root absolute error
* COMB = Combined = (1-abs(CC)) + RRSE + RAE
* ACC = Accuracy
* WAUC = Weighted AUC
* KAP = Kappa
* (default: CC)
*
*
*
* -y-property <option>
* The Y option to test (without leading dash).
* (default: kernel.gamma)
*
*
*
* -y-min <num>
* The minimum for Y.
* (default: -3)
*
*
*
* -y-max <num>
* The maximum for Y.
* (default: +3)
*
*
*
* -y-step <num>
* The step size for Y.
* (default: 1)
*
*
*
* -y-base <num>
* The base for Y.
* (default: 10)
*
*
*
* -y-expression <expr>
* The expression for Y.
* Available parameters:
* BASE
* FROM
* TO
* STEP
* I - the current iteration value
* (from 'FROM' to 'TO' with stepsize 'STEP')
* (default: 'pow(BASE,I)')
*
*
*
* -x-property <option>
* The X option to test (without leading dash).
* (default: C)
*
*
*
* -x-min <num>
* The minimum for X.
* (default: -3)
*
*
*
* -x-max <num>
* The maximum for X.
* (default: 3)
*
*
*
* -x-step <num>
* The step size for X.
* (default: 1)
*
*
*
* -x-base <num>
* The base for X.
* (default: 10)
*
*
*
* -x-expression <expr>
* The expression for the X value.
* Available parameters:
* BASE
* MIN
* MAX
* STEP
* I - the current iteration value
* (from 'FROM' to 'TO' with stepsize 'STEP')
* (default: 'pow(BASE,I)')
*
*
*
* -extend-grid
* Whether the grid can be extended.
* (default: no)
*
*
*
* -max-grid-extensions <num>
* The maximum number of grid extensions (-1 is unlimited).
* (default: 3)
*
*
*
* -sample-size <num>
* The size (in percent) of the sample to search the inital grid with.
* (default: 100)
*
*
*
* -traversal <ROW-WISE|COLUMN-WISE>
* The type of traversal for the grid.
* (default: COLUMN-WISE)
*
*
*
* -log-file <filename>
* The log file to log the messages to.
* (default: none)
*
*
*
* -num-slots <num>
* Number of execution slots.
* (default 1 - i.e. no parallelism)
*
*
*
* -S <num>
* Random number seed.
* (default 1)
*
*
*
* -W
* Full name of base classifier.
* (default: weka.classifiers.functions.SMOreg with options -K weka.classifiers.functions.supportVector.RBFKernel)
*
*
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
*
*
* Options specific to classifier weka.classifiers.functions.SMOreg:
*
*
*
* -C <double>
* The complexity constant C.
* (default 1)
*
*
*
* -N
* Whether to 0=normalize/1=standardize/2=neither.
* (default 0=normalize)
*
*
*
* -I <classname and parameters>
* Optimizer class used for solving quadratic optimization problem
* (default weka.classifiers.functions.supportVector.RegSMOImproved)
*
*
*
* -K <classname and parameters>
* The Kernel to use.
* (default: weka.classifiers.functions.supportVector.PolyKernel)
*
*
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
*
*
* Options specific to optimizer ('-I') weka.classifiers.functions.supportVector.RegSMOImproved:
*
*
*
* -T <double>
* The tolerance parameter for checking the stopping criterion.
* (default 0.001)
*
*
*
* -V
* Use variant 1 of the algorithm when true, otherwise use variant 2.
* (default true)
*
*
*
* -P <double>
* The epsilon for round-off error.
* (default 1.0e-12)
*
*
*
* -L <double>
* The epsilon parameter in epsilon-insensitive loss function.
* (default 1.0e-3)
*
*
*
* -W <double>
* The random number seed.
* (default 1)
*
*
*
* Options specific to kernel ('-K') weka.classifiers.functions.supportVector.RBFKernel:
*
*
*
* -G <num>
* The Gamma parameter.
* (default: 0.01)
*
*
*
* -C <num>
* The size of the cache (a prime number), 0 for full cache and
* -1 to turn it off.
* (default: 250007)
*
*
*
* -output-debug-info
* Enables debugging output (if available) to be printed.
* (default: off)
*
*
*
* -no-checks
* Turns off all checks - use with caution!
* (default: checks on)
*
*
*
*
* Examples:
*
* -
* Optimizing SMO with RBFKernel (C and gamma)
*
* - Set the evaluation to Accuracy.
* - Set the filter to
weka.filters.AllFilter
since we don't need
* any special data processing and we don't optimize the filter in this case
* (data gets always passed through filter!).
* - Set
weka.classifiers.functions.SMO
as classifier with
* weka.classifiers.functions.supportVector.RBFKernel
as kernel.
* - Set the XProperty to "classifier.c", XMin to "1", XMax to "16", XStep to
* "1" and the XExpression to "I". This will test the "C" parameter of SMO for
* the values from 1 to 16.
* - Set the YProperty to "classifier.kernel.gamma", YMin to "-5", YMax to
* "2", YStep to "1" YBase to "10" and YExpression to "pow(BASE,I)". This will
* test the gamma of the RBFKernel with the values 10^-5, 10^-4,..,10^2.
*
*
* -
* Optimizing PLSFilter with LinearRegression (# of components and ridge) -
* default setup
*
* - Set the evaluation to Correlation coefficient.
* - Set the filter to
*
weka.filters.supervised.attribute.PLSFilter
.
* - Set
weka.classifiers.functions.LinearRegression
as
* classifier and use no attribute selection and no elimination of colinear
* attributes.
* - Set the XProperty to "filter.numComponents", XMin to "5", XMax to "20"
* (this depends heavily on your dataset, should be no more than the number of
* attributes!), XStep to "1" and XExpression to "I". This will test the number
* of components the PLSFilter will produce from 5 to 20.
* - Set the YProperty to "classifier.ridge", XMin to "-10", XMax to "5",
* YStep to "1" and YExpression to "pow(BASE,I)". This will try ridge parameters
* from 10^-10 to 10^5.
*
*
*
*
* General notes:
*
* - Turn the debug flag on in order to see some progress output in the
* console
* - If you want to view the fitness landscape that GridSearch explores,
* select a log file. This log will then contain Gnuplot data and script
* block for viewing the landscape. Just copy paste those blocks into files
* named accordingly and run Gnuplot with them.
*
*
* @author Bernhard Pfahringer (bernhard at cs dot waikato dot ac dot nz)
* @author Geoff Holmes (geoff at cs dot waikato dot ac dot nz)
* @author fracpete (fracpete at waikato dot ac dot nz)
* @version $Revision: 14326 $
*/
public class GridSearch extends RandomizableSingleClassifierEnhancer implements
AdditionalMeasureProducer, Summarizable {
/**
* a serializable version of Point2D.Double.
*
* @see java.awt.geom.Point2D.Double
*/
protected static class PointDouble extends java.awt.geom.Point2D.Double
implements Serializable, RevisionHandler {
/** for serialization. */
private static final long serialVersionUID = 7151661776161898119L;
/**
* the default constructor.
*
* @param x the x value of the point
* @param y the y value of the point
*/
public PointDouble(double x, double y) {
super(x, y);
}
/**
* Determines whether or not two points are equal.
*
* @param obj an object to be compared with this PointDouble
* @return true if the object to be compared has the same values; false
* otherwise.
*/
@Override
public boolean equals(Object obj) {
PointDouble pd;
pd = (PointDouble) obj;
return (Utils.eq(this.getX(), pd.getX()) && Utils.eq(this.getY(),
pd.getY()));
}
/**
* returns a string representation of the Point.
*
* @return the point as string
*/
@Override
public String toString() {
return super.toString().replaceAll(".*\\[", "[");
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 14326 $");
}
}
/**
* a serializable version of Point.
*
* @see java.awt.Point
*/
protected static class PointInt extends java.awt.Point implements
Serializable, RevisionHandler {
/** for serialization. */
private static final long serialVersionUID = -5900415163698021618L;
/**
* the default constructor.
*
* @param x the x value of the point
* @param y the y value of the point
*/
public PointInt(int x, int y) {
super(x, y);
}
/**
* returns a string representation of the Point.
*
* @return the point as string
*/
@Override
public String toString() {
return super.toString().replaceAll(".*\\[", "[");
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 14326 $");
}
}
/**
* for generating the parameter pairs in a grid.
*/
protected static class Grid implements Serializable, RevisionHandler {
/** for serialization. */
private static final long serialVersionUID = 7290732613611243139L;
/** the minimum on the X axis. */
protected double m_MinX;
/** the maximum on the X axis. */
protected double m_MaxX;
/** the step size for the X axis. */
protected double m_StepX;
/** the label for the X axis. */
protected String m_LabelX;
/** the minimum on the Y axis. */
protected double m_MinY;
/** the maximum on the Y axis. */
protected double m_MaxY;
/** the step size for the Y axis. */
protected double m_StepY;
/** the label for the Y axis. */
protected String m_LabelY;
/** the number of points on the X axis. */
protected int m_Width;
/** the number of points on the Y axis. */
protected int m_Height;
/**
* initializes the grid.
*
* @param minX the minimum on the X axis
* @param maxX the maximum on the X axis
* @param stepX the step size for the X axis
* @param minY the minimum on the Y axis
* @param maxY the maximum on the Y axis
* @param stepY the step size for the Y axis
*/
public Grid(double minX, double maxX, double stepX, double minY,
double maxY, double stepY) {
this(minX, maxX, stepX, "", minY, maxY, stepY, "");
}
/**
* initializes the grid.
*
* @param minX the minimum on the X axis
* @param maxX the maximum on the X axis
* @param stepX the step size for the X axis
* @param labelX the label for the X axis
* @param minY the minimum on the Y axis
* @param maxY the maximum on the Y axis
* @param stepY the step size for the Y axis
* @param labelY the label for the Y axis
*/
public Grid(double minX, double maxX, double stepX, String labelX,
double minY, double maxY, double stepY, String labelY) {
super();
m_MinX = minX;
m_MaxX = maxX;
m_StepX = stepX;
m_LabelX = labelX;
m_MinY = minY;
m_MaxY = maxY;
m_StepY = stepY;
m_LabelY = labelY;
m_Height = (int) StrictMath.round((m_MaxY - m_MinY) / m_StepY) + 1;
m_Width = (int) StrictMath.round((m_MaxX - m_MinX) / m_StepX) + 1;
// is min < max?
if (m_MinX >= m_MaxX) {
throw new IllegalArgumentException("XMin must be smaller than XMax!");
}
if (m_MinY >= m_MaxY) {
throw new IllegalArgumentException("YMin must be smaller than YMax!");
}
// steps positive?
if (m_StepX <= 0) {
throw new IllegalArgumentException("XStep must be a positive number!");
}
if (m_StepY <= 0) {
throw new IllegalArgumentException("YStep must be a positive number!");
}
// check borders
if (!Utils.eq(m_MinX + (m_Width - 1) * m_StepX, m_MaxX)) {
throw new IllegalArgumentException(
"X axis doesn't match! Provided max: " + m_MaxX
+ ", calculated max via min and step size: "
+ (m_MinX + (m_Width - 1) * m_StepX));
}
if (!Utils.eq(m_MinY + (m_Height - 1) * m_StepY, m_MaxY)) {
throw new IllegalArgumentException(
"Y axis doesn't match! Provided max: " + m_MaxY
+ ", calculated max via min and step size: "
+ (m_MinY + (m_Height - 1) * m_StepY));
}
}
/**
* Tests itself against the provided grid object.
*
* @param o the grid object to compare against
* @return if the two grids have the same setup
*/
@Override
public boolean equals(Object o) {
boolean result;
Grid g;
g = (Grid) o;
result = (width() == g.width()) && (height() == g.height())
&& (getMinX() == g.getMinX()) && (getMinY() == g.getMinY())
&& (getStepX() == g.getStepX()) && (getStepY() == g.getStepY())
&& getLabelX().equals(g.getLabelX())
&& getLabelY().equals(g.getLabelY());
return result;
}
/**
* returns the left border.
*
* @return the left border
*/
public double getMinX() {
return m_MinX;
}
/**
* returns the right border.
*
* @return the right border
*/
public double getMaxX() {
return m_MaxX;
}
/**
* returns the step size on the X axis.
*
* @return the step size
*/
public double getStepX() {
return m_StepX;
}
/**
* returns the label for the X axis.
*
* @return the label
*/
public String getLabelX() {
return m_LabelX;
}
/**
* returns the bottom border.
*
* @return the bottom border
*/
public double getMinY() {
return m_MinY;
}
/**
* returns the top border.
*
* @return the top border
*/
public double getMaxY() {
return m_MaxY;
}
/**
* returns the step size on the Y axis.
*
* @return the step size
*/
public double getStepY() {
return m_StepY;
}
/**
* returns the label for the Y axis.
*
* @return the label
*/
public String getLabelY() {
return m_LabelY;
}
/**
* returns the number of points in the grid on the Y axis (incl. borders)
*
* @return the number of points in the grid on the Y axis
*/
public int height() {
return m_Height;
}
/**
* returns the number of points in the grid on the X axis (incl. borders)
*
* @return the number of points in the grid on the X axis
*/
public int width() {
return m_Width;
}
/**
* returns the values at the given point in the grid.
*
* @param x the x-th point on the X axis
* @param y the y-th point on the Y axis
* @return the value pair at the given position
*/
public PointDouble getValues(int x, int y) {
if (x >= width()) {
throw new IllegalArgumentException("Index out of scope on X axis (" + x
+ " >= " + width() + ")!");
}
if (y >= height()) {
throw new IllegalArgumentException("Index out of scope on Y axis (" + y
+ " >= " + height() + ")!");
}
return new PointDouble(m_MinX + m_StepX * x, m_MinY + m_StepY * y);
}
/**
* returns the closest index pair for the given value pair in the grid.
*
* @param values the values to get the indices for
* @return the closest indices in the grid
*/
public PointInt getLocation(PointDouble values) {
PointInt result;
int x;
int y;
double distance;
double currDistance;
int i;
// determine x
x = 0;
distance = m_StepX;
for (i = 0; i < width(); i++) {
currDistance = StrictMath.abs(values.getX() - getValues(i, 0).getX());
if (Utils.sm(currDistance, distance)) {
distance = currDistance;
x = i;
}
}
// determine y
y = 0;
distance = m_StepY;
for (i = 0; i < height(); i++) {
currDistance = StrictMath.abs(values.getY() - getValues(0, i).getY());
if (Utils.sm(currDistance, distance)) {
distance = currDistance;
y = i;
}
}
result = new PointInt(x, y);
return result;
}
/**
* checks whether the given values are on the border of the grid.
*
* @param values the values to check
* @return true if the the values are on the border
*/
public boolean isOnBorder(PointDouble values) {
return isOnBorder(getLocation(values));
}
/**
* checks whether the given location is on the border of the grid.
*
* @param location the location to check
* @return true if the the location is on the border
*/
public boolean isOnBorder(PointInt location) {
if (location.getX() == 0) {
return true;
} else if (location.getX() == width() - 1) {
return true;
}
if (location.getY() == 0) {
return true;
} else if (location.getY() == height() - 1) {
return true;
} else {
return false;
}
}
/**
* returns a subgrid with the same step sizes, but different borders.
*
* @param top the top index
* @param left the left index
* @param bottom the bottom index
* @param right the right index
* @return the Sub-Grid
*/
public Grid subgrid(int top, int left, int bottom, int right) {
return new Grid(getValues(left, top).getX(),
getValues(right, top).getX(), getStepX(), getLabelX(), getValues(left,
bottom).getY(), getValues(left, top).getY(), getStepY(), getLabelY());
}
/**
* returns an extended grid that encompasses the given point (won't be on
* the border of the grid).
*
* @param values the point that the grid should contain
* @return the extended grid
*/
public Grid extend(PointDouble values) {
double minX;
double maxX;
double minY;
double maxY;
double distance;
Grid result;
// left
if (Utils.smOrEq(values.getX(), getMinX())) {
distance = getMinX() - values.getX();
// exactly on grid point?
if (Utils.eq(distance, 0)) {
minX = getMinX() - getStepX()
* (StrictMath.round(distance / getStepX()) + 1);
} else {
minX = getMinX() - getStepX()
* (StrictMath.round(distance / getStepX()));
}
} else {
minX = getMinX();
}
// right
if (Utils.grOrEq(values.getX(), getMaxX())) {
distance = values.getX() - getMaxX();
// exactly on grid point?
if (Utils.eq(distance, 0)) {
maxX = getMaxX() + getStepX()
* (StrictMath.round(distance / getStepX()) + 1);
} else {
maxX = getMaxX() + getStepX()
* (StrictMath.round(distance / getStepX()));
}
} else {
maxX = getMaxX();
}
// bottom
if (Utils.smOrEq(values.getY(), getMinY())) {
distance = getMinY() - values.getY();
// exactly on grid point?
if (Utils.eq(distance, 0)) {
minY = getMinY() - getStepY()
* (StrictMath.round(distance / getStepY()) + 1);
} else {
minY = getMinY() - getStepY()
* (StrictMath.round(distance / getStepY()));
}
} else {
minY = getMinY();
}
// top
if (Utils.grOrEq(values.getY(), getMaxY())) {
distance = values.getY() - getMaxY();
// exactly on grid point?
if (Utils.eq(distance, 0)) {
maxY = getMaxY() + getStepY()
* (StrictMath.round(distance / getStepY()) + 1);
} else {
maxY = getMaxY() + getStepY()
* (StrictMath.round(distance / getStepY()));
}
} else {
maxY = getMaxY();
}
result = new Grid(minX, maxX, getStepX(), getLabelX(), minY, maxY,
getStepY(), getLabelY());
// did the grid really extend?
if (equals(result)) {
throw new IllegalStateException("Grid extension failed!");
}
return result;
}
/**
* returns an Enumeration over all pairs in the given row.
*
* @param y the row to retrieve
* @return an Enumeration over all pairs
* @see #getValues(int, int)
*/
public Enumeration row(int y) {
Vector result;
int i;
result = new Vector();
for (i = 0; i < width(); i++) {
result.add(getValues(i, y));
}
return result.elements();
}
/**
* returns an Enumeration over all pairs in the given column.
*
* @param x the column to retrieve
* @return an Enumeration over all pairs
* @see #getValues(int, int)
*/
public Enumeration column(int x) {
Vector result;
int i;
result = new Vector();
for (i = 0; i < height(); i++) {
result.add(getValues(x, i));
}
return result.elements();
}
/**
* returns a string representation of the grid.
*
* @return a string representation
*/
@Override
public String toString() {
String result;
result = "X: " + m_MinX + " - " + m_MaxX + ", Step " + m_StepX;
if (m_LabelX.length() != 0) {
result += " (" + m_LabelX + ")";
}
result += "\n";
result += "Y: " + m_MinY + " - " + m_MaxY + ", Step " + m_StepY;
if (m_LabelY.length() != 0) {
result += " (" + m_LabelY + ")";
}
result += "\n";
result += "Dimensions (Rows x Columns): " + height() + " x " + width();
return result;
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 14326 $");
}
}
/**
* A helper class for storing the performance of a values-pair. Can be sorted
* with the PerformanceComparator class.
*
* @see PerformanceComparator
*/
protected static class Performance implements Serializable, RevisionHandler {
/** for serialization. */
private static final long serialVersionUID = -4374706475277588755L;
/** the value pair the classifier was built with. */
protected PointDouble m_Values;
/** the Correlation coefficient. */
protected double m_CC;
/** the Root mean squared error. */
protected double m_RMSE;
/** the Root relative squared error. */
protected double m_RRSE;
/** the Mean absolute error. */
protected double m_MAE;
/** the Relative absolute error. */
protected double m_RAE;
/** the Accuracy. */
protected double m_ACC;
/** The weighted AUC value. */
protected double m_wAUC;
/** the kappa value. */
protected double m_Kappa;
/**
* initializes the performance container.
*
* @param values the values-pair
* @param evaluation the evaluation to extract the performance measures from
* @throws Exception if retrieving of measures fails
*/
public Performance(PointDouble values, Evaluation evaluation)
throws Exception {
super();
m_Values = values;
m_RMSE = evaluation.rootMeanSquaredError();
m_RRSE = evaluation.rootRelativeSquaredError();
m_MAE = evaluation.meanAbsoluteError();
m_RAE = evaluation.relativeAbsoluteError();
try {
m_wAUC = evaluation.weightedAreaUnderROC();
} catch (Exception e) {
m_wAUC = Double.NaN;
}
try {
m_CC = evaluation.correlationCoefficient();
} catch (Exception e) {
m_CC = Double.NaN;
}
try {
m_ACC = evaluation.pctCorrect();
} catch (Exception e) {
m_ACC = Double.NaN;
}
try {
m_Kappa = evaluation.kappa();
} catch (Exception e) {
m_Kappa = Double.NaN;
}
}
/**
* returns the performance measure.
*
* @param evaluation the type of measure to return
* @return the performance measure
*/
public double getPerformance(int evaluation) {
double result;
result = Double.NaN;
switch (evaluation) {
case EVALUATION_CC:
result = m_CC;
break;
case EVALUATION_RMSE:
result = m_RMSE;
break;
case EVALUATION_RRSE:
result = m_RRSE;
break;
case EVALUATION_MAE:
result = m_MAE;
break;
case EVALUATION_RAE:
result = m_RAE;
break;
case EVALUATION_COMBINED:
result = (1 - StrictMath.abs(m_CC)) + m_RRSE + m_RAE;
break;
case EVALUATION_ACC:
result = m_ACC;
break;
case EVALUATION_KAPPA:
result = m_Kappa;
break;
case EVALUATION_WAUC:
result = m_wAUC;
break;
default:
throw new IllegalArgumentException("Evaluation type '" + evaluation
+ "' not supported!");
}
return result;
}
/**
* returns the values-pair for this performance.
*
* @return the values-pair
*/
public PointDouble getValues() {
return m_Values;
}
/**
* returns a string representation of this performance object.
*
* @param evaluation the type of performance to return
* @return a string representation
*/
public String toString(int evaluation) {
String result;
result = "Performance (" + getValues() + "): "
+ getPerformance(evaluation) + " ("
+ new SelectedTag(evaluation, TAGS_EVALUATION) + ")";
return result;
}
/**
* returns a Gnuplot string of this performance object.
*
* @param evaluation the type of performance to return
* @return the gnuplot string (x, y, z)
*/
public String toGnuplot(int evaluation) {
String result;
result = getValues().getX() + "\t" + getValues().getY() + "\t"
+ getPerformance(evaluation);
return result;
}
/**
* returns a string representation of this performance object.
*
* @return a string representation
*/
@Override
public String toString() {
String result;
int i;
result = "Performance (" + getValues() + "): ";
for (i = 0; i < TAGS_EVALUATION.length; i++) {
if (i > 0) {
result += ", ";
}
result += getPerformance(TAGS_EVALUATION[i].getID()) + " ("
+ new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION) + ")";
}
return result;
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 14326 $");
}
}
/**
* A concrete Comparator for the Performance class.
*
* @see Performance
*/
protected static class PerformanceComparator implements
Comparator, Serializable, RevisionHandler {
/** for serialization. */
private static final long serialVersionUID = 6507592831825393847L;
/**
* the performance measure to use for comparison.
*
* @see GridSearch#TAGS_EVALUATION
*/
protected int m_Evaluation;
/**
* initializes the comparator with the given performance measure.
*
* @param evaluation the performance measure to use
* @see GridSearch#TAGS_EVALUATION
*/
public PerformanceComparator(int evaluation) {
super();
m_Evaluation = evaluation;
}
/**
* returns the performance measure that's used to compare the objects.
*
* @return the performance measure
* @see GridSearch#TAGS_EVALUATION
*/
public int getEvaluation() {
return m_Evaluation;
}
/**
* Compares its two arguments for order. Returns a negative integer, zero,
* or a positive integer as the first argument is less than, equal to, or
* greater than the second.
*
* @param o1 the first performance
* @param o2 the second performance
* @return the order
*/
@Override
public int compare(Performance o1, Performance o2) {
int result;
double p1;
double p2;
p1 = o1.getPerformance(getEvaluation());
p2 = o2.getPerformance(getEvaluation());
if (p1 < p2) {
result = -1;
} else if (p1 > p2) {
result = 1;
} else { // Need to make order deterministic
if (o1.getValues().getX() < o2.getValues().getX()) {
result = -1;
} else if (o1.getValues().getX() > o2.getValues().getX()) {
result = 1;
} else {
if (o1.getValues().getY() < o2.getValues().getY()) {
result = -1;
} else if (o1.getValues().getY() > o2.getValues().getY()) {
result = 1;
} else {
result = 0;
}
}
}
// only correlation coefficient/accuracy/kappa obey to this order, for the
// errors (and the combination of all three), the smaller the number the
// better -> hence invert them
if ((getEvaluation() != EVALUATION_CC)
&& (getEvaluation() != EVALUATION_ACC)
&& (getEvaluation() != EVALUATION_WAUC)
&& (getEvaluation() != EVALUATION_KAPPA)) {
result = -result;
}
return result;
}
/**
* Indicates whether some other object is "equal to" this Comparator.
*
* @param obj the object to compare with
* @return true if the same evaluation type is used
*/
@Override
public boolean equals(Object obj) {
if (!(obj instanceof PerformanceComparator)) {
throw new IllegalArgumentException("Must be PerformanceComparator!");
}
return (m_Evaluation == ((PerformanceComparator) obj).m_Evaluation);
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 14326 $");
}
}
/**
* Generates a 2-dim array for the performances from a grid for a certain
* type. x-min/y-min is in the bottom-left corner, i.e., getTable()[0][0]
* returns the performance for the x-min/y-max pair.
*
*
* x-min x-max
* |-------------|
* - y-max
* |
* |
* - y-min
*
*/
protected static class PerformanceTable implements Serializable,
RevisionHandler {
/** for serialization. */
private static final long serialVersionUID = 5486491313460338379L;
/** the owning classifier. */
protected GridSearch m_Owner;
/** the corresponding grid. */
protected Grid m_Grid;
/** the performances. */
protected Vector m_Performances;
/** the type of performance the table was generated for. */
protected int m_Type;
/** the table with the values. */
protected double[][] m_Table;
/** the minimum performance. */
protected double m_Min;
/** the maximum performance. */
protected double m_Max;
/**
* initializes the table.
*
* @param owner the owning GridSearch
* @param grid the underlying grid
* @param performances the performances
* @param type the type of performance
*/
public PerformanceTable(GridSearch owner, Grid grid,
Vector performances, int type) {
super();
m_Owner = owner;
m_Grid = grid;
m_Type = type;
m_Performances = performances;
generate();
}
/**
* generates the table.
*/
protected void generate() {
Performance perf;
int i;
PointInt location;
m_Table = new double[getGrid().height()][getGrid().width()];
m_Min = 0;
m_Max = 0;
for (i = 0; i < getPerformances().size(); i++) {
perf = getPerformances().get(i);
location = getGrid().getLocation(perf.getValues());
m_Table[getGrid().height() - (int) location.getY() - 1][(int) location
.getX()] = perf.getPerformance(getType());
// determine min/max
if (i == 0) {
m_Min = perf.getPerformance(m_Type);
m_Max = m_Min;
} else {
if (perf.getPerformance(m_Type) < m_Min) {
m_Min = perf.getPerformance(m_Type);
}
if (perf.getPerformance(m_Type) > m_Max) {
m_Max = perf.getPerformance(m_Type);
}
}
}
}
/**
* returns the corresponding grid.
*
* @return the underlying grid
*/
public Grid getGrid() {
return m_Grid;
}
/**
* returns the underlying performances.
*
* @return the underlying performances
*/
public Vector getPerformances() {
return m_Performances;
}
/**
* returns the type of performance.
*
* @return the type of performance
*/
public int getType() {
return m_Type;
}
/**
* returns the generated table.
*
* @return the performance table
* @see #m_Table
* @see #generate()
*/
public double[][] getTable() {
return m_Table;
}
/**
* the minimum performance.
*
* @return the performance
*/
public double getMin() {
return m_Min;
}
/**
* the maximum performance.
*
* @return the performance
*/
public double getMax() {
return m_Max;
}
/**
* returns the table as string.
*
* @return the table as string
*/
@Override
public String toString() {
String result;
int i;
int n;
result = "Table ("
+ new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag()
.getReadable() + ") - " + "X: " + getGrid().getLabelX() + ", Y: "
+ getGrid().getLabelY() + ":\n";
for (i = 0; i < getTable().length; i++) {
if (i > 0) {
result += "\n";
}
for (n = 0; n < getTable()[i].length; n++) {
if (n > 0) {
result += ",";
}
result += getTable()[i][n];
}
}
return result;
}
/**
* returns a string containing a gnuplot script+data file.
*
* @return the data in gnuplot format
*/
public String toGnuplot() {
StringBuffer result;
Tag type;
int i;
result = new StringBuffer();
type = new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag();
result.append("Gnuplot (" + type.getReadable() + "):\n");
result.append("# begin 'gridsearch.data'\n");
result.append("# " + type.getReadable() + "\n");
for (i = 0; i < getPerformances().size(); i++) {
result.append(getPerformances().get(i).toGnuplot(type.getID()) + "\n");
}
result.append("# end 'gridsearch.data'\n\n");
result.append("# begin 'gridsearch.plot'\n");
result.append("# " + type.getReadable() + "\n");
result.append("set data style lines\n");
result.append("set contour base\n");
result.append("set surface\n");
result.append("set title '" + m_Owner.getData().relationName() + "'\n");
result.append("set xrange [" + getGrid().getMinX() + ":"
+ getGrid().getMaxX() + "]\n");
result.append("set xlabel 'x ("
+ m_Owner.getClassifier().getClass().getName() + ": "
+ m_Owner.getXProperty() + ")'\n");
result.append("set yrange [" + getGrid().getMinY() + ":"
+ getGrid().getMaxY() + "]\n");
result.append("set ylabel 'y - ("
+ m_Owner.getClassifier().getClass().getName() + ": "
+ m_Owner.getYProperty() + ")'\n");
result.append("set zrange [" + (getMin() - (getMax() - getMin()) * 0.1)
+ ":" + (getMax() + (getMax() - getMin()) * 0.1) + "]\n");
result.append("set zlabel 'z - " + type.getReadable() + "'\n");
result.append("set dgrid3d " + getGrid().height() + ","
+ getGrid().width() + ",1\n");
result.append("show contour\n");
result.append("splot 'gridsearch.data'\n");
result.append("pause -1\n");
result.append("# end 'gridsearch.plot'");
return result.toString();
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 14326 $");
}
}
/**
* Represents a simple cache for performance objects.
*/
protected static class PerformanceCache implements Serializable,
RevisionHandler {
/** for serialization. */
private static final long serialVersionUID = 5838863230451530252L;
/** the cache for points in the grid that got calculated. */
protected Hashtable m_Cache =
new Hashtable();
/**
* returns the ID string for a cache item.
*
* @param cv the number of folds in the cross-validation
* @param values the point in the grid
* @return the ID string
*/
protected String getID(int cv, PointDouble values) {
return cv + "\t" + values.getX() + "\t" + values.getY();
}
/**
* checks whether the point was already calculated ones.
*
* @param cv the number of folds in the cross-validation
* @param values the point in the grid
* @return true if the value is already cached
*/
public boolean isCached(int cv, PointDouble values) {
return (get(cv, values) != null);
}
/**
* returns a cached performance object, null if not yet in the cache.
*
* @param cv the number of folds in the cross-validation
* @param values the point in the grid
* @return the cached performance item, null if not in cache
*/
public Performance get(int cv, PointDouble values) {
return m_Cache.get(getID(cv, values));
}
/**
* adds the performance to the cache.
*
* @param cv the number of folds in the cross-validation
* @param p the performance object to store
*/
public void add(int cv, Performance p) {
m_Cache.put(getID(cv, p.getValues()), p);
}
/**
* returns a string representation of the cache.
*
* @return the string representation of the cache
*/
@Override
public String toString() {
return m_Cache.toString();
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 14326 $");
}
}
/**
* Helper class for generating the setups.
*/
protected static class SetupGenerator implements Serializable,
RevisionHandler {
/** for serialization. */
private static final long serialVersionUID = -2517395033342543417L;
/** variables exposed to expressions. */
private static final SimpleVariableDeclarations variables =
new SimpleVariableDeclarations();
static {
variables.addDouble("BASE");
variables.addDouble("FROM");
variables.addDouble("TO");
variables.addDouble("STEP");
variables.addDouble("I");
}
/** the owner. */
protected GridSearch m_Owner;
/** the Y option to work on. */
protected String m_Y_Property;
/** the minimum of Y. */
protected double m_Y_Min;
/** the maximum of Y. */
protected double m_Y_Max;
/** the step size of Y. */
protected double m_Y_Step;
/** the base for Y. */
protected double m_Y_Base;
/** The expression for the Y property. */
protected String m_Y_Expression;
/** the compiled expression for the Y property. */
private DoubleExpression m_Y_Node;
/** the X option to work on. */
protected String m_X_Property;
/** the minimum of X. */
protected double m_X_Min;
/** the maximum of X. */
protected double m_X_Max;
/** the step size of X. */
protected double m_X_Step;
/** the base for X. */
protected double m_X_Base;
/** the compiled expression for the X property. */
private DoubleExpression m_X_Node;
/** The expression for the X property. */
protected String m_X_Expression;
/**
* Initializes the setup generator.
*
* @param owner the owning classifier
*/
public SetupGenerator(GridSearch owner) {
super();
m_Owner = owner;
m_Y_Expression = m_Owner.getYExpression();
m_Y_Property = m_Owner.getYProperty();
m_Y_Min = m_Owner.getYMin();
m_Y_Max = m_Owner.getYMax();
m_Y_Step = m_Owner.getYStep();
m_Y_Base = m_Owner.getYBase();
// precompile expression
try {
m_Y_Node = (DoubleExpression) Parser.parse(
// expression
m_Y_Expression,
// variables
variables,
// macros
new MacroDeclarationsCompositor(
new MathFunctions(),
new IfElseMacro(),
new JavaMacro()
)
);
} catch (Exception e) {
m_Y_Node = null;
System.err.println("Failed to compile Y expression '"
+ m_Y_Expression + "'");
e.printStackTrace();
}
m_X_Expression = m_Owner.getXExpression();
m_X_Property = m_Owner.getXProperty();
m_X_Min = m_Owner.getXMin();
m_X_Max = m_Owner.getXMax();
m_X_Step = m_Owner.getXStep();
m_X_Base = m_Owner.getXBase();
try {
// precompile expression
m_X_Node = (DoubleExpression) Parser.parse(
// expression
m_X_Expression,
// variables
variables,
// macros
new MacroDeclarationsCompositor(
new MathFunctions(),
new IfElseMacro(),
new JavaMacro()
)
);
} catch (Exception e) {
m_X_Node = null;
System.err.println("Failed to compile X expression '"
+ m_X_Expression + "'");
e.printStackTrace();
}
}
/**
* evalutes the expression for the current iteration.
*
* @param value the current iteration value (from 'min' to 'max' with
* stepsize 'step')
* @param isX true if X is to be evaluated otherwise Y
* @return the generated value, NaN if the evaluation fails
*/
public double evaluate(double value, boolean isX) {
DoubleExpression expr;
if (isX) {
if (variables.getInitializer().hasVariable("BASE"))
variables.getInitializer().setDouble("BASE", m_X_Base);
if (variables.getInitializer().hasVariable("FROM"))
variables.getInitializer().setDouble("FROM", m_X_Min);
if (variables.getInitializer().hasVariable("TO"))
variables.getInitializer().setDouble("TO", m_X_Max);
if (variables.getInitializer().hasVariable("STEP"))
variables.getInitializer().setDouble("STEP", m_X_Step);
expr = m_X_Node;
} else {
if (variables.getInitializer().hasVariable("BASE"))
variables.getInitializer().setDouble("BASE", m_Y_Base);
if (variables.getInitializer().hasVariable("FROM"))
variables.getInitializer().setDouble("FROM", m_Y_Min);
if (variables.getInitializer().hasVariable("TO"))
variables.getInitializer().setDouble("TO", m_Y_Max);
if (variables.getInitializer().hasVariable("STEP"))
variables.getInitializer().setDouble("STEP", m_Y_Step);
expr = m_Y_Node;
}
if (variables.getInitializer().hasVariable("I"))
variables.getInitializer().setDouble("I", value);
try {
return expr.evaluate();
} catch (Exception e) {
e.printStackTrace();
return Double.NaN;
}
}
/**
* tries to set the value as double, integer (just casts it to int!) or
* boolean (false if 0, otherwise true) in the object according to the
* specified path. float, char and long are also supported.
*
* @param o the object to modify
* @param path the property path
* @param value the value to set
* @return the modified object
* @throws Exception if neither double nor int could be set
*/
public Object setValue(Object o, String path, double value)
throws Exception {
PropertyDescriptor desc;
Class> c;
desc = PropertyPath.getPropertyDescriptor(o, path);
if (desc == null) {
throw new IllegalArgumentException("Failed to set property " + path + " on object " + o.getClass().getName());
}
c = desc.getPropertyType();
// float
if ((c == Float.class) || (c == Float.TYPE)) {
PropertyPath.setValue(o, path, new Float((float) value));
} else if ((c == Double.class) || (c == Double.TYPE)) {
PropertyPath.setValue(o, path, new Double(value));
} else if ((c == Character.class) || (c == Character.TYPE)) {
PropertyPath.setValue(o, path, new Integer((char) value));
} else if ((c == Integer.class) || (c == Integer.TYPE)) {
PropertyPath.setValue(o, path, new Integer((int) value));
} else if ((c == Long.class) || (c == Long.TYPE)) {
PropertyPath.setValue(o, path, new Long((long) value));
} else if ((c == Boolean.class) || (c == Boolean.TYPE)) {
PropertyPath.setValue(o, path, (value == 0 ? new Boolean(false)
: new Boolean(true)));
} else {
throw new Exception(
"Could neither set double nor integer nor boolean value for '" + path
+ "'!");
}
return o;
}
/**
* returns a fully configured object (a copy of the provided one).
*
* @param original the object to create a copy from and set the parameters
* @param valueX the current iteration value for X
* @param valueY the current iteration value for Y
* @return the configured classifier
* @throws Exception if setup fails
*/
public Object setup(Object original, double valueX, double valueY)
throws Exception {
Object result;
result = new SerializedObject(original).getObject();
if (original instanceof Classifier) {
setValue(result, m_X_Property, valueX);
setValue(result, m_Y_Property, valueY);
} else {
throw new IllegalArgumentException(
"Object must be a classifier!");
}
return result;
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 14326 $");
}
}
/**
* Helper class for evaluating a setup.
*/
protected static class EvaluationTask implements Callable, RevisionHandler {
/** the owner. */
protected GridSearch m_Owner;
/** for generating the setups. */
protected SetupGenerator m_Generator;
/** the classifier to use. */
protected Classifier m_Classifier;
/** the data to use for training. */
protected Instances m_Data;
/** the values to use. */
protected PointDouble m_Values;
/** the number of folds for cross-validation. */
protected int m_Folds;
/** the type of evaluation. */
protected int m_Evaluation;
/**
* Initializes the task.
*
* @param owner the owning GridSearch classifier
* @param generator the generator for the setips
* @param inst the data
* @param values the values in the grid
* @param folds the number of cross-validation folds
* @param eval the type of evaluation
*/
public EvaluationTask(GridSearch owner, SetupGenerator generator,
Instances inst, PointDouble values, int folds, int eval) {
super();
m_Owner = owner;
m_Generator = generator;
m_Classifier = m_Owner.getClassifier();
m_Data = inst;
m_Values = values;
m_Folds = folds;
m_Evaluation = eval;
}
/**
* Performs the evaluation.
*/
@Override
public Exception call() {
Evaluation eval;
Classifier classifier;
Performance performance;
double x;
double y;
Instances data;
classifier = null;
x = m_Generator.evaluate(m_Values.getX(), true);
y = m_Generator.evaluate(m_Values.getY(), false);
try {
data = m_Data;
// setup classifier
classifier = (Classifier) m_Generator.setup(m_Classifier, x, y);
// evaluate
eval = new Evaluation(data);
eval.crossValidateModel(classifier, data, m_Folds,
new Random(m_Owner.getSeed()));
// store performance
performance = new Performance(m_Values, eval);
m_Owner.addPerformance(performance, m_Folds);
// log
m_Owner.log(performance + ": cached=false");
// release slot
m_Owner.completedEvaluation(classifier, null);
// clean up
m_Owner = null;
m_Data = null;
return null;
} catch (Exception e) {
if (m_Owner.getDebug()) {
System.err
.println("Encountered exception while evaluating classifier, skipping!");
System.err.println("- Values....: " + m_Values);
System.err.println("- Classifier: "
+ ((classifier != null) ? Utils.toCommandLine(classifier)
: "-no setup-"));
e.printStackTrace();
}
m_Owner.completedEvaluation(m_Values, e);
// clean up
m_Owner = null;
m_Data = null;
return e;
}
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 14326 $");
}
}
/** for serialization. */
private static final long serialVersionUID = -3034773968581595348L;
/** evaluation via: Correlation coefficient. */
public static final int EVALUATION_CC = 0;
/** evaluation via: Root mean squared error. */
public static final int EVALUATION_RMSE = 1;
/** evaluation via: Root relative squared error. */
public static final int EVALUATION_RRSE = 2;
/** evaluation via: Mean absolute error. */
public static final int EVALUATION_MAE = 3;
/** evaluation via: Relative absolute error. */
public static final int EVALUATION_RAE = 4;
/** evaluation via: Combined = (1-CC) + RRSE + RAE. */
public static final int EVALUATION_COMBINED = 5;
/** evaluation via: Accuracy. */
public static final int EVALUATION_ACC = 6;
/** evaluation via: kappa statistic. */
public static final int EVALUATION_KAPPA = 7;
/** evaluation via: weighted AUC */
public static final int EVALUATION_WAUC = 8;
/** evaluation. */
public static final Tag[] TAGS_EVALUATION =
{
new Tag(EVALUATION_CC, "CC", "Correlation coefficient"),
new Tag(EVALUATION_RMSE, "RMSE", "Root mean squared error"),
new Tag(EVALUATION_RRSE, "RRSE", "Root relative squared error"),
new Tag(EVALUATION_MAE, "MAE", "Mean absolute error"),
new Tag(EVALUATION_RAE, "RAE", "Root absolute error"),
new Tag(EVALUATION_COMBINED, "COMB",
"Combined = (1-abs(CC)) + RRSE + RAE"),
new Tag(EVALUATION_ACC, "ACC", "Accuracy"),
new Tag(EVALUATION_WAUC, "WAUC", "Weighted AUC"),
new Tag(EVALUATION_KAPPA, "KAP", "Kappa") };
/** row-wise grid traversal. */
public static final int TRAVERSAL_BY_ROW = 0;
/** column-wise grid traversal. */
public static final int TRAVERSAL_BY_COLUMN = 1;
/** traversal. */
public static final Tag[] TAGS_TRAVERSAL = {
new Tag(TRAVERSAL_BY_ROW, "row-wise", "row-wise"),
new Tag(TRAVERSAL_BY_COLUMN, "column-wise", "column-wise") };
/** the Classifier with the best setup. */
protected Classifier m_BestClassifier;
/** the best values. */
protected PointDouble m_Values = null;
/** the type of evaluation. */
protected int m_Evaluation = EVALUATION_CC;
/**
* the Y option to work on (without leading dash).
*/
protected String m_Y_Property = "kernel.gamma";
/** the minimum of Y. */
protected double m_Y_Min = -3;
/** the maximum of Y. */
protected double m_Y_Max = +3;
/** the step size of Y. */
protected double m_Y_Step = 1;
/** the base for Y. */
protected double m_Y_Base = 10;
/**
* The expression for the Y property. Available parameters for the expression:
*
* - BASE
* - FROM (= min)
* - TO (= max)
* - STEP
* - I - the current value (from 'from' to 'to' with stepsize 'step')
*
*
* @see MathExpression
*/
protected String m_Y_Expression = "pow(BASE,I)";
/**
* the X option to work on (without leading dash)
*/
protected String m_X_Property = "C";
/** the minimum of X. */
protected double m_X_Min = -3;
/** the maximum of X. */
protected double m_X_Max = 3;
/** the step size of X. */
protected double m_X_Step = 1;
/** the base for X. */
protected double m_X_Base = 10;
/**
* The expression for the X property. Available parameters for the expression:
*
* - BASE
* - FROM (= min)
* - TO (= max)
* - STEP
* - I - the current value (from 'from' to 'to' with stepsize 'step')
*
*
* @see MathExpression
*/
protected String m_X_Expression = "pow(BASE,I)";
/** whether the grid can be extended. */
protected boolean m_GridIsExtendable = false;
/** maximum number of grid extensions (-1 means unlimited). */
protected int m_MaxGridExtensions = 3;
/** the number of extensions performed. */
protected int m_GridExtensionsPerformed = 0;
/** the sample size to search the initial grid with. */
protected double m_SampleSize = 100;
/** the traversal. */
protected int m_Traversal = TRAVERSAL_BY_COLUMN;
/** the log file to use. */
protected File m_LogFile = new File(System.getProperty("user.dir"));
/** the value-pairs grid. */
protected Grid m_Grid;
/** the training data. */
protected Instances m_Data;
/** the cache for points in the grid that got calculated. */
protected PerformanceCache m_Cache;
/** for storing the performances. */
protected Vector m_Performances;
/** whether all performances in the grid are the same. */
protected boolean m_UniformPerformance = false;
/** The number of threads to have executing at any one time. */
protected int m_NumExecutionSlots = 1;
/** Pool of threads to train models with. */
protected transient ExecutorService m_ExecutorPool;
/** The number of setups completed so far. */
protected int m_Completed;
/**
* The number of setups that experienced a failure of some sort during
* construction.
*/
protected int m_Failed;
/** the number of setups to evaluate. */
protected int m_NumSetups;
/** the generator for generating the setups. */
protected SetupGenerator m_Generator;
/** for storing an exception that happened in one of the worker threads. */
protected transient Exception m_Exception;
/** The properties file containing default settings */
protected static final String PROPERTY_FILE =
"weka/classifiers/meta/GridSearch.props";
/** The properties object holding the loaded defaults */
protected static Properties GRID_SEARCH_PROPS;
static {
try {
GRID_SEARCH_PROPS =
Utils.readProperties("weka/classifiers/meta/GridSearch.props");
} catch (Exception ex) {
ex.printStackTrace();
}
}
/**
* the default constructor.
*/
public GridSearch() {
super();
defaultsFromProps();
try {
m_BestClassifier = AbstractClassifier.makeCopy(m_Classifier);
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* Set defaults for all options from a properties file. A default properties
* file is included in the gridSearch jar file
* (weka/classifiers/meta/GridSearch.props) - this can be copied, altered and
* placed into ${WEKA_HOME}/props
*/
protected void defaultsFromProps() {
try {
if (GRID_SEARCH_PROPS != null) {
String classifierSpec = GRID_SEARCH_PROPS.getProperty("classifier");
if (classifierSpec != null && classifierSpec.length() > 0) {
String[] spec = Utils.splitOptions(classifierSpec);
String classifier = spec[0];
spec[0] = "";
boolean ok = true;
try {
Classifier result = AbstractClassifier.forName(classifier, spec);
setClassifier(result);
// continue with the remaining defaults
String yProp = GRID_SEARCH_PROPS.getProperty("yProperty", "");
String yMin = GRID_SEARCH_PROPS.getProperty("yMin", "");
String yMax = GRID_SEARCH_PROPS.getProperty("yMax", "");
String yStep = GRID_SEARCH_PROPS.getProperty("yStep", "");
String yBase = GRID_SEARCH_PROPS.getProperty("yBase", "");
String yExpression =
GRID_SEARCH_PROPS.getProperty("yExpression", "");
if (yProp.length() > 0 && yMin.length() > 0 && yMax.length() > 0
&& yStep.length() > 0 && yBase.length() > 0
&& yExpression.length() > 0) {
setYProperty(yProp);
setYMin(Double.parseDouble(yMin));
setYMax(Double.parseDouble(yMax));
setYStep(Double.parseDouble(yStep));
setYBase(Double.parseDouble(yBase));
setYExpression(yExpression);
} else {
ok = false;
}
String xProp = GRID_SEARCH_PROPS.getProperty("xProperty", "");
String xMin = GRID_SEARCH_PROPS.getProperty("xMin", "");
String xMax = GRID_SEARCH_PROPS.getProperty("xMax", "");
String xStep = GRID_SEARCH_PROPS.getProperty("xStep", "");
String xBase = GRID_SEARCH_PROPS.getProperty("xBase", "");
String xExpression =
GRID_SEARCH_PROPS.getProperty("xExpression", "");
if (xProp.length() > 0 && xMin.length() > 0 && xMax.length() > 0
&& xStep.length() > 0 && xBase.length() > 0
&& xExpression.length() > 0) {
setXProperty(xProp);
setXMin(Double.parseDouble(xMin));
setXMax(Double.parseDouble(xMax));
setXStep(Double.parseDouble(xStep));
setXBase(Double.parseDouble(xBase));
setXExpression(xExpression);
} else {
ok = false;
}
// optionals
String gridExtend =
GRID_SEARCH_PROPS.getProperty("gridIsExtendable", "false");
setGridIsExtendable(Boolean.parseBoolean(gridExtend));
String maxExtensions =
GRID_SEARCH_PROPS.getProperty("maxGridExtensions", "3");
setMaxGridExtensions(Integer.parseInt(maxExtensions));
String sampleSizePerc =
GRID_SEARCH_PROPS.getProperty("sampleSizePercent", "100");
setSampleSizePercent(Integer.parseInt(sampleSizePerc));
String traversal = GRID_SEARCH_PROPS.getProperty("traversal", "0");
m_Traversal = Integer.parseInt(traversal);
String eval = GRID_SEARCH_PROPS.getProperty("evaluation", "0");
m_Evaluation = Integer.parseInt(eval);
String numSlots = GRID_SEARCH_PROPS.getProperty("numSlots", "1");
setNumExecutionSlots(Integer.parseInt(numSlots));
} catch (Exception ex) {
// continue with the default of GaussianProcesses
ok = false;
}
if (!ok) {
// setup GaussianProcesses
setClassifier(new weka.classifiers.functions.GaussianProcesses());
setYProperty("kernel.exponent");
setYMin(1);
setYMax(5);
setYStep(1);
setYBase(10);
setYExpression("I");
setXProperty("noise");
setXMin(0.2);
setXMax(2);
setXStep(0.2);
setXBase(10);
setXExpression("I");
}
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* Returns a string describing classifier.
*
* @return a description suitable for displaying in the explorer/experimenter
* gui
*/
public String globalInfo() {
return "Performs a grid search of parameter pairs for a classifier and chooses the best "
+ "pair found for the actual predicting.\n\n"
+ "The initial grid is worked on with 2-fold CV to determine the values "
+ "of the parameter pairs for the selected type of evaluation (e.g., "
+ "accuracy). The best point in the grid is then taken and a 10-fold CV "
+ "is performed with the adjacent parameter pairs. If a better pair is "
+ "found, then this will act as new center and another 10-fold CV will "
+ "be performed (kind of hill-climbing). This process is repeated until "
+ "no better pair is found or the best pair is on the border of the grid.\n"
+ "In case the best pair is on the border, one can let GridSearch "
+ "automatically extend the grid and continue the search. Check out the "
+ "properties 'gridIsExtendable' (option '-extend-grid') and "
+ "'maxGridExtensions' (option '-max-grid-extensions ').\n\n"
+ "GridSearch can handle doubles, integers (values are just cast to int) "
+ "and booleans (0 is false, otherwise true). float, char and long are "
+ "supported as well.\n\n"
+ "The best classifier setup can be accessed after the buildClassifier "
+ "call via the getBestClassifier methods.\n"
+ "Note: with -num-slots/numExecutionSlots you can specify how many "
+ "setups are evaluated in parallel, taking advantage of multi-cpu/core "
+ "architectures.";
}
/**
* Gets the classifier specification string, which contains the class name of
* the classifier and any options to the classifier.
*
* @return the classifier string.
*/
@Override
protected String getClassifierSpec() {
Classifier c = getClassifier();
if (c instanceof OptionHandler) {
return c.getClass().getName() + " "
+ Utils.joinOptions(((OptionHandler) c).getOptions());
}
return c.getClass().getName();
}
/**
* String describing default classifier.
*
* @return the classname of the default classifier
*/
@Override
protected String defaultClassifierString() {
try {
if (GRID_SEARCH_PROPS != null) {
String classifierSpec = GRID_SEARCH_PROPS.getProperty("classifier");
if (classifierSpec != null && classifierSpec.length() > 0) {
String[] parts = classifierSpec.split(" ");
if (parts.length > 0) {
return parts[0].trim();
}
}
}
} catch (Exception ex) {
// don't complain here
}
return "weka.classifiers.functions.GaussianProcesses";
}
/**
* String array with default classifier options.
*
* @return string array with default classifier options
*/
@Override
protected String[] defaultClassifierOptions() {
try {
if (GRID_SEARCH_PROPS != null) {
String classifierSpec = GRID_SEARCH_PROPS.getProperty("classifier");
if (classifierSpec != null && classifierSpec.length() > 0) {
String[] parts = Utils.splitOptions(classifierSpec);
if (parts.length > 1) {
// return Utils.splitOptions(parts[1]);
parts[0] = "";
return parts;
}
}
}
} catch (Exception ex) {
// don't complain here
}
String[] opts =
// { "-K", "weka.classifiers.functions.supportVector.RBFKernel" };
{};
return opts;
}
/**
* Gets an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration
© 2015 - 2025 Weber Informatics LLC | Privacy Policy