weka.classifiers.trees.j48.NBTreeSplit Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-stable Show documentation
Show all versions of weka-stable Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This is the stable version. Apart from bugfixes, this version
does not receive any other updates.
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* NBTreeSplit.java
* Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees.j48;
import weka.classifiers.bayes.NaiveBayesUpdateable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.filters.Filter;
import weka.filters.supervised.attribute.Discretize;
import java.util.Random;
/**
* Class implementing a NBTree split on an attribute.
*
* @author Mark Hall ([email protected])
* @version $Revision: 1.5 $
*/
public class NBTreeSplit
extends ClassifierSplitModel{
/** for serialization */
private static final long serialVersionUID = 8922627123884975070L;
/** Desired number of branches. */
private int m_complexityIndex;
/** Attribute to split on. */
private int m_attIndex;
/** Minimum number of objects in a split. */
private int m_minNoObj;
/** Value of split point. */
private double m_splitPoint;
/** The sum of the weights of the instances. */
private double m_sumOfWeights;
/** The weight of the instances incorrectly classified by the
naive bayes models arising from this split*/
private double m_errors;
private C45Split m_c45S;
/** The global naive bayes model for this node */
NBTreeNoSplit m_globalNB;
/**
* Initializes the split model.
*/
public NBTreeSplit(int attIndex, int minNoObj, double sumOfWeights) {
// Get index of attribute to split on.
m_attIndex = attIndex;
// Set minimum number of objects.
m_minNoObj = minNoObj;
// Set the sum of the weights
m_sumOfWeights = sumOfWeights;
}
/**
* Creates a NBTree-type split on the given data. Assumes that none of
* the class values is missing.
*
* @exception Exception if something goes wrong
*/
public void buildClassifier(Instances trainInstances)
throws Exception {
// Initialize the remaining instance variables.
m_numSubsets = 0;
m_splitPoint = Double.MAX_VALUE;
m_errors = 0;
if (m_globalNB != null) {
m_errors = m_globalNB.getErrors();
}
// Different treatment for enumerated and numeric
// attributes.
if (trainInstances.attribute(m_attIndex).isNominal()) {
m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();
handleEnumeratedAttribute(trainInstances);
}else{
m_complexityIndex = 2;
trainInstances.sort(trainInstances.attribute(m_attIndex));
handleNumericAttribute(trainInstances);
}
}
/**
* Returns index of attribute for which split was generated.
*/
public final int attIndex() {
return m_attIndex;
}
/**
* Creates split on enumerated attribute.
*
* @exception Exception if something goes wrong
*/
private void handleEnumeratedAttribute(Instances trainInstances)
throws Exception {
m_c45S = new C45Split(m_attIndex, 2, m_sumOfWeights);
m_c45S.buildClassifier(trainInstances);
if (m_c45S.numSubsets() == 0) {
return;
}
m_errors = 0;
Instance instance;
Instances [] trainingSets = new Instances [m_complexityIndex];
for (int i = 0; i < m_complexityIndex; i++) {
trainingSets[i] = new Instances(trainInstances, 0);
}
/* m_distribution = new Distribution(m_complexityIndex,
trainInstances.numClasses()); */
int subset;
for (int i = 0; i < trainInstances.numInstances(); i++) {
instance = trainInstances.instance(i);
subset = m_c45S.whichSubset(instance);
if (subset > -1) {
trainingSets[subset].add((Instance)instance.copy());
} else {
double [] weights = m_c45S.weights(instance);
for (int j = 0; j < m_complexityIndex; j++) {
try {
Instance temp = (Instance) instance.copy();
if (weights.length == m_complexityIndex) {
temp.setWeight(temp.weight() * weights[j]);
} else {
temp.setWeight(temp.weight() / m_complexityIndex);
}
trainingSets[j].add(temp);
} catch (Exception ex) {
ex.printStackTrace();
System.err.println("*** "+m_complexityIndex);
System.err.println(weights.length);
System.exit(1);
}
}
}
}
/* // compute weights (weights of instances per subset
m_weights = new double [m_complexityIndex];
for (int i = 0; i < m_complexityIndex; i++) {
m_weights[i] = trainingSets[i].sumOfWeights();
}
Utils.normalize(m_weights); */
/*
// Only Instances with known values are relevant.
Enumeration enu = trainInstances.enumerateInstances();
while (enu.hasMoreElements()) {
instance = (Instance) enu.nextElement();
if (!instance.isMissing(m_attIndex)) {
// m_distribution.add((int)instance.value(m_attIndex),instance);
trainingSets[(int)instances.value(m_attIndex)].add(instance);
} else {
// add these to the error count
m_errors += instance.weight();
}
} */
Random r = new Random(1);
int minNumCount = 0;
for (int i = 0; i < m_complexityIndex; i++) {
if (trainingSets[i].numInstances() >= 5) {
minNumCount++;
// Discretize the sets
Discretize disc = new Discretize();
disc.setInputFormat(trainingSets[i]);
trainingSets[i] = Filter.useFilter(trainingSets[i], disc);
trainingSets[i].randomize(r);
trainingSets[i].stratify(5);
NaiveBayesUpdateable fullModel = new NaiveBayesUpdateable();
fullModel.buildClassifier(trainingSets[i]);
// add the errors for this branch of the split
m_errors += NBTreeNoSplit.crossValidate(fullModel, trainingSets[i], r);
} else {
// if fewer than min obj then just count them as errors
for (int j = 0; j < trainingSets[i].numInstances(); j++) {
m_errors += trainingSets[i].instance(j).weight();
}
}
}
// Check if there are at least five instances in at least two of the subsets
// subsets.
if (minNumCount > 1) {
m_numSubsets = m_complexityIndex;
}
}
/**
* Creates split on numeric attribute.
*
* @exception Exception if something goes wrong
*/
private void handleNumericAttribute(Instances trainInstances)
throws Exception {
m_c45S = new C45Split(m_attIndex, 2, m_sumOfWeights);
m_c45S.buildClassifier(trainInstances);
if (m_c45S.numSubsets() == 0) {
return;
}
m_errors = 0;
Instances [] trainingSets = new Instances [m_complexityIndex];
trainingSets[0] = new Instances(trainInstances, 0);
trainingSets[1] = new Instances(trainInstances, 0);
int subset = -1;
// populate the subsets
for (int i = 0; i < trainInstances.numInstances(); i++) {
Instance instance = trainInstances.instance(i);
subset = m_c45S.whichSubset(instance);
if (subset != -1) {
trainingSets[subset].add((Instance)instance.copy());
} else {
double [] weights = m_c45S.weights(instance);
for (int j = 0; j < m_complexityIndex; j++) {
Instance temp = (Instance)instance.copy();
if (weights.length == m_complexityIndex) {
temp.setWeight(temp.weight() * weights[j]);
} else {
temp.setWeight(temp.weight() / m_complexityIndex);
}
trainingSets[j].add(temp);
}
}
}
/* // compute weights (weights of instances per subset
m_weights = new double [m_complexityIndex];
for (int i = 0; i < m_complexityIndex; i++) {
m_weights[i] = trainingSets[i].sumOfWeights();
}
Utils.normalize(m_weights); */
Random r = new Random(1);
int minNumCount = 0;
for (int i = 0; i < m_complexityIndex; i++) {
if (trainingSets[i].numInstances() > 5) {
minNumCount++;
// Discretize the sets
Discretize disc = new Discretize();
disc.setInputFormat(trainingSets[i]);
trainingSets[i] = Filter.useFilter(trainingSets[i], disc);
trainingSets[i].randomize(r);
trainingSets[i].stratify(5);
NaiveBayesUpdateable fullModel = new NaiveBayesUpdateable();
fullModel.buildClassifier(trainingSets[i]);
// add the errors for this branch of the split
m_errors += NBTreeNoSplit.crossValidate(fullModel, trainingSets[i], r);
} else {
for (int j = 0; j < trainingSets[i].numInstances(); j++) {
m_errors += trainingSets[i].instance(j).weight();
}
}
}
// Check if minimum number of Instances in at least two
// subsets.
if (minNumCount > 1) {
m_numSubsets = m_complexityIndex;
}
}
/**
* Returns index of subset instance is assigned to.
* Returns -1 if instance is assigned to more than one subset.
*
* @exception Exception if something goes wrong
*/
public final int whichSubset(Instance instance)
throws Exception {
return m_c45S.whichSubset(instance);
}
/**
* Returns weights if instance is assigned to more than one subset.
* Returns null if instance is only assigned to one subset.
*/
public final double [] weights(Instance instance) {
return m_c45S.weights(instance);
// return m_weights;
}
/**
* Returns a string containing java source code equivalent to the test
* made at this node. The instance being tested is called "i".
*
* @param index index of the nominal value tested
* @param data the data containing instance structure info
* @return a value of type 'String'
*/
public final String sourceExpression(int index, Instances data) {
return m_c45S.sourceExpression(index, data);
}
/**
* Prints the condition satisfied by instances in a subset.
*
* @param index of subset
* @param data training set.
*/
public final String rightSide(int index,Instances data) {
return m_c45S.rightSide(index, data);
}
/**
* Prints left side of condition..
*
* @param data training set.
*/
public final String leftSide(Instances data) {
return m_c45S.leftSide(data);
}
/**
* Return the probability for a class value
*
* @param classIndex the index of the class value
* @param instance the instance to generate a probability for
* @param theSubset the subset to consider
* @return a probability
* @exception Exception if an error occurs
*/
public double classProb(int classIndex, Instance instance, int theSubset)
throws Exception {
// use the global naive bayes model
if (theSubset > -1) {
return m_globalNB.classProb(classIndex, instance, theSubset);
} else {
throw new Exception("This shouldn't happen!!!");
}
}
/**
* Return the global naive bayes model for this node
*
* @return a NBTreeNoSplit
value
*/
public NBTreeNoSplit getGlobalModel() {
return m_globalNB;
}
/**
* Set the global naive bayes model for this node
*
* @param global a NBTreeNoSplit
value
*/
public void setGlobalModel(NBTreeNoSplit global) {
m_globalNB = global;
}
/**
* Return the errors made by the naive bayes models arising
* from this split.
*
* @return a double
value
*/
public double getErrors() {
return m_errors;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 1.5 $");
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy