weka.classifiers.mi.miti.TreeNode Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of multiInstanceLearning Show documentation
Show all versions of multiInstanceLearning Show documentation
A collection of multi-instance learning classifiers. Includes the Citation KNN method, several variants of the diverse density method, support vector machines for multi-instance learning, simple wrappers for applying standard propositional learners to multi-instance data, decision tree and rule learners, and some other methods.
The newest version!
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* TreeNode.java
* Copyright (C) 2011 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.mi.miti;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import weka.core.Attribute;
import weka.core.Instance;
/**
* Represents a node in the decision tree.
*
* @author Luke Bjerring
* @version $Revision: 10369 $
*/
public class TreeNode implements Serializable {
/** ID added to avoid warning */
private static final long serialVersionUID = 9050803921532593168L;
// The instances associated with this node
private ArrayList instances;
// The score for this node
private double nodeScore;
// Flag to indicate whether node is a leaf.
private boolean leafNode;
// Flag to indicate whether node is a positive leaf.
private boolean positiveLeaf;
// Reference to parent node
private TreeNode parent = null;
// The left child node in the case of a binary split
private TreeNode left = null;
// The right child node in the case of a binary split
private TreeNode right = null;
// The children in the case of a nominal-attribute split
private TreeNode[] nominalNodes = null;
// The actual split used
public Split split;
/**
* Creates node based on given collection of instances and parent node.
*/
public TreeNode(TreeNode parent, ArrayList instances) {
this.parent = parent;
this.instances = instances;
}
/**
* Returns the score for this node.
*/
public double nodeScore() {
return nodeScore;
}
/**
* Removes deactivated instances from the node. Does NOT update the node
* score.
*/
public void removeDeactivatedInstances(HashMap instanceBags) {
ArrayList newInstances = new ArrayList();
for (Instance i : instances) {
if (instanceBags.get(i).isEnabled()) {
newInstances.add(i);
}
}
instances = newInstances;
}
/**
* Calculates the node score based on the given arguments.
*/
public void calculateNodeScore(HashMap instanceBags,
boolean unbiasedEstimate, int kBEPPConstant, boolean bagCount,
double multiplier) {
nodeScore = NextSplitHeuristic.getBepp(instances, instanceBags,
unbiasedEstimate, kBEPPConstant, bagCount, multiplier);
}
/**
* Is node a leaf?
*/
public boolean isLeafNode() {
return leafNode;
}
/**
* Is node a positive leaf?
*/
public boolean isPositiveLeaf() {
return positiveLeaf;
}
/**
* Checks whether all the instances at the node are associated with positive
* bags.
*/
public boolean isPureNegative(HashMap instanceBags) {
for (Instance i : instances) {
Bag bag = instanceBags.get(i);
if (bag.isEnabled() && bag.isPositive()) {
return false;
}
}
return true;
}
/**
* Checks whether all the instances at the node are associated with negative
* bags.
*/
public boolean isPurePositive(HashMap instanceBags) {
for (Instance i : instances) {
Bag bag = instanceBags.get(i);
if (bag.isEnabled() && !bag.isPositive()) {
return false;
}
}
return true;
}
/**
* Turns the node into a leaf node.
*/
public void makeLeafNode(boolean positiveLeaf) {
leafNode = true;
this.positiveLeaf = positiveLeaf;
}
/**
* Returns the parent.
*/
public TreeNode parent() {
return parent;
}
/**
* Returns the left child in the case of a binary split.
*/
public TreeNode left() {
return left;
}
/**
* Returns the right child in the case of a binary split.
*/
public TreeNode right() {
return right;
}
/**
* Returns the children in the case of a nominal-attribute split.
*/
public TreeNode[] nominals() {
return nominalNodes;
}
/**
* Deactives all instances associated with bags that have at least one
* instance in the current node.
*/
public void deactivateRelatedInstances(HashMap instanceBags,
List deactivated) {
for (Instance i : instances) {
Bag container = instanceBags.get(i);
container.disableInstances(deactivated);
}
}
/**
* @param a attribute to check
* @return true if any parent branch used the same attribute to split on
*/
private boolean hasSplitOnAttributePreviously(Attribute a) {
TreeNode n = this;
while (n != null) {
if (n.split != null && n.split.attribute.equals(a)) {
return true;
}
n = n.parent;
}
return false;
}
/**
* Splits the instances on this node into the best possible child nodes,
* according to the settings
*/
public void splitInstances(HashMap instanceBags,
AlgorithmConfiguration settings, Random rand, boolean debug) {
// All remaining instances are enabled
ArrayList enabled = instances;
int totalAttributes = instances.get(0).numAttributes();
Instance template = instances.get(0);
// Filter to only use attributes that are not constant
List attributes = new ArrayList();
for (int index = 0; index < totalAttributes; index++) {
double val = template.value(index);
for (Instance i : instances) {
if (i.value(index) != val) {
attributes.add(template.attribute(index));
break;
}
}
}
// Choose some random attributes
int attributesToSplit = settings.attributesToSplit;
if (settings.attributesToSplit == -1) {
attributesToSplit = totalAttributes;
}
if (settings.attributesToSplit == -2) {
attributesToSplit = (int) Math.sqrt(totalAttributes) + 1;
}
if (attributesToSplit < attributes.size()) {
// Select a random set of attributes
Collections.shuffle(attributes, rand);
attributes = attributes.subList(0, attributesToSplit);
}
// Collect a list of the split scores
ArrayList best = new ArrayList();
for (Attribute a : attributes) {
if (a.isNominal() && hasSplitOnAttributePreviously(a)) {
continue;
}
Split splitPoint = Split.getBestSplitPoint(a, enabled, instanceBags,
settings);
if (splitPoint == null) {
continue;
}
if (debug) {
System.out.println(a.name() + " scored " + splitPoint.score);
}
best.add(splitPoint);
continue;
}
// If we can't find a split point, make this a leaf node
if (best.size() == 0) {
makeImpureLeafNode(instanceBags, settings, debug);
return;
}
Collections.sort(best, new Comparator() {
@Override
public int compare(Split o1, Split o2) {
return Double.compare(o2.score, o1.score);
}
});
// Get a random best split based on the setting
int attributeSplitChoices = settings.attributeSplitChoices;
if (settings.attributeSplitChoices == -1) {
attributeSplitChoices = best.size();
} else if (settings.attributeSplitChoices == -2) {
attributeSplitChoices = (int) Math.sqrt(best.size()) + 1;
}
int pick = rand.nextInt(Math.min(attributeSplitChoices, best.size()));
split = best.get(pick);
if (debug) {
System.out.println("Selected best is " + split.attribute.name());
}
Attribute splittingAttribute = split.attribute;
if (splittingAttribute.isNominal()) {
// Create a multi-valued nominal-attribute split
int numNominalValues = splittingAttribute.numValues();
nominalNodes = new TreeNode[numNominalValues];
for (int i = 0; i < numNominalValues; i++) {
ArrayList list = new ArrayList();
for (Instance instance : enabled) {
if (instance.value(splittingAttribute) == i) {
list.add(instance);
}
}
nominalNodes[i] = new TreeNode(this, list);
}
} else {
// Create a binary split for a numeric attribute
ArrayList left = new ArrayList();
ArrayList right = new ArrayList();
for (Instance instance : enabled) {
if (instance.value(splittingAttribute) < split.splitPoint) {
left.add(instance);
} else {
right.add(instance);
}
}
this.left = new TreeNode(this, left);
this.right = new TreeNode(this, right);
if (debug) {
System.out.println(left.size() + " went left and " + right.size()
+ " went right");
}
}
}
/**
* Code to cover special case where impure leaf node needs to be created
* because data cannot be split any further.
*/
private void makeImpureLeafNode(HashMap instanceBags,
AlgorithmConfiguration settings, boolean debug) {
SufficientStatistics ss;
if (!settings.useBagStatistics) {
ss = new SufficientInstanceStatistics(instances, instanceBags);
} else {
ss = new SufficientBagStatistics(instances, instanceBags,
settings.bagCountMultiplier);
}
double bepp = BEPP.GetBEPP(ss.totalCountRight(), ss.positiveCountRight(),
settings.kBEPPConstant, settings.unbiasedEstimate);
makeLeafNode(ss.positiveCountRight() / ss.totalCountRight() > 0.5);
if (debug) {
System.out.println(bepp > 0.5);
}
// Deactivate the related instances if we decided this
// is a positive instance
if (!isPositiveLeaf()) {
return;
}
ArrayList deactivated = new ArrayList();
deactivateRelatedInstances(instanceBags, deactivated);
// Print out any deactivated bags if we're debugging
if (deactivated.size() > 0 && debug) {
Bag.printDeactivatedInstances(deactivated);
}
}
/**
* Recursively renders this node and its branches as a tabbed out tree
*
* @return a string containing the node and its children, tabbed to the given
* depth
*/
public String render(int depth, HashMap instanceBags) {
String s = "";
int pos = 0;
for (Instance i : instances) {
Bag bag = instanceBags.get(i);
if (bag.isPositive()) {
pos++;
}
}
s += instances.size() + " [" + pos + " / " + (instances.size() - pos) + "]";
if (isLeafNode()) {
s += isPositiveLeaf() ? " (+)" : " (-)";
}
if (!isLeafNode() && split != null) {
if (split.attribute.isNominal()) {
for (int i = 0; i < nominalNodes.length; i++) {
if (nominalNodes[i] != null) {
// New line, tab it out.
s += "\n";
for (int t = 0; t < depth; t++) {
s += "|\t";
}
s += split.attribute.name() + " = " + split.attribute.value(i)
+ " : ";
s += nominalNodes[i].render(depth + 1, instanceBags);
}
}
} else {
if (left != null) {
// New line, tab it out.
s += "\n";
for (int i = 0; i < depth; i++) {
s += "|\t";
}
s += split.attribute.name() + " <= "
+ String.format("%.4g", split.splitPoint) + " : ";
s += left.render(depth + 1, instanceBags);
}
if (right != null) {
// New line, tab it out.
s += "\n";
for (int i = 0; i < depth; i++) {
s += "|\t";
}
s += split.attribute.name() + " > "
+ String.format("%.4g", split.splitPoint) + " : ";
s += right.render(depth + 1, instanceBags);
}
}
}
return s;
}
/**
* Recursively removes all branches that do not contain a positive leaf. Used
* to create partial tree for MIRI rule learner.
*
* @return true if a positive leaf was encountered during the trim
*/
public boolean trimNegativeBranches() {
boolean positive = false;
if (nominalNodes != null) {
// Consider nominal split
for (int i = 0; i < nominalNodes.length; i++) {
TreeNode child = nominalNodes[i];
if (child.isPositiveLeaf()) {
positive = true;
} else if (child.trimNegativeBranches()) {
positive = true;
} else {
nominalNodes[i] = null;
}
}
} else {
// Consider numeric split
if (left != null) {
if (left.isPositiveLeaf()) {
positive = true;
} else if (left.trimNegativeBranches()) {
positive = true;
} else {
left = null;
}
}
if (right != null) {
if (right.isPositiveLeaf()) {
positive = true;
} else if (right.trimNegativeBranches()) {
positive = true;
} else {
right = null;
}
}
}
return positive;
}
}