All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.trees.j48.NBTreeSplit Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    NBTreeSplit.java
 *    Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.trees.j48;

import weka.classifiers.bayes.NaiveBayesUpdateable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.filters.Filter;
import weka.filters.supervised.attribute.Discretize;

import java.util.Random;

/**
 * Class implementing a NBTree split on an attribute.
 *
 * @author Mark Hall ([email protected])
 * @version $Revision: 1.5 $
 */
public class NBTreeSplit
  extends ClassifierSplitModel{

  /** for serialization */
  private static final long serialVersionUID = 8922627123884975070L;

  /** Desired number of branches. */
  private int m_complexityIndex;  

  /** Attribute to split on. */
  private int m_attIndex;         

  /** Minimum number of objects in a split.   */
  private int m_minNoObj;         

  /** Value of split point. */
  private double m_splitPoint;   

  /** The sum of the weights of the instances. */
  private double m_sumOfWeights;  

  /** The weight of the instances incorrectly classified by the 
      naive bayes models arising from this split*/
  private double m_errors;

  private C45Split m_c45S;

  /** The global naive bayes model for this node */
  NBTreeNoSplit m_globalNB;

  /**
   * Initializes the split model.
   */
  public NBTreeSplit(int attIndex, int minNoObj, double sumOfWeights) {
    
    // Get index of attribute to split on.
    m_attIndex = attIndex;
        
    // Set minimum number of objects.
    m_minNoObj = minNoObj;

    // Set the sum of the weights
    m_sumOfWeights = sumOfWeights;
    
  }

  /**
   * Creates a NBTree-type split on the given data. Assumes that none of
   * the class values is missing.
   *
   * @exception Exception if something goes wrong
   */
  public void buildClassifier(Instances trainInstances) 
       throws Exception {

    // Initialize the remaining instance variables.
    m_numSubsets = 0;
    m_splitPoint = Double.MAX_VALUE;
    m_errors = 0;
    if (m_globalNB != null) {
      m_errors = m_globalNB.getErrors();
    }

    // Different treatment for enumerated and numeric
    // attributes.
    if (trainInstances.attribute(m_attIndex).isNominal()) {
      m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();
      handleEnumeratedAttribute(trainInstances);
    }else{
      m_complexityIndex = 2;
      trainInstances.sort(trainInstances.attribute(m_attIndex));
      handleNumericAttribute(trainInstances);
    }
  }

  /**
   * Returns index of attribute for which split was generated.
   */
  public final int attIndex() {
    
    return m_attIndex;
  }

  /**
   * Creates split on enumerated attribute.
   *
   * @exception Exception if something goes wrong
   */
  private void handleEnumeratedAttribute(Instances trainInstances)
       throws Exception {

    m_c45S = new C45Split(m_attIndex, 2, m_sumOfWeights);
    m_c45S.buildClassifier(trainInstances);
    if (m_c45S.numSubsets() == 0) {
      return;
    }
    m_errors = 0;
    Instance instance;

    Instances [] trainingSets = new Instances [m_complexityIndex];
    for (int i = 0; i < m_complexityIndex; i++) {
      trainingSets[i] = new Instances(trainInstances, 0);
    }
    /*    m_distribution = new Distribution(m_complexityIndex,
	  trainInstances.numClasses()); */
    int subset;
    for (int i = 0; i < trainInstances.numInstances(); i++) {
      instance = trainInstances.instance(i);
      subset = m_c45S.whichSubset(instance);
      if (subset > -1) {
	trainingSets[subset].add((Instance)instance.copy());
      } else {
	double [] weights = m_c45S.weights(instance);
	for (int j = 0; j < m_complexityIndex; j++) {
	  try {
	    Instance temp = (Instance) instance.copy();
	    if (weights.length == m_complexityIndex) {
	      temp.setWeight(temp.weight() * weights[j]);
	    } else {
	      temp.setWeight(temp.weight() / m_complexityIndex);
	    }
	    trainingSets[j].add(temp);
	  } catch (Exception ex) {
	    ex.printStackTrace();
	    System.err.println("*** "+m_complexityIndex);
	    System.err.println(weights.length);
	    System.exit(1);
	  }
	}
      }
    }

    /*    // compute weights (weights of instances per subset
    m_weights = new double [m_complexityIndex];
    for (int i = 0; i < m_complexityIndex; i++) {
      m_weights[i] = trainingSets[i].sumOfWeights();
    }
    Utils.normalize(m_weights); */

    /*
    // Only Instances with known values are relevant.
    Enumeration enu = trainInstances.enumerateInstances();
    while (enu.hasMoreElements()) {
      instance = (Instance) enu.nextElement();
      if (!instance.isMissing(m_attIndex)) {
	//	m_distribution.add((int)instance.value(m_attIndex),instance);
	trainingSets[(int)instances.value(m_attIndex)].add(instance);
      } else {
	// add these to the error count
	m_errors += instance.weight();
      }
      } */

    Random r = new Random(1);
    int minNumCount = 0;
    for (int i = 0; i < m_complexityIndex; i++) {
      if (trainingSets[i].numInstances() >= 5) {
	minNumCount++;
	// Discretize the sets
	Discretize disc = new Discretize();
	disc.setInputFormat(trainingSets[i]);
	trainingSets[i] = Filter.useFilter(trainingSets[i], disc);

	trainingSets[i].randomize(r);
	trainingSets[i].stratify(5);
	NaiveBayesUpdateable fullModel = new NaiveBayesUpdateable();
	fullModel.buildClassifier(trainingSets[i]);

	// add the errors for this branch of the split
	m_errors += NBTreeNoSplit.crossValidate(fullModel, trainingSets[i], r);
      } else {
	// if fewer than min obj then just count them as errors
	for (int j = 0; j < trainingSets[i].numInstances(); j++) {
	  m_errors += trainingSets[i].instance(j).weight();
	}
      }
    }
    
    // Check if there are at least five instances in at least two of the subsets
    // subsets.
    if (minNumCount > 1) {
      m_numSubsets = m_complexityIndex;
    }
  }

  /**
   * Creates split on numeric attribute.
   *
   * @exception Exception if something goes wrong
   */
  private void handleNumericAttribute(Instances trainInstances)
       throws Exception {

    m_c45S = new C45Split(m_attIndex, 2, m_sumOfWeights);
    m_c45S.buildClassifier(trainInstances);
    if (m_c45S.numSubsets() == 0) {
      return;
    }
    m_errors = 0;

    Instances [] trainingSets = new Instances [m_complexityIndex];
    trainingSets[0] = new Instances(trainInstances, 0);
    trainingSets[1] = new Instances(trainInstances, 0);
    int subset = -1;
    
    // populate the subsets
    for (int i = 0; i < trainInstances.numInstances(); i++) {
      Instance instance = trainInstances.instance(i);
      subset = m_c45S.whichSubset(instance);
      if (subset != -1) {
	trainingSets[subset].add((Instance)instance.copy());
      } else {
	double [] weights = m_c45S.weights(instance);
	for (int j = 0; j < m_complexityIndex; j++) {
	  Instance temp = (Instance)instance.copy();
	  if (weights.length == m_complexityIndex) {
	    temp.setWeight(temp.weight() * weights[j]);
	  } else {
	    temp.setWeight(temp.weight() / m_complexityIndex);
	  }
	  trainingSets[j].add(temp); 
	}
      }
    }
    
    /*    // compute weights (weights of instances per subset
    m_weights = new double [m_complexityIndex];
    for (int i = 0; i < m_complexityIndex; i++) {
      m_weights[i] = trainingSets[i].sumOfWeights();
    }
    Utils.normalize(m_weights); */

    Random r = new Random(1);
    int minNumCount = 0;
    for (int i = 0; i < m_complexityIndex; i++) {
      if (trainingSets[i].numInstances() > 5) {
	minNumCount++;
	// Discretize the sets
		Discretize disc = new Discretize();
	disc.setInputFormat(trainingSets[i]);
	trainingSets[i] = Filter.useFilter(trainingSets[i], disc);

	trainingSets[i].randomize(r);
	trainingSets[i].stratify(5);
	NaiveBayesUpdateable fullModel = new NaiveBayesUpdateable();
	fullModel.buildClassifier(trainingSets[i]);

	// add the errors for this branch of the split
	m_errors += NBTreeNoSplit.crossValidate(fullModel, trainingSets[i], r);
      } else {
	for (int j = 0; j < trainingSets[i].numInstances(); j++) {
	  m_errors += trainingSets[i].instance(j).weight();
	}
      }
    }
    
    // Check if minimum number of Instances in at least two
    // subsets.
    if (minNumCount > 1) {
      m_numSubsets = m_complexityIndex;
    }
  }

  /**
   * Returns index of subset instance is assigned to.
   * Returns -1 if instance is assigned to more than one subset.
   *
   * @exception Exception if something goes wrong
   */
  public final int whichSubset(Instance instance) 
    throws Exception {
    
    return m_c45S.whichSubset(instance);
  }

  /**
   * Returns weights if instance is assigned to more than one subset.
   * Returns null if instance is only assigned to one subset.
   */
  public final double [] weights(Instance instance) {
    return m_c45S.weights(instance);
    //     return m_weights;
  }

  /**
   * Returns a string containing java source code equivalent to the test
   * made at this node. The instance being tested is called "i".
   *
   * @param index index of the nominal value tested
   * @param data the data containing instance structure info
   * @return a value of type 'String'
   */
  public final String sourceExpression(int index, Instances data) {
    return m_c45S.sourceExpression(index, data);
  }

  /**
   * Prints the condition satisfied by instances in a subset.
   *
   * @param index of subset 
   * @param data training set.
   */
  public final String rightSide(int index,Instances data) {
    return m_c45S.rightSide(index, data);
  }

  /**
   * Prints left side of condition..
   *
   * @param data training set.
   */
  public final String leftSide(Instances data) {

    return m_c45S.leftSide(data);
  }

  /**
   * Return the probability for a class value
   *
   * @param classIndex the index of the class value
   * @param instance the instance to generate a probability for
   * @param theSubset the subset to consider
   * @return a probability
   * @exception Exception if an error occurs
   */
  public double classProb(int classIndex, Instance instance, int theSubset) 
    throws Exception {

    // use the global naive bayes model
    if (theSubset > -1) {
      return m_globalNB.classProb(classIndex, instance, theSubset);
    } else {
      throw new Exception("This shouldn't happen!!!");
    }
  }

  /**
   * Return the global naive bayes model for this node
   *
   * @return a NBTreeNoSplit value
   */
  public NBTreeNoSplit getGlobalModel() {
    return m_globalNB;
  }

  /**
   * Set the global naive bayes model for this node
   *
   * @param global a NBTreeNoSplit value
   */
  public void setGlobalModel(NBTreeNoSplit global) {
    m_globalNB = global;
  }

  /**
   * Return the errors made by the naive bayes models arising
   * from this split.
   *
   * @return a double value
   */
  public double getErrors() {
    return m_errors;
  }
  
  /**
   * Returns the revision string.
   * 
   * @return		the revision
   */
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 1.5 $");
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy