weka.classifiers.trees.ht.HNode Maven / Gradle / Ivy
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* HNode.java
* Copyright (C) 2013 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees.ht;
import java.io.Serializable;
import java.util.LinkedHashMap;
import java.util.Map;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Utils;
/**
* Abstract base class for nodes in a Hoeffding tree
*
* @author Richard Kirkby ([email protected])
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @revision $Revision: 9707 $
*/
public abstract class HNode implements Serializable {
/**
* For serialization
*/
private static final long serialVersionUID = 197233928177240264L;
/** Class distribution at this node */
public Map m_classDistribution = new LinkedHashMap();
/** Holds the leaf number (if this is a leaf) */
protected int m_leafNum;
/** Holds the node number (for graphing purposes) */
protected int m_nodeNum;
/**
* Construct a new HNode
*/
public HNode() {
}
/**
* Construct a new HNode with the supplied class distribution
*
* @param classDistrib
*/
public HNode(Map classDistrib) {
m_classDistribution = classDistrib;
}
/**
* Returns true if this is a leaf
*
* @return
*/
public boolean isLeaf() {
return true;
}
/**
* The size of the class distribution
*
* @return the number of entries in the class distribution
*/
public int numEntriesInClassDistribution() {
return m_classDistribution.size();
}
/**
* Returns true if the class distribution is pure
*
* @return true if the class distribution is pure
*/
public boolean classDistributionIsPure() {
int count = 0;
for (Map.Entry el : m_classDistribution.entrySet()) {
if (el.getValue().m_weight > 0) {
count++;
if (count > 1) {
break;
}
}
}
return (count < 2);
}
/**
* Update the class frequency distribution with the supplied instance
*
* @param inst the instance to update with
*/
public void updateDistribution(Instance inst) {
if (inst.classIsMissing()) {
return;
}
String classVal = inst.stringValue(inst.classAttribute());
WeightMass m = m_classDistribution.get(classVal);
if (m == null) {
m = new WeightMass();
m.m_weight = 1.0;
m_classDistribution.put(classVal, m);
}
m.m_weight += inst.weight();
}
/**
* Return a class probability distribution computed from the frequency counts
* at this node
*
* @param inst the instance to get a prediction for
* @param classAtt the class attribute
* @return a class probability distribution
* @throws Exception if a problem occurs
*/
public double[] getDistribution(Instance inst, Attribute classAtt)
throws Exception {
double[] dist = new double[classAtt.numValues()];
for (int i = 0; i < classAtt.numValues(); i++) {
WeightMass w = m_classDistribution.get(classAtt.value(i));
if (w != null) {
dist[i] = w.m_weight;
} else {
dist[i] = 1.0;
}
}
Utils.normalize(dist);
return dist;
}
public int installNodeNums(int nodeNum) {
nodeNum++;
m_nodeNum = nodeNum;
return nodeNum;
}
protected int dumpTree(int depth, int leafCount, StringBuffer buff) {
double max = -1;
String classVal = "";
for (Map.Entry e : m_classDistribution.entrySet()) {
if (e.getValue().m_weight > max) {
max = e.getValue().m_weight;
classVal = e.getKey();
}
}
buff.append(classVal + " (" + String.format("%-9.3f", max).trim() + ")");
leafCount++;
m_leafNum = leafCount;
return leafCount;
}
protected void printLeafModels(StringBuffer buff) {
}
public void graphTree(StringBuffer text) {
double max = -1;
String classVal = "";
for (Map.Entry e : m_classDistribution.entrySet()) {
if (e.getValue().m_weight > max) {
max = e.getValue().m_weight;
classVal = e.getKey();
}
}
text.append("N" + m_nodeNum + " [label=\"" + classVal + " ("
+ String.format("%-9.3f", max).trim() + ")\" shape=box style=filled]\n");
}
/**
* Print a textual description of the tree
*
* @param printLeaf true if leaf models (NB, NB adaptive) should be output
* @return a textual description of the tree
*/
public String toString(boolean printLeaf) {
installNodeNums(0);
StringBuffer buff = new StringBuffer();
dumpTree(0, 0, buff);
if (printLeaf) {
buff.append("\n\n");
printLeafModels(buff);
}
return buff.toString();
}
/**
* Return the total weight of instances seen at this node
*
* @return the total weight of instances seen at this node
*/
public double totalWeight() {
double tw = 0;
for (Map.Entry e : m_classDistribution.entrySet()) {
tw += e.getValue().m_weight;
}
return tw;
}
/**
* Return the leaf that the supplied instance ends up at
*
* @param inst the instance to find the leaf for
* @param parent the parent node
* @param parentBranch the parent branch
* @return the leaf that the supplied instance ends up at
*/
public LeafNode leafForInstance(Instance inst, SplitNode parent,
String parentBranch) {
return new LeafNode(this, parent, parentBranch);
}
/**
* Update the node with the supplied instance
*
* @param inst the instance to update with
* @throws Exception if a problem occurs
*/
public abstract void updateNode(Instance inst) throws Exception;
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy