All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.trees.m5.RuleNode Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    RuleNode.java
 *    Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.trees.m5;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegression;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

/**
 * Constructs a node for use in an m5 tree or rule
 *
 * @author Mark Hall ([email protected])
 * @version $Revision: 1.13 $
 */
public class RuleNode 
  extends Classifier {

  /** for serialization */
  static final long serialVersionUID = 1979807611124337144L;
  
  /**
   * instances reaching this node
   */
  private Instances	   m_instances;

  /**
   * the class index
   */
  private int		   m_classIndex;

  /**
   * the number of instances reaching this node
   */
  protected int		   m_numInstances;

  /**
   * the number of attributes
   */
  private int		   m_numAttributes;

  /**
   * Node is a leaf
   */
  private boolean	   m_isLeaf;

  /**
   * attribute this node splits on
   */
  private int		   m_splitAtt;

  /**
   * the value of the split attribute
   */
  private double	   m_splitValue;

  /**
   * the linear model at this node
   */
  private PreConstructedLinearModel m_nodeModel;

  /**
   * the number of paramters in the chosen model for this node---either
   * the subtree model or the linear model.
   * The constant term is counted as a paramter---this is for pruning
   * purposes
   */
  public int		   m_numParameters;

  /**
   * the mean squared error of the model at this node (either linear or
   * subtree)
   */
  private double	   m_rootMeanSquaredError;

  /**
   * left child node
   */
  protected RuleNode	   m_left;

  /**
   * right child node
   */
  protected RuleNode	   m_right;

  /**
   * the parent of this node
   */
  private RuleNode	   m_parent;

  /**
   * a node will not be split if it contains less then m_splitNum instances
   */
  private double	   m_splitNum = 4;

  /**
   * a node will not be split if its class standard deviation is less
   * than 5% of the class standard deviation of all the instances
   */
  private double	   m_devFraction = 0.05;
  private double	   m_pruningMultiplier = 2;

  /**
   * the number assigned to the linear model if this node is a leaf.
   * = 0 if this node is not a leaf
   */
  private int		   m_leafModelNum;

  /**
   * a node will not be split if the class deviation of its
   * instances is less than m_devFraction of the deviation of the
   * global class
   */
  private double	   m_globalDeviation;

  /**
   * the absolute deviation of the global class
   */
  private double	   m_globalAbsDeviation;

  /**
   * Indices of the attributes to be used in generating a linear model
   * at this node
   */
  private int [] m_indices;
    
  /**
   * Constant used in original m5 smoothing calculation
   */
  private static final double	   SMOOTHING_CONSTANT = 15.0;

  /**
   * Node id.
   */
  private int m_id;

  /**
   * Save the instances at each node (for visualizing in the
   * Explorer's treevisualizer.
   */
  private boolean m_saveInstances = false;

  /**
   * Make a regression tree instead of a model tree
   */
  private boolean m_regressionTree;

  /**
   * Creates a new RuleNode instance.
   *
   * @param globalDev the global standard deviation of the class
   * @param globalAbsDev the global absolute deviation of the class
   * @param parent the parent of this node
   */
  public RuleNode(double globalDev, double globalAbsDev, RuleNode parent) {
    m_nodeModel = null;
    m_right = null;
    m_left = null;
    m_parent = parent;
    m_globalDeviation = globalDev;
    m_globalAbsDeviation = globalAbsDev;
  }

    
  /**
   * Build this node (find an attribute and split point)
   *
   * @param data the instances on which to build this node
   * @throws Exception if an error occurs
   */
  public void buildClassifier(Instances data) throws Exception {

    m_rootMeanSquaredError = Double.MAX_VALUE;
    //    m_instances = new Instances(data);
    m_instances = data;
    m_classIndex = m_instances.classIndex();
    m_numInstances = m_instances.numInstances();
    m_numAttributes = m_instances.numAttributes();
    m_nodeModel = null;
    m_right = null;
    m_left = null;

    if ((m_numInstances < m_splitNum) 
	|| (Rule.stdDev(m_classIndex, m_instances) 
	    < (m_globalDeviation * m_devFraction))) {
      m_isLeaf = true;
    } else {
      m_isLeaf = false;
    } 

    split();
  } 
 
  /**
   * Classify an instance using this node. Recursively calls classifyInstance
   * on child nodes.
   *
   * @param inst the instance to classify
   * @return the prediction for this instance
   * @throws Exception if an error occurs
   */
  public double classifyInstance(Instance inst) throws Exception {
    if (m_isLeaf) {
      if (m_nodeModel == null) {
	throw new Exception("Classifier has not been built correctly.");
      } 

      return m_nodeModel.classifyInstance(inst);
    }

    if (inst.value(m_splitAtt) <= m_splitValue) {
      return m_left.classifyInstance(inst);
    } else {
      return m_right.classifyInstance(inst);
    } 
  } 

  /**
   * Applies the m5 smoothing procedure to a prediction
   *
   * @param n number of instances in selected child of this node
   * @param pred the prediction so far
   * @param supportPred the prediction of the linear model at this node
   * @return the current prediction smoothed with the prediction of the
   * linear model at this node
   * @throws Exception if an error occurs
   */
  protected static double smoothingOriginal(double n, double pred, 
					    double supportPred) 
    throws Exception {
    double   smoothed;

    smoothed = 
      ((n * pred) + (SMOOTHING_CONSTANT * supportPred)) /
      (n + SMOOTHING_CONSTANT);

    return smoothed;
  } 


  /**
   * Finds an attribute and split point for this node
   *
   * @throws Exception if an error occurs
   */
  public void split() throws Exception {
    int		  i;
    Instances     leftSubset, rightSubset;
    SplitEvaluate bestSplit, currentSplit;
    boolean[]     attsBelow;

    if (!m_isLeaf) {
     
      bestSplit = new YongSplitInfo(0, m_numInstances - 1, -1);
      currentSplit = new YongSplitInfo(0, m_numInstances - 1, -1);

      // find the best attribute to split on
      for (i = 0; i < m_numAttributes; i++) {
	if (i != m_classIndex) {

	  // sort the instances by this attribute
	  m_instances.sort(i);
	  currentSplit.attrSplit(i, m_instances);

	  if ((Math.abs(currentSplit.maxImpurity() - 
			bestSplit.maxImpurity()) > 1.e-6) 
	      && (currentSplit.maxImpurity() 
		  > bestSplit.maxImpurity() + 1.e-6)) {
	    bestSplit = currentSplit.copy();
	  } 
	} 
      } 

      // cant find a good split or split point?
      if (bestSplit.splitAttr() < 0 || bestSplit.position() < 1 
	  || bestSplit.position() > m_numInstances - 1) {
	m_isLeaf = true;
      } else {
	m_splitAtt = bestSplit.splitAttr();
	m_splitValue = bestSplit.splitValue();
	leftSubset = new Instances(m_instances, m_numInstances);
	rightSubset = new Instances(m_instances, m_numInstances);

	for (i = 0; i < m_numInstances; i++) {
	  if (m_instances.instance(i).value(m_splitAtt) <= m_splitValue) {
	    leftSubset.add(m_instances.instance(i));
	  } else {
	    rightSubset.add(m_instances.instance(i));
	  } 
	} 

	leftSubset.compactify();
	rightSubset.compactify();

	// build left and right nodes
	m_left = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this);
	m_left.setMinNumInstances(m_splitNum);
	m_left.setRegressionTree(m_regressionTree);
	m_left.setSaveInstances(m_saveInstances);
	m_left.buildClassifier(leftSubset);

	m_right = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this);
	m_right.setMinNumInstances(m_splitNum);
	m_right.setRegressionTree(m_regressionTree);
	m_right.setSaveInstances(m_saveInstances);
	m_right.buildClassifier(rightSubset);

	// now find out what attributes are tested in the left and right
	// subtrees and use them to learn a linear model for this node
	if (!m_regressionTree) {
	  attsBelow = attsTestedBelow();
	  attsBelow[m_classIndex] = true;
	  int count = 0, j;

	  for (j = 0; j < m_numAttributes; j++) {
	    if (attsBelow[j]) {
	      count++;
	    } 
	  } 
	  
	  int[] indices = new int[count];

	  count = 0;
	  
	  for (j = 0; j < m_numAttributes; j++) {
	    if (attsBelow[j] && (j != m_classIndex)) {
	      indices[count++] = j;
	    } 
	  } 
	  
	  indices[count] = m_classIndex;
	  m_indices = indices;
	} else {
	  m_indices = new int [1];
	  m_indices[0] = m_classIndex;
	  m_numParameters = 1;
	}
      } 
    } 

    if (m_isLeaf) {
      int [] indices = new int [1];
      indices[0] = m_classIndex;
      m_indices = indices;
      m_numParameters = 1;
     
      // need to evaluate the model here if want correct stats for unpruned
      // tree
    } 
  } 

  /**
   * Build a linear model for this node using those attributes
   * specified in indices.
   *
   * @param indices an array of attribute indices to include in the linear
   * model
   * @throws Exception if something goes wrong
   */
  private void buildLinearModel(int [] indices) throws Exception {
    // copy the training instances and remove all but the tested
    // attributes
    Instances reducedInst = new Instances(m_instances);
    Remove attributeFilter = new Remove();
    
    attributeFilter.setInvertSelection(true);
    attributeFilter.setAttributeIndicesArray(indices);
    attributeFilter.setInputFormat(reducedInst);

    reducedInst = Filter.useFilter(reducedInst, attributeFilter);
    
    // build a linear regression for the training data using the
    // tested attributes
    LinearRegression temp = new LinearRegression();
    temp.buildClassifier(reducedInst);

    double [] lmCoeffs = temp.coefficients();
    double [] coeffs = new double [m_instances.numAttributes()];

    for (int i = 0; i < lmCoeffs.length - 1; i++) {
      if (indices[i] != m_classIndex) {
	coeffs[indices[i]] = lmCoeffs[i];
      }
    }
    m_nodeModel = new PreConstructedLinearModel(coeffs, lmCoeffs[lmCoeffs.length - 1]);
    m_nodeModel.buildClassifier(m_instances);
  }

  /**
   * Returns an array containing the indexes of attributes used in tests
   * above this node
   *
   * @return an array of attribute indexes
   */
  private boolean[] attsTestedAbove() {
    boolean[] atts = new boolean[m_numAttributes];
    boolean[] attsAbove = null;

    if (m_parent != null) {
      attsAbove = m_parent.attsTestedAbove();
    } 

    if (attsAbove != null) {
      for (int i = 0; i < m_numAttributes; i++) {
	atts[i] = attsAbove[i];
      } 
    } 

    atts[m_splitAtt] = true;
    return atts;
  } 

  /**
   * Returns an array containing the indexes of attributes used in tests
   * below this node
   *
   * @return an array of attribute indexes
   */
  private boolean[] attsTestedBelow() {
    boolean[] attsBelow = new boolean[m_numAttributes];
    boolean[] attsBelowLeft = null;
    boolean[] attsBelowRight = null;

    if (m_right != null) {
      attsBelowRight = m_right.attsTestedBelow();
    } 

    if (m_left != null) {
      attsBelowLeft = m_left.attsTestedBelow();
    } 

    for (int i = 0; i < m_numAttributes; i++) {
      if (attsBelowLeft != null) {
	attsBelow[i] = (attsBelow[i] || attsBelowLeft[i]);
      } 

      if (attsBelowRight != null) {
	attsBelow[i] = (attsBelow[i] || attsBelowRight[i]);
      } 
    } 

    if (!m_isLeaf) {
      attsBelow[m_splitAtt] = true;
    } 
    return attsBelow;
  } 

  /**
   * Sets the leaves' numbers
   * @param leafCounter the number of leaves counted
   * @return the number of the total leaves under the node
   */
  public int numLeaves(int leafCounter) {

    if (!m_isLeaf) {
      // node
      m_leafModelNum = 0;

      if (m_left != null) {
	leafCounter = m_left.numLeaves(leafCounter);
      } 

      if (m_right != null) {
	leafCounter = m_right.numLeaves(leafCounter);
      } 
    } else {
      // leaf
      leafCounter++;
      m_leafModelNum = leafCounter;
    } 
    return leafCounter;
  } 

  /**
   * print the linear model at this node
   * 
   * @return the linear model
   */
  public String toString() {
    return printNodeLinearModel();
  } 

  /**
   * print the linear model at this node
   * 
   * @return the linear model at this node
   */
  public String printNodeLinearModel() {
    return m_nodeModel.toString();
  } 

  /**
   * print all leaf models
   * 
   * @return the leaf models
   */
  public String printLeafModels() {
    StringBuffer text = new StringBuffer();

    if (m_isLeaf) {
      text.append("\nLM num: " + m_leafModelNum);
      text.append(m_nodeModel.toString());
      text.append("\n");
    } else {
      text.append(m_left.printLeafModels());
      text.append(m_right.printLeafModels());
    } 
    return text.toString();
  } 

  /**
   * Returns a description of this node (debugging purposes)
   *
   * @return a string describing this node
   */
  public String nodeToString() {
    StringBuffer text = new StringBuffer();

    System.out.println("In to string");
    text.append("Node:\n\tnum inst: " + m_numInstances);

    if (m_isLeaf) {
      text.append("\n\tleaf");
    } else {
      text.append("\tnode");
    }

    text.append("\n\tSplit att: " + m_instances.attribute(m_splitAtt).name());
    text.append("\n\tSplit val: " + Utils.doubleToString(m_splitValue, 1, 3));
    text.append("\n\tLM num: " + m_leafModelNum);
    text.append("\n\tLinear model\n" + m_nodeModel.toString());
    text.append("\n\n");

    if (m_left != null) {
      text.append(m_left.nodeToString());
    } 

    if (m_right != null) {
      text.append(m_right.nodeToString());
    } 

    return text.toString();
  } 

  /**
   * Recursively builds a textual description of the tree
   *
   * @param level the level of this node
   * @return string describing the tree
   */
  public String treeToString(int level) {
    int		 i;
    StringBuffer text = new StringBuffer();

    if (!m_isLeaf) {
      text.append("\n");

      for (i = 1; i <= level; i++) {
	text.append("|   ");
      } 

      if (m_instances.attribute(m_splitAtt).name().charAt(0) != '[') {
	text.append(m_instances.attribute(m_splitAtt).name() + " <= " 
		    + Utils.doubleToString(m_splitValue, 1, 3) + " : ");
      } else {
	text.append(m_instances.attribute(m_splitAtt).name() + " false : ");
      } 

      if (m_left != null) {
	text.append(m_left.treeToString(level + 1));
      } else {
	text.append("NULL\n");
      }

      for (i = 1; i <= level; i++) {
	text.append("|   ");
      } 

      if (m_instances.attribute(m_splitAtt).name().charAt(0) != '[') {
	text.append(m_instances.attribute(m_splitAtt).name() + " >  " 
		    + Utils.doubleToString(m_splitValue, 1, 3) + " : ");
      } else {
	text.append(m_instances.attribute(m_splitAtt).name() + " true : ");
      } 

      if (m_right != null) {
	text.append(m_right.treeToString(level + 1));
      } else {
	text.append("NULL\n");
      }
    } else {
      text.append("LM" + m_leafModelNum);

      if (m_globalDeviation > 0.0) {
	text
	  .append(" (" + m_numInstances + "/" 
		  + Utils.doubleToString((100.0 * m_rootMeanSquaredError /
					     m_globalDeviation), 1, 3) 
		  + "%)\n");
      } else {
	text.append(" (" + m_numInstances + ")\n");
      } 
    } 
    return text.toString();
  } 

  /**
   * Traverses the tree and installs linear models at each node.
   * This method must be called if pruning is not to be performed.
   *
   * @throws Exception if an error occurs
   */
  public void installLinearModels() throws Exception {
    Evaluation nodeModelEval;
    if (m_isLeaf) {
      buildLinearModel(m_indices);
    } else {
      if (m_left != null) {
	m_left.installLinearModels();
      }

      if (m_right != null) {
	m_right.installLinearModels();
      }
      buildLinearModel(m_indices);
    }
    nodeModelEval = new Evaluation(m_instances);
    nodeModelEval.evaluateModel(m_nodeModel, m_instances);
    m_rootMeanSquaredError = nodeModelEval.rootMeanSquaredError();
    // save space
    if (!m_saveInstances) {
      m_instances = new Instances(m_instances, 0);
    }
  }

  /**
   * 
   * @throws Exception
   */
  public void installSmoothedModels() throws Exception {

    if (m_isLeaf) {
      double [] coefficients = new double [m_numAttributes];
      double intercept;
      double  [] coeffsUsedByLinearModel = m_nodeModel.coefficients();
      RuleNode current = this;
      
      // prime array with leaf node coefficients
      for (int i = 0; i < coeffsUsedByLinearModel.length; i++) {
	if (i != m_classIndex) {
	  coefficients[i] = coeffsUsedByLinearModel[i];
	}
      }
      // intercept
      intercept = m_nodeModel.intercept();

      do {
	if (current.m_parent != null) {
	  double n = current.m_numInstances;
	  // contribution of the model below
	  for (int i = 0; i < coefficients.length; i++) {
	    coefficients[i] = ((coefficients[i] * n) / (n + SMOOTHING_CONSTANT));
	  }
	  intercept =  ((intercept * n) / (n + SMOOTHING_CONSTANT));

	  // contribution of this model
	  coeffsUsedByLinearModel = current.m_parent.getModel().coefficients();
	  for (int i = 0; i < coeffsUsedByLinearModel.length; i++) {
	    if (i != m_classIndex) {
	      // smooth in these coefficients (at this node)
	      coefficients[i] += 
		((SMOOTHING_CONSTANT * coeffsUsedByLinearModel[i]) /
		 (n + SMOOTHING_CONSTANT));
	    }
	  }
	  // smooth in the intercept
	  intercept += 
	    ((SMOOTHING_CONSTANT * 
	      current.m_parent.getModel().intercept()) /
	     (n + SMOOTHING_CONSTANT));
	  current = current.m_parent;
	}
      } while (current.m_parent != null);
      m_nodeModel = 
	new PreConstructedLinearModel(coefficients, intercept);
      m_nodeModel.buildClassifier(m_instances);
    }
    if (m_left != null) {
      m_left.installSmoothedModels();
    }
    if (m_right != null) {
      m_right.installSmoothedModels();
    }
  }
    
  /**
   * Recursively prune the tree
   *
   * @throws Exception if an error occurs
   */
  public void prune() throws Exception {
    Evaluation nodeModelEval = null;

    if (m_isLeaf) {
      buildLinearModel(m_indices);
      nodeModelEval = new Evaluation(m_instances);

      // count the constant term as a paramter for a leaf
      // Evaluate the model
      nodeModelEval.evaluateModel(m_nodeModel, m_instances);

      m_rootMeanSquaredError = nodeModelEval.rootMeanSquaredError();
    } else {

      // Prune the left and right subtrees
      if (m_left != null) {
	m_left.prune();
      } 

      if (m_right != null) {
	m_right.prune();	
      } 
      
      buildLinearModel(m_indices);
      nodeModelEval = new Evaluation(m_instances);

      double rmsModel;
      double adjustedErrorModel;

      nodeModelEval.evaluateModel(m_nodeModel, m_instances);

      rmsModel = nodeModelEval.rootMeanSquaredError();
      adjustedErrorModel = rmsModel 
	* pruningFactor(m_numInstances, 
			m_nodeModel.numParameters() + 1);

      // Evaluate this node (ie its left and right subtrees)
      Evaluation nodeEval = new Evaluation(m_instances);
      double     rmsSubTree;
      double     adjustedErrorNode;
      int	 l_params = 0, r_params = 0;

      nodeEval.evaluateModel(this, m_instances);

      rmsSubTree = nodeEval.rootMeanSquaredError();

      if (m_left != null) {
	l_params = m_left.numParameters();
      } 

      if (m_right != null) {
	r_params = m_right.numParameters();
      } 

      adjustedErrorNode = rmsSubTree 
	* pruningFactor(m_numInstances, 
			(l_params + r_params + 1));

      if ((adjustedErrorModel <= adjustedErrorNode) 
	  || (adjustedErrorModel < (m_globalDeviation * 0.00001))) {

	// Choose linear model for this node rather than subtree model
	m_isLeaf = true;
	m_right = null;
	m_left = null;
	m_numParameters = m_nodeModel.numParameters() + 1;
	m_rootMeanSquaredError = rmsModel;
      } else {
	m_numParameters = (l_params + r_params + 1);
	m_rootMeanSquaredError = rmsSubTree;
      } 
    }
    // save space
    if (!m_saveInstances) {
      m_instances = new Instances(m_instances, 0);
    }
  } 


  /**
   * Compute the pruning factor
   *
   * @param num_instances number of instances
   * @param num_params number of parameters in the model
   * @return the pruning factor
   */
  private double pruningFactor(int num_instances, int num_params) {
    if (num_instances <= num_params) {
      return 10.0;    // Caution says Yong in his code
    } 

    return ((double) (num_instances + m_pruningMultiplier * num_params) 
	    / (double) (num_instances - num_params));
  } 

  /**
   * Find the leaf with greatest coverage
   *
   * @param maxCoverage the greatest coverage found so far
   * @param bestLeaf the leaf with the greatest coverage
   */
  public void findBestLeaf(double[] maxCoverage, RuleNode[] bestLeaf) {
    if (!m_isLeaf) {
      if (m_left != null) {
	m_left.findBestLeaf(maxCoverage, bestLeaf);
      } 

      if (m_right != null) {
	m_right.findBestLeaf(maxCoverage, bestLeaf);
      } 
    } else {
      if (m_numInstances > maxCoverage[0]) {
	maxCoverage[0] = m_numInstances;
	bestLeaf[0] = this;
      } 
    } 
  } 

  /**
   * Return a list containing all the leaves in the tree
   *
   * @param v a single element array containing a vector of leaves
   */
  public void returnLeaves(FastVector[] v) {
    if (m_isLeaf) {
      v[0].addElement(this);
    } else {
      if (m_left != null) {
	m_left.returnLeaves(v);
      } 

      if (m_right != null) {
	m_right.returnLeaves(v);
      } 
    } 
  } 

  /**
   * Get the parent of this node
   *
   * @return the parent of this node
   */
  public RuleNode parentNode() {
    return m_parent;
  } 

  /**
   * Get the left child of this node
   *
   * @return the left child of this node
   */
  public RuleNode leftNode() {
    return m_left;
  } 

  /**
   * Get the right child of this node
   *
   * @return the right child of this node
   */
  public RuleNode rightNode() {
    return m_right;
  } 

  /**
   * Get the index of the splitting attribute for this node
   *
   * @return the index of the splitting attribute
   */
  public int splitAtt() {
    return m_splitAtt;
  } 

  /**
   * Get the split point for this node
   *
   * @return the split point for this node
   */
  public double splitVal() {
    return m_splitValue;
  } 

  /**
   * Get the number of linear models in the tree
   *
   * @return the number of linear models
   */
  public int numberOfLinearModels() {
    if (m_isLeaf) {
      return 1;
    } else {
      return m_left.numberOfLinearModels() + m_right.numberOfLinearModels();
    } 
  } 

  /**
   * Return true if this node is a leaf
   *
   * @return true if this node is a leaf
   */
  public boolean isLeaf() {
    return m_isLeaf;
  } 

  /**
   * Get the root mean squared error at this node
   *
   * @return the root mean squared error
   */
  protected double rootMeanSquaredError() {
    return m_rootMeanSquaredError;
  } 

  /**
   * Get the linear model at this node
   *
   * @return the linear model at this node
   */
  public PreConstructedLinearModel getModel() {
    return m_nodeModel;
  }

  /**
   * Return the number of instances that reach this node.
   *
   * @return the number of instances at this node.
   */
  public int getNumInstances() {
    return m_numInstances;
  }

  /**
   * Get the number of parameters in the model at this node
   *
   * @return the number of parameters in the model at this node
   */
  private int numParameters() {
    return m_numParameters;
  } 

  /**
   * Get the value of regressionTree.
   *
   * @return Value of regressionTree.
   */
  public boolean getRegressionTree() {
    
    return m_regressionTree;
  }

  /**
   * Set the minumum number of instances to allow at a leaf node
   *
   * @param minNum the minimum number of instances
   */
  public void setMinNumInstances(double minNum) {
    m_splitNum = minNum;
  }

  /**
   * Get the minimum number of instances to allow at a leaf node
   *
   * @return a double value
   */
  public double getMinNumInstances() {
    return m_splitNum;
  }
  
  /**
   * Set the value of regressionTree.
   *
   * @param newregressionTree Value to assign to regressionTree.
   */
  public void setRegressionTree(boolean newregressionTree) {
    
    m_regressionTree = newregressionTree;
  }
							  
  /**
   * Print all the linear models at the learf (debugging purposes)
   */
  public void printAllModels() {
    if (m_isLeaf) {
      System.out.println(m_nodeModel.toString());
    } else {
      System.out.println(m_nodeModel.toString());
      m_left.printAllModels();
      m_right.printAllModels();
    } 
  } 

  /**
   * Assigns a unique identifier to each node in the tree
   *
   * @param lastID last id number used
   * @return ID after processing child nodes
   */
  protected int assignIDs(int lastID) {
    int currLastID = lastID + 1;
    m_id = currLastID;

    if (m_left != null) {
      currLastID = m_left.assignIDs(currLastID);
    }

    if (m_right != null) {
      currLastID = m_right.assignIDs(currLastID);
    }
    return currLastID;
  }

  /**
   * Assign a unique identifier to each node in the tree and then
   * calls graphTree
   *
   * @param text a StringBuffer value
   */
  public void graph(StringBuffer text) {
    assignIDs(-1);
    graphTree(text);
  }

  /**
   * Return a dotty style string describing the tree
   *
   * @param text a StringBuffer value
   */
  protected void graphTree(StringBuffer text) {
    text.append("N" + m_id
		+ (m_isLeaf 
		   ? " [label=\"LM " + m_leafModelNum
		   : " [label=\"" + Utils.backQuoteChars(m_instances.attribute(m_splitAtt).name()))
		+ (m_isLeaf
		 ? " (" + ((m_globalDeviation > 0.0) 
			  ?  m_numInstances + "/" 
			     + Utils.doubleToString((100.0 * 
						     m_rootMeanSquaredError /
						     m_globalDeviation), 
						    1, 3) 
			     + "%)"
			   : m_numInstances + ")")
		    + "\" shape=box style=filled "
		   : "\"")
		+ (m_saveInstances
		   ? "data=\n" + m_instances + "\n,\n"
		   : "")
		+ "]\n");
		
    if (m_left != null) {
      text.append("N" + m_id + "->" + "N" + m_left.m_id + " [label=\"<="
		  + Utils.doubleToString(m_splitValue, 1, 3)
		  + "\"]\n");
      m_left.graphTree(text);
    }
     
    if (m_right != null) {
      text.append("N" + m_id + "->" + "N" + m_right.m_id + " [label=\">"
		  + Utils.doubleToString(m_splitValue, 1, 3)
		  + "\"]\n");
      m_right.graphTree(text);
    }
  }

  /**
   * Set whether to save instances for visualization purposes.
   * Default is to save memory.
   *
   * @param save a boolean value
   */
  protected void setSaveInstances(boolean save) {
    m_saveInstances = save;
  }
  
  /**
   * Returns the revision string.
   * 
   * @return		the revision
   */
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 1.13 $");
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy