weka.classifiers.trees.ht.SplitNode Maven / Gradle / Ivy
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* SplitNode.java
* Copyright (C) 2013 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees.ht;
import java.util.LinkedHashMap;
import java.util.Map;
import weka.core.Instance;
/**
* Class for a node that splits the data in a Hoeffding tree
*
* @author Richard Kirkby ([email protected])
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @version $Revision: 10169 $
*/
public class SplitNode extends HNode {
/**
* For serialization
*/
private static final long serialVersionUID = 1558033628618451073L;
/** The split itself */
protected Split m_split;
/** Child nodes */
protected Map m_children = new LinkedHashMap();
/**
* Construct a new SplitNode
*
* @param classDistrib the class distribution
* @param split the split
*/
public SplitNode(Map classDistrib, Split split) {
super(classDistrib);
m_split = split;
}
/**
* Return the branch that the supplied instance goes down
*
* @param inst the instance to find the branch for
* @return the branch that the supplied instance goes down
*/
public String branchForInstance(Instance inst) {
return m_split.branchForInstance(inst);
}
@Override
public boolean isLeaf() {
return false;
}
/**
* Number of child nodes
*
* @return the number of child nodes
*/
public int numChildred() {
return m_children.size();
}
/**
* Add a child
*
* @param branch the branch for the child
* @param child the child itself
*/
public void setChild(String branch, HNode child) {
m_children.put(branch, child);
}
@Override
public LeafNode leafForInstance(Instance inst, SplitNode parent,
String parentBranch) {
String branch = branchForInstance(inst);
if (branch != null) {
HNode child = m_children.get(branch);
if (child != null) {
return child.leafForInstance(inst, this, branch);
}
return new LeafNode(null, this, branch);
}
return new LeafNode(this, parent, parentBranch);
}
@Override
public void updateNode(Instance inst) {
// don't update the distribution
}
@Override
protected int dumpTree(int depth, int leafCount, StringBuffer buff) {
for (Map.Entry e : m_children.entrySet()) {
HNode child = e.getValue();
String branch = e.getKey();
if (child != null) {
buff.append("\n");
for (int i = 0; i < depth; i++) {
buff.append("| ");
}
buff.append(m_split.conditionForBranch(branch).trim());
buff.append(": ");
leafCount = child.dumpTree(depth + 1, leafCount, buff);
}
}
return leafCount;
}
@Override
public int installNodeNums(int nodeNum) {
nodeNum = super.installNodeNums(nodeNum);
for (Map.Entry e : m_children.entrySet()) {
HNode child = e.getValue();
if (child != null) {
nodeNum = child.installNodeNums(nodeNum);
}
}
return nodeNum;
}
@Override
public void graphTree(StringBuffer buff) {
boolean first = true;
for (Map.Entry e : m_children.entrySet()) {
HNode child = e.getValue();
String branch = e.getKey();
if (child != null) {
String conditionForBranch = m_split.conditionForBranch(branch);
if (first) {
String testAttName = null;
if (conditionForBranch.indexOf("<=") < 0) {
testAttName = conditionForBranch.substring(0,
conditionForBranch.indexOf("=")).trim();
} else {
testAttName = conditionForBranch.substring(0,
conditionForBranch.indexOf("<")).trim();
}
first = false;
buff.append("N" + m_nodeNum + " [label=\"" + testAttName + "\"]\n");
}
int startIndex = 0;
if (conditionForBranch.indexOf("<=") > 0) {
startIndex = conditionForBranch.indexOf("<") - 1;
} else if (conditionForBranch.indexOf("=") > 0) {
startIndex = conditionForBranch.indexOf("=") - 1;
} else {
startIndex = conditionForBranch.indexOf(">") - 1;
}
conditionForBranch = conditionForBranch.substring(startIndex,
conditionForBranch.length()).trim();
buff.append(
"N" + m_nodeNum + "->" + "N" + child.m_nodeNum + "[label=\""
+ conditionForBranch + "\"]\n").append("\n");
}
}
for (Map.Entry e : m_children.entrySet()) {
HNode child = e.getValue();
if (child != null) {
child.graphTree(buff);
}
}
}
@Override
protected void printLeafModels(StringBuffer buff) {
for (Map.Entry e : m_children.entrySet()) {
HNode child = e.getValue();
if (child != null) {
child.printLeafModels(buff);
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy