weka.classifiers.trees.SimpleCart Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-stable Show documentation
Show all versions of weka-stable Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This is the stable version. Apart from bugfixes, this version
does not receive any other updates.
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* SimpleCart.java
* Copyright (C) 2007 Haijian Shi
*
*/
package weka.classifiers.trees;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableClassifier;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.Capabilities.Capability;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.matrix.Matrix;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
/**
* Class implementing minimal cost-complexity pruning.
* Note when dealing with missing values, use "fractional instances" method instead of surrogate split method.
*
* For more information, see:
*
* Leo Breiman, Jerome H. Friedman, Richard A. Olshen, Charles J. Stone (1984). Classification and Regression Trees. Wadsworth International Group, Belmont, California.
*
*
* BibTeX:
*
* @book{Breiman1984,
* address = {Belmont, California},
* author = {Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone},
* publisher = {Wadsworth International Group},
* title = {Classification and Regression Trees},
* year = {1984}
* }
*
*
*
* Valid options are:
*
* -S <num>
* Random number seed.
* (default 1)
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -M <min no>
* The minimal number of instances at the terminal nodes.
* (default 2)
*
* -N <num folds>
* The number of folds used in the minimal cost-complexity pruning.
* (default 5)
*
* -U
* Don't use the minimal cost-complexity pruning.
* (default yes).
*
* -H
* Don't use the heuristic method for binary split.
* (default true).
*
* -A
* Use 1 SE rule to make pruning decision.
* (default no).
*
* -C
* Percentage of training data size (0-1].
* (default 1).
*
*
* @author Haijian Shi ([email protected])
* @version $Revision: 10491 $
*/
public class SimpleCart
extends RandomizableClassifier
implements AdditionalMeasureProducer, TechnicalInformationHandler {
/** For serialization. */
private static final long serialVersionUID = 4154189200352566053L;
/** Training data. */
protected Instances m_train;
/** Successor nodes. */
protected SimpleCart[] m_Successors;
/** Attribute used to split data. */
protected Attribute m_Attribute;
/** Split point for a numeric attribute. */
protected double m_SplitValue;
/** Split subset used to split data for nominal attributes. */
protected String m_SplitString;
/** Class value if the node is leaf. */
protected double m_ClassValue;
/** Class attriubte of data. */
protected Attribute m_ClassAttribute;
/** Minimum number of instances in at the terminal nodes. */
protected double m_minNumObj = 2;
/** Number of folds for minimal cost-complexity pruning. */
protected int m_numFoldsPruning = 5;
/** Alpha-value (for pruning) at the node. */
protected double m_Alpha;
/** Number of training examples misclassified by the model (subtree rooted). */
protected double m_numIncorrectModel;
/** Number of training examples misclassified by the model (subtree not rooted). */
protected double m_numIncorrectTree;
/** Indicate if the node is a leaf node. */
protected boolean m_isLeaf;
/** If use minimal cost-compexity pruning. */
protected boolean m_Prune = true;
/** Total number of instances used to build the classifier. */
protected int m_totalTrainInstances;
/** Proportion for each branch. */
protected double[] m_Props;
/** Class probabilities. */
protected double[] m_ClassProbs = null;
/** Distributions of leaf node (or temporary leaf node in minimal cost-complexity pruning) */
protected double[] m_Distribution;
/** If use huristic search for nominal attributes in multi-class problems (default true). */
protected boolean m_Heuristic = true;
/** If use the 1SE rule to make final decision tree. */
protected boolean m_UseOneSE = false;
/** Training data size. */
protected double m_SizePer = 1;
/**
* Return a description suitable for displaying in the explorer/experimenter.
*
* @return a description suitable for displaying in the
* explorer/experimenter
*/
public String globalInfo() {
return
"Class implementing minimal cost-complexity pruning.\n"
+ "Note when dealing with missing values, use \"fractional "
+ "instances\" method instead of surrogate split method.\n\n"
+ "For more information, see:\n\n"
+ getTechnicalInformation().toString();
}
/**
* Returns an instance of a TechnicalInformation object, containing
* detailed information about the technical background of this class,
* e.g., paper reference or book this class is based on.
*
* @return the technical information about this class
*/
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
result = new TechnicalInformation(Type.BOOK);
result.setValue(Field.AUTHOR, "Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone");
result.setValue(Field.YEAR, "1984");
result.setValue(Field.TITLE, "Classification and Regression Trees");
result.setValue(Field.PUBLISHER, "Wadsworth International Group");
result.setValue(Field.ADDRESS, "Belmont, California");
return result;
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
result.enable(Capability.NOMINAL_CLASS);
return result;
}
/**
* Build the classifier.
*
* @param data the training instances
* @throws Exception if something goes wrong
*/
public void buildClassifier(Instances data) throws Exception {
getCapabilities().testWithFail(data);
data = new Instances(data);
data.deleteWithMissingClass();
// unpruned CART decision tree
if (!m_Prune) {
// calculate sorted indices and weights, and compute initial class counts.
int[][] sortedIndices = new int[data.numAttributes()][0];
double[][] weights = new double[data.numAttributes()][0];
double[] classProbs = new double[data.numClasses()];
double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);
makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
totalWeight,m_minNumObj, m_Heuristic);
return;
}
Random random = new Random(m_Seed);
Instances cvData = new Instances(data);
cvData.randomize(random);
cvData = new Instances(cvData,0,(int)(cvData.numInstances()*m_SizePer)-1);
cvData.stratify(m_numFoldsPruning);
double[][] alphas = new double[m_numFoldsPruning][];
double[][] errors = new double[m_numFoldsPruning][];
// calculate errors and alphas for each fold
for (int i = 0; i < m_numFoldsPruning; i++) {
//for every fold, grow tree on training set and fix error on test set.
Instances train = cvData.trainCV(m_numFoldsPruning, i);
Instances test = cvData.testCV(m_numFoldsPruning, i);
// calculate sorted indices and weights, and compute initial class counts for each fold
int[][] sortedIndices = new int[train.numAttributes()][0];
double[][] weights = new double[train.numAttributes()][0];
double[] classProbs = new double[train.numClasses()];
double totalWeight = computeSortedInfo(train,sortedIndices, weights,classProbs);
makeTree(train, train.numInstances(),sortedIndices,weights,classProbs,
totalWeight,m_minNumObj, m_Heuristic);
int numNodes = numInnerNodes();
alphas[i] = new double[numNodes + 2];
errors[i] = new double[numNodes + 2];
// prune back and log alpha-values and errors on test set
prune(alphas[i], errors[i], test);
}
// calculate sorted indices and weights, and compute initial class counts on all training instances
int[][] sortedIndices = new int[data.numAttributes()][0];
double[][] weights = new double[data.numAttributes()][0];
double[] classProbs = new double[data.numClasses()];
double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);
//build tree using all the data
makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
totalWeight,m_minNumObj, m_Heuristic);
int numNodes = numInnerNodes();
double[] treeAlphas = new double[numNodes + 2];
// prune back and log alpha-values
int iterations = prune(treeAlphas, null, null);
double[] treeErrors = new double[numNodes + 2];
// for each pruned subtree, find the cross-validated error
for (int i = 0; i <= iterations; i++){
//compute midpoint alphas
double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i+1]);
double error = 0;
for (int k = 0; k < m_numFoldsPruning; k++) {
int l = 0;
while (alphas[k][l] <= alpha) l++;
error += errors[k][l - 1];
}
treeErrors[i] = error/m_numFoldsPruning;
}
// find best alpha
int best = -1;
double bestError = Double.MAX_VALUE;
for (int i = iterations; i >= 0; i--) {
if (treeErrors[i] < bestError) {
bestError = treeErrors[i];
best = i;
}
}
// 1 SE rule to choose expansion
if (m_UseOneSE) {
double oneSE = Math.sqrt(bestError*(1-bestError)/(data.numInstances()));
for (int i = iterations; i >= 0; i--) {
if (treeErrors[i] <= bestError+oneSE) {
best = i;
break;
}
}
}
double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);
//"unprune" final tree (faster than regrowing it)
unprune();
prune(bestAlpha);
}
/**
* Make binary decision tree recursively.
*
* @param data the training instances
* @param totalInstances total number of instances
* @param sortedIndices sorted indices of the instances
* @param weights weights of the instances
* @param classProbs class probabilities
* @param totalWeight total weight of instances
* @param minNumObj minimal number of instances at leaf nodes
* @param useHeuristic if use heuristic search for nominal attributes in multi-class problem
* @throws Exception if something goes wrong
*/
protected void makeTree(Instances data, int totalInstances, int[][] sortedIndices,
double[][] weights, double[] classProbs, double totalWeight, double minNumObj,
boolean useHeuristic) throws Exception{
// if no instances have reached this node (normally won't happen)
if (totalWeight == 0){
m_Attribute = null;
m_ClassValue = Instance.missingValue();
m_Distribution = new double[data.numClasses()];
return;
}
m_totalTrainInstances = totalInstances;
m_isLeaf = true;
m_Successors = null;
m_ClassProbs = new double[classProbs.length];
m_Distribution = new double[classProbs.length];
System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length);
if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs);
// Compute class distributions and value of splitting
// criterion for each attribute
double[][][] dists = new double[data.numAttributes()][0][0];
double[][] props = new double[data.numAttributes()][0];
double[][] totalSubsetWeights = new double[data.numAttributes()][2];
double[] splits = new double[data.numAttributes()];
String[] splitString = new String[data.numAttributes()];
double[] giniGains = new double[data.numAttributes()];
// for each attribute find split information
for (int i = 0; i < data.numAttributes(); i++) {
Attribute att = data.attribute(i);
if (i==data.classIndex()) continue;
if (att.isNumeric()) {
// numeric attribute
splits[i] = numericDistribution(props, dists, att, sortedIndices[i],
weights[i], totalSubsetWeights, giniGains, data);
} else {
// nominal attribute
splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i],
weights[i], totalSubsetWeights, giniGains, data, useHeuristic);
}
}
// Find best attribute (split with maximum Gini gain)
int attIndex = Utils.maxIndex(giniGains);
m_Attribute = data.attribute(attIndex);
m_train = new Instances(data, sortedIndices[attIndex].length);
for (int i=0; i 0);
double preAlpha = Double.MAX_VALUE;
while (prune) {
// select node with minimum alpha
SimpleCart nodeToPrune = nodeToPrune(nodeList);
// want to prune if its alpha is smaller than alpha
if (nodeToPrune.m_Alpha > alpha) {
break;
}
nodeToPrune.makeLeaf(nodeToPrune.m_train);
// normally would not happen
if (nodeToPrune.m_Alpha==preAlpha) {
nodeToPrune.makeLeaf(nodeToPrune.m_train);
treeErrors();
calculateAlphas();
nodeList = getInnerNodes();
prune = (nodeList.size() > 0);
continue;
}
preAlpha = nodeToPrune.m_Alpha;
//update tree errors and alphas
treeErrors();
calculateAlphas();
nodeList = getInnerNodes();
prune = (nodeList.size() > 0);
}
}
/**
* Method for performing one fold in the cross-validation of minimal
* cost-complexity pruning. Generates a sequence of alpha-values with error
* estimates for the corresponding (partially pruned) trees, given the test
* set of that fold.
*
* @param alphas array to hold the generated alpha-values
* @param errors array to hold the corresponding error estimates
* @param test test set of that fold (to obtain error estimates)
* @return the iteration of the pruning
* @throws Exception if something goes wrong
*/
public int prune(double[] alphas, double[] errors, Instances test)
throws Exception {
Vector nodeList;
// determine training error of subtrees (both with and without replacing a subtree),
// and calculate alpha-values from them
modelErrors();
treeErrors();
calculateAlphas();
// get list of all inner nodes in the tree
nodeList = getInnerNodes();
boolean prune = (nodeList.size() > 0);
//alpha_0 is always zero (unpruned tree)
alphas[0] = 0;
Evaluation eval;
// error of unpruned tree
if (errors != null) {
eval = new Evaluation(test);
eval.evaluateModel(this, test);
errors[0] = eval.errorRate();
}
int iteration = 0;
double preAlpha = Double.MAX_VALUE;
while (prune) {
iteration++;
// get node with minimum alpha
SimpleCart nodeToPrune = nodeToPrune(nodeList);
// do not set m_sons null, want to unprune
nodeToPrune.m_isLeaf = true;
// normally would not happen
if (nodeToPrune.m_Alpha==preAlpha) {
iteration--;
treeErrors();
calculateAlphas();
nodeList = getInnerNodes();
prune = (nodeList.size() > 0);
continue;
}
// get alpha-value of node
alphas[iteration] = nodeToPrune.m_Alpha;
// log error
if (errors != null) {
eval = new Evaluation(test);
eval.evaluateModel(this, test);
errors[iteration] = eval.errorRate();
}
preAlpha = nodeToPrune.m_Alpha;
//update errors/alphas
treeErrors();
calculateAlphas();
nodeList = getInnerNodes();
prune = (nodeList.size() > 0);
}
//set last alpha 1 to indicate end
alphas[iteration + 1] = 1.0;
return iteration;
}
/**
* Method to "unprune" the CART tree. Sets all leaf-fields to false.
* Faster than re-growing the tree because CART do not have to be fit again.
*/
protected void unprune() {
if (m_Successors != null) {
m_isLeaf = false;
for (int i = 0; i < m_Successors.length; i++) m_Successors[i].unprune();
}
}
/**
* Compute distributions, proportions and total weights of two successor
* nodes for a given numeric attribute.
*
* @param props proportions of each two branches for each attribute
* @param dists class distributions of two branches for each attribute
* @param att numeric att split on
* @param sortedIndices sorted indices of instances for the attirubte
* @param weights weights of instances for the attirbute
* @param subsetWeights total weight of two branches split based on the attribute
* @param giniGains Gini gains for each attribute
* @param data training instances
* @return Gini gain the given numeric attribute
* @throws Exception if something goes wrong
*/
protected double numericDistribution(double[][] props, double[][][] dists,
Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
double[] giniGains, Instances data)
throws Exception {
double splitPoint = Double.NaN;
double[][] dist = null;
int numClasses = data.numClasses();
int i; // differ instances with or without missing values
double[][] currDist = new double[2][numClasses];
dist = new double[2][numClasses];
// Move all instances without missing values into second subset
double[] parentDist = new double[numClasses];
int missingStart = 0;
for (int j = 0; j < sortedIndices.length; j++) {
Instance inst = data.instance(sortedIndices[j]);
if (!inst.isMissing(att)) {
missingStart ++;
currDist[1][(int)inst.classValue()] += weights[j];
}
parentDist[(int)inst.classValue()] += weights[j];
}
System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);
// Try all possible split points
double currSplit = data.instance(sortedIndices[0]).value(att);
double currGiniGain;
double bestGiniGain = -Double.MAX_VALUE;
for (i = 0; i < sortedIndices.length; i++) {
Instance inst = data.instance(sortedIndices[i]);
if (inst.isMissing(att)) {
break;
}
if (inst.value(att) > currSplit) {
double[][] tempDist = new double[2][numClasses];
for (int k=0; k<2; k++) {
//tempDist[k] = currDist[k];
System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length);
}
double[] tempProps = new double[2];
for (int k=0; k<2; k++) {
tempProps[k] = Utils.sum(tempDist[k]);
}
if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps);
// split missing values
int index = missingStart;
while (index < sortedIndices.length) {
Instance insta = data.instance(sortedIndices[index]);
for (int j = 0; j < 2; j++) {
tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
}
index++;
}
currGiniGain = computeGiniGain(parentDist,tempDist);
if (currGiniGain > bestGiniGain) {
bestGiniGain = currGiniGain;
// clean split point
// splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0;
splitPoint = (inst.value(att) + currSplit) / 2.0;
for (int j = 0; j < currDist.length; j++) {
System.arraycopy(tempDist[j], 0, dist[j], 0,
dist[j].length);
}
}
}
currSplit = inst.value(att);
currDist[0][(int)inst.classValue()] += weights[i];
currDist[1][(int)inst.classValue()] -= weights[i];
}
// Compute weights
int attIndex = att.index();
props[attIndex] = new double[2];
for (int k = 0; k < 2; k++) {
props[attIndex][k] = Utils.sum(dist[k]);
}
if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]);
// Compute subset weights
subsetWeights[attIndex] = new double[2];
for (int j = 0; j < 2; j++) {
subsetWeights[attIndex][j] += Utils.sum(dist[j]);
}
// clean Gini gain
//giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
giniGains[attIndex] = bestGiniGain;
dists[attIndex] = dist;
return splitPoint;
}
/**
* Compute distributions, proportions and total weights of two successor
* nodes for a given nominal attribute.
*
* @param props proportions of each two branches for each attribute
* @param dists class distributions of two branches for each attribute
* @param att numeric att split on
* @param sortedIndices sorted indices of instances for the attirubte
* @param weights weights of instances for the attirbute
* @param subsetWeights total weight of two branches split based on the attribute
* @param giniGains Gini gains for each attribute
* @param data training instances
* @param useHeuristic if use heuristic search
* @return Gini gain for the given nominal attribute
* @throws Exception if something goes wrong
*/
protected String nominalDistribution(double[][] props, double[][][] dists,
Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
double[] giniGains, Instances data, boolean useHeuristic)
throws Exception {
String[] values = new String[att.numValues()];
int numCat = values.length; // number of values of the attribute
int numClasses = data.numClasses();
String bestSplitString = "";
double bestGiniGain = -Double.MAX_VALUE;
// class frequency for each value
int[] classFreq = new int[numCat];
for (int j=0; jbestGiniGain) {
bestGiniGain = currGiniGain;
bestSplitString = tempStr;
for (int jj = 0; jj < 2; jj++) {
//dist[jj] = new double[currDist[jj].length];
System.arraycopy(tempDist[jj], 0, dist[jj], 0,
dist[jj].length);
}
}
}
}
// multi-class problems - exhaustive search
else if (!useHeuristic || nonEmpty<=4) {
// Firstly, for attribute values which class frequency is not zero
for (int i=0; i<(int)Math.pow(2,nonEmpty-1); i++) {
String tempStr="";
currDist = new double[2][numClasses];
int mod;
int bit10 = i;
for (int j=nonEmpty-1; j>=0; j--) {
mod = bit10%2; // convert from 10bit to 2bit
if (mod==1) {
if (tempStr=="") tempStr = "("+nonEmptyValues[j]+")";
else tempStr += "|" + "("+nonEmptyValues[j]+")";
}
bit10 = bit10/2;
}
for (int j=0; jbestGiniGain) {
bestGiniGain = currGiniGain;
bestSplitString = tempStr;
for (int j = 0; j < 2; j++) {
//dist[jj] = new double[currDist[jj].length];
System.arraycopy(tempDist[j], 0, dist[j], 0,
dist[j].length);
}
}
}
}
// huristic search to solve multi-classes problems
else {
// Firstly, for attribute values which class frequency is not zero
int n = nonEmpty;
int k = data.numClasses(); // number of classes of the data
double[][] P = new double[n][k]; // class probability matrix
int[] numInstancesValue = new int[n]; // number of instances for an attribute value
double[] meanClass = new double[k]; // vector of mean class probability
int numInstances = data.numInstances(); // total number of instances
// initialize the vector of mean class probability
for (int j=0; jlargest) {
index=i;
largest = eigenValues[i];
}
}
// calculate the first principle component
double[] FPC = new double[k];
Matrix eigenVector = eigen.getV();
double[][] vectorArray = eigenVector.getArray();
for (int i=0; ibestGiniGain) {
bestGiniGain = currGiniGain;
bestSplitString = tempStr;
for (int jj = 0; jj < 2; jj++) {
//dist[jj] = new double[currDist[jj].length];
System.arraycopy(tempDist[jj], 0, dist[jj], 0,
dist[jj].length);
}
}
}
}
// Compute weights
int attIndex = att.index();
props[attIndex] = new double[2];
for (int k = 0; k < 2; k++) {
props[attIndex][k] = Utils.sum(dist[k]);
}
if (!(Utils.sum(props[attIndex]) > 0)) {
for (int k = 0; k < props[attIndex].length; k++) {
props[attIndex][k] = 1.0 / (double)props[attIndex].length;
}
} else {
Utils.normalize(props[attIndex]);
}
// Compute subset weights
subsetWeights[attIndex] = new double[2];
for (int j = 0; j < 2; j++) {
subsetWeights[attIndex][j] += Utils.sum(dist[j]);
}
// Then, for the attribute values that class frequency is 0, split it into the
// most frequent branch
for (int j=0; j=props[attIndex][1]) {
if (bestSplitString=="") bestSplitString = "(" + emptyValues[j] + ")";
else bestSplitString += "|" + "(" + emptyValues[j] + ")";
}
}
// clean Gini gain for the attribute
//giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
giniGains[attIndex] = bestGiniGain;
dists[attIndex] = dist;
return bestSplitString;
}
/**
* Split data into two subsets and store sorted indices and weights for two
* successor nodes.
*
* @param subsetIndices sorted indecis of instances for each attribute
* for two successor node
* @param subsetWeights weights of instances for each attribute for
* two successor node
* @param att attribute the split based on
* @param splitPoint split point the split based on if att is numeric
* @param splitStr split subset the split based on if att is nominal
* @param sortedIndices sorted indices of the instances to be split
* @param weights weights of the instances to bes split
* @param data training data
* @throws Exception if something goes wrong
*/
protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,
Attribute att, double splitPoint, String splitStr, int[][] sortedIndices,
double[][] weights, Instances data) throws Exception {
int j;
// For each attribute
for (int i = 0; i < data.numAttributes(); i++) {
if (i==data.classIndex()) continue;
int[] num = new int[2];
for (int k = 0; k < 2; k++) {
subsetIndices[k][i] = new int[sortedIndices[i].length];
subsetWeights[k][i] = new double[weights[i].length];
}
for (j = 0; j < sortedIndices[i].length; j++) {
Instance inst = data.instance(sortedIndices[i][j]);
if (inst.isMissing(att)) {
// Split instance up
for (int k = 0; k < 2; k++) {
if (m_Props[k] > 0) {
subsetIndices[k][i][num[k]] = sortedIndices[i][j];
subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j];
num[k]++;
}
}
} else {
int subset;
if (att.isNumeric()) {
subset = (inst.value(att) < splitPoint) ? 0 : 1;
} else { // nominal attribute
if (splitStr.indexOf
("(" + att.value((int)inst.value(att.index()))+")")!=-1) {
subset = 0;
} else subset = 1;
}
subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
subsetWeights[subset][i][num[subset]] = weights[i][j];
num[subset]++;
}
}
// Trim arrays
for (int k = 0; k < 2; k++) {
int[] copy = new int[num[k]];
System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
subsetIndices[k][i] = copy;
double[] copyWeights = new double[num[k]];
System.arraycopy(subsetWeights[k][i], 0 ,copyWeights, 0, num[k]);
subsetWeights[k][i] = copyWeights;
}
}
}
/**
* Updates the numIncorrectModel field for all nodes when subtree (to be
* pruned) is rooted. This is needed for calculating the alpha-values.
*
* @throws Exception if something goes wrong
*/
public void modelErrors() throws Exception{
Evaluation eval = new Evaluation(m_train);
if (!m_isLeaf) {
m_isLeaf = true; //temporarily make leaf
// calculate distribution for evaluation
eval.evaluateModel(this, m_train);
m_numIncorrectModel = eval.incorrect();
m_isLeaf = false;
for (int i = 0; i < m_Successors.length; i++)
m_Successors[i].modelErrors();
} else {
eval.evaluateModel(this, m_train);
m_numIncorrectModel = eval.incorrect();
}
}
/**
* Updates the numIncorrectTree field for all nodes. This is needed for
* calculating the alpha-values.
*
* @throws Exception if something goes wrong
*/
public void treeErrors() throws Exception {
if (m_isLeaf) {
m_numIncorrectTree = m_numIncorrectModel;
} else {
m_numIncorrectTree = 0;
for (int i = 0; i < m_Successors.length; i++) {
m_Successors[i].treeErrors();
m_numIncorrectTree += m_Successors[i].m_numIncorrectTree;
}
}
}
/**
* Updates the alpha field for all nodes.
*
* @throws Exception if something goes wrong
*/
public void calculateAlphas() throws Exception {
if (!m_isLeaf) {
double errorDiff = m_numIncorrectModel - m_numIncorrectTree;
if (errorDiff <=0) {
//split increases training error (should not normally happen).
//prune it instantly.
makeLeaf(m_train);
m_Alpha = Double.MAX_VALUE;
} else {
//compute alpha
errorDiff /= m_totalTrainInstances;
m_Alpha = errorDiff / (double)(numLeaves() - 1);
long alphaLong = Math.round(m_Alpha*Math.pow(10,10));
m_Alpha = (double)alphaLong/Math.pow(10,10);
for (int i = 0; i < m_Successors.length; i++) {
m_Successors[i].calculateAlphas();
}
}
} else {
//alpha = infinite for leaves (do not want to prune)
m_Alpha = Double.MAX_VALUE;
}
}
/**
* Find the node with minimal alpha value. If two nodes have the same alpha,
* choose the one with more leave nodes.
*
* @param nodeList list of inner nodes
* @return the node to be pruned
*/
protected SimpleCart nodeToPrune(Vector nodeList) {
if (nodeList.size()==0) return null;
if (nodeList.size()==1) return (SimpleCart)nodeList.elementAt(0);
SimpleCart returnNode = (SimpleCart)nodeList.elementAt(0);
double baseAlpha = returnNode.m_Alpha;
for (int i=1; ireturnNode.numLeaves()) {
returnNode = node;
}
}
}
return returnNode;
}
/**
* Compute sorted indices, weights and class probabilities for a given
* dataset. Return total weights of the data at the node.
*
* @param data training data
* @param sortedIndices sorted indices of instances at the node
* @param weights weights of instances at the node
* @param classProbs class probabilities at the node
* @return total weights of instances at the node
* @throws Exception if something goes wrong
*/
protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights,
double[] classProbs) throws Exception {
// Create array of sorted indices and weights
double[] vals = new double[data.numInstances()];
for (int j = 0; j < data.numAttributes(); j++) {
if (j==data.classIndex()) continue;
weights[j] = new double[data.numInstances()];
if (data.attribute(j).isNominal()) {
// Handling nominal attributes. Putting indices of
// instances with missing values at the end.
sortedIndices[j] = new int[data.numInstances()];
int count = 0;
for (int i = 0; i < data.numInstances(); i++) {
Instance inst = data.instance(i);
if (!inst.isMissing(j)) {
sortedIndices[j][count] = i;
weights[j][count] = inst.weight();
count++;
}
}
for (int i = 0; i < data.numInstances(); i++) {
Instance inst = data.instance(i);
if (inst.isMissing(j)) {
sortedIndices[j][count] = i;
weights[j][count] = inst.weight();
count++;
}
}
} else {
// Sorted indices are computed for numeric attributes
// missing values instances are put to end
for (int i = 0; i < data.numInstances(); i++) {
Instance inst = data.instance(i);
vals[i] = inst.value(j);
}
sortedIndices[j] = Utils.sort(vals);
for (int i = 0; i < data.numInstances(); i++) {
weights[j][i] = data.instance(sortedIndices[j][i]).weight();
}
}
}
// Compute initial class counts
double totalWeight = 0;
for (int i = 0; i < data.numInstances(); i++) {
Instance inst = data.instance(i);
classProbs[(int)inst.classValue()] += inst.weight();
totalWeight += inst.weight();
}
return totalWeight;
}
/**
* Compute and return gini gain for given distributions of a node and its
* successor nodes.
*
* @param parentDist class distributions of parent node
* @param childDist class distributions of successor nodes
* @return Gini gain computed
*/
protected double computeGiniGain(double[] parentDist, double[][] childDist) {
double totalWeight = Utils.sum(parentDist);
if (totalWeight==0) return 0;
double leftWeight = Utils.sum(childDist[0]);
double rightWeight = Utils.sum(childDist[1]);
double parentGini = computeGini(parentDist, totalWeight);
double leftGini = computeGini(childDist[0],leftWeight);
double rightGini = computeGini(childDist[1], rightWeight);
return parentGini - leftWeight/totalWeight*leftGini -
rightWeight/totalWeight*rightGini;
}
/**
* Compute and return gini index for a given distribution of a node.
*
* @param dist class distributions
* @param total class distributions
* @return Gini index of the class distributions
*/
protected double computeGini(double[] dist, double total) {
if (total==0) return 0;
double val = 0;
for (int i=0; i= " + m_SplitValue);
else
text.append(m_Attribute.name() + "!=" + m_SplitString);
}
text.append(m_Successors[j].toString(level + 1));
}
}
return text.toString();
}
/**
* Compute size of the tree.
*
* @return size of the tree
*/
public int numNodes() {
if (m_isLeaf) {
return 1;
} else {
int size =1;
for (int i=0;i"));
result.addElement(new Option(
"\tThe number of folds used in the minimal cost-complexity pruning.\n"
+ "\t(default 5)",
"N", 1, "-N "));
result.addElement(new Option(
"\tDon't use the minimal cost-complexity pruning.\n"
+ "\t(default yes).",
"U", 0, "-U"));
result.addElement(new Option(
"\tDon't use the heuristic method for binary split.\n"
+ "\t(default true).",
"H", 0, "-H"));
result.addElement(new Option(
"\tUse 1 SE rule to make pruning decision.\n"
+ "\t(default no).",
"A", 0, "-A"));
result.addElement(new Option(
"\tPercentage of training data size (0-1].\n"
+ "\t(default 1).",
"C", 1, "-C"));
return result.elements();
}
/**
* Parses a given list of options.
*
* Valid options are:
*
* -S <num>
* Random number seed.
* (default 1)
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -M <min no>
* The minimal number of instances at the terminal nodes.
* (default 2)
*
* -N <num folds>
* The number of folds used in the minimal cost-complexity pruning.
* (default 5)
*
* -U
* Don't use the minimal cost-complexity pruning.
* (default yes).
*
* -H
* Don't use the heuristic method for binary split.
* (default true).
*
* -A
* Use 1 SE rule to make pruning decision.
* (default no).
*
* -C
* Percentage of training data size (0-1].
* (default 1).
*
*
* @param options the list of options as an array of strings
* @throws Exception if an options is not supported
*/
public void setOptions(String[] options) throws Exception {
String tmpStr;
super.setOptions(options);
tmpStr = Utils.getOption('M', options);
if (tmpStr.length() != 0)
setMinNumObj(Double.parseDouble(tmpStr));
else
setMinNumObj(2);
tmpStr = Utils.getOption('N', options);
if (tmpStr.length()!=0)
setNumFoldsPruning(Integer.parseInt(tmpStr));
else
setNumFoldsPruning(5);
setUsePrune(!Utils.getFlag('U',options));
setHeuristic(!Utils.getFlag('H',options));
setUseOneSE(Utils.getFlag('A',options));
tmpStr = Utils.getOption('C', options);
if (tmpStr.length()!=0)
setSizePer(Double.parseDouble(tmpStr));
else
setSizePer(1);
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of the classifier.
*
* @return the current setting of the classifier
*/
public String[] getOptions() {
int i;
Vector result;
String[] options;
result = new Vector();
options = super.getOptions();
for (i = 0; i < options.length; i++)
result.add(options[i]);
result.add("-M");
result.add("" + getMinNumObj());
result.add("-N");
result.add("" + getNumFoldsPruning());
if (!getUsePrune())
result.add("-U");
if (!getHeuristic())
result.add("-H");
if (getUseOneSE())
result.add("-A");
result.add("-C");
result.add("" + getSizePer());
return (String[]) result.toArray(new String[result.size()]);
}
/**
* Return an enumeration of the measure names.
*
* @return an enumeration of the measure names
*/
public Enumeration enumerateMeasures() {
Vector result = new Vector();
result.addElement("measureTreeSize");
return result.elements();
}
/**
* Return number of tree size.
*
* @return number of tree size
*/
public double measureTreeSize() {
return numNodes();
}
/**
* Returns the value of the named measure.
*
* @param additionalMeasureName the name of the measure to query for its value
* @return the value of the named measure
* @throws IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) {
return measureTreeSize();
} else {
throw new IllegalArgumentException(additionalMeasureName
+ " not supported (Cart pruning)");
}
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String minNumObjTipText() {
return "The minimal number of observations at the terminal nodes (default 2).";
}
/**
* Set minimal number of instances at the terminal nodes.
*
* @param value minimal number of instances at the terminal nodes
*/
public void setMinNumObj(double value) {
m_minNumObj = value;
}
/**
* Get minimal number of instances at the terminal nodes.
*
* @return minimal number of instances at the terminal nodes
*/
public double getMinNumObj() {
return m_minNumObj;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numFoldsPruningTipText() {
return "The number of folds in the internal cross-validation (default 5).";
}
/**
* Set number of folds in internal cross-validation.
*
* @param value number of folds in internal cross-validation.
*/
public void setNumFoldsPruning(int value) {
m_numFoldsPruning = value;
}
/**
* Set number of folds in internal cross-validation.
*
* @return number of folds in internal cross-validation.
*/
public int getNumFoldsPruning() {
return m_numFoldsPruning;
}
/**
* Return the tip text for this property
*
* @return tip text for this property suitable for displaying in
* the explorer/experimenter gui.
*/
public String usePruneTipText() {
return "Use minimal cost-complexity pruning (default yes).";
}
/**
* Set if use minimal cost-complexity pruning.
*
* @param value if use minimal cost-complexity pruning
*/
public void setUsePrune(boolean value) {
m_Prune = value;
}
/**
* Get if use minimal cost-complexity pruning.
*
* @return if use minimal cost-complexity pruning
*/
public boolean getUsePrune() {
return m_Prune;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui.
*/
public String heuristicTipText() {
return
"If heuristic search is used for binary split for nominal attributes "
+ "in multi-class problems (default yes).";
}
/**
* Set if use heuristic search for nominal attributes in multi-class problems.
*
* @param value if use heuristic search for nominal attributes in
* multi-class problems
*/
public void setHeuristic(boolean value) {
m_Heuristic = value;
}
/**
* Get if use heuristic search for nominal attributes in multi-class problems.
*
* @return if use heuristic search for nominal attributes in
* multi-class problems
*/
public boolean getHeuristic() {return m_Heuristic;}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui.
*/
public String useOneSETipText() {
return "Use the 1SE rule to make pruning decisoin.";
}
/**
* Set if use the 1SE rule to choose final model.
*
* @param value if use the 1SE rule to choose final model
*/
public void setUseOneSE(boolean value) {
m_UseOneSE = value;
}
/**
* Get if use the 1SE rule to choose final model.
*
* @return if use the 1SE rule to choose final model
*/
public boolean getUseOneSE() {
return m_UseOneSE;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui.
*/
public String sizePerTipText() {
return "The percentage of the training set size (0-1, 0 not included).";
}
/**
* Set training set size.
*
* @param value training set size
*/
public void setSizePer(double value) {
if ((value <= 0) || (value > 1))
System.err.println(
"The percentage of the training set size must be in range 0 to 1 "
+ "(0 not included) - ignored!");
else
m_SizePer = value;
}
/**
* Get training set size.
*
* @return training set size
*/
public double getSizePer() {
return m_SizePer;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 10491 $");
}
/**
* Main method.
* @param args the options for the classifier
*/
public static void main(String[] args) {
runClassifier(new SimpleCart(), args);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy