![JAR search and dependency download from the Maven repository](/logo.png)
moa.classifiers.trees.HoeffdingTree Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of moa Show documentation
Show all versions of moa Show documentation
Massive On-line Analysis is an environment for massive data mining. MOA
provides a framework for data stream mining and includes tools for evaluation
and a collection of machine learning algorithms. Related to the WEKA project,
also written in Java, while scaling to more demanding problems.
The newest version!
/*
* HoeffdingTree.java
* Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
* @author Richard Kirkby ([email protected])
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*
*/
package moa.classifiers.trees;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import moa.AbstractMOAObject;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.bayes.NaiveBayes;
import moa.classifiers.core.AttributeSplitSuggestion;
import moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
import moa.classifiers.core.attributeclassobservers.DiscreteAttributeClassObserver;
import moa.classifiers.core.attributeclassobservers.NullAttributeClassObserver;
import moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver;
import moa.classifiers.core.conditionaltests.InstanceConditionalTest;
import moa.classifiers.core.splitcriteria.SplitCriterion;
import moa.core.AutoExpandVector;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.SizeOf;
import moa.core.StringUtils;
import moa.core.Utils;
import moa.options.ClassOption;
import com.yahoo.labs.samoa.instances.Instance;
/**
* Hoeffding Tree or VFDT.
*
* A Hoeffding tree is an incremental, anytime decision tree induction algorithm
* that is capable of learning from massive data streams, assuming that the
* distribution generating examples does not change over time. Hoeffding trees
* exploit the fact that a small sample can often be enough to choose an optimal
* splitting attribute. This idea is supported mathematically by the Hoeffding
* bound, which quantifies the number of observations (in our case, examples)
* needed to estimate some statistics within a prescribed precision (in our
* case, the goodness of an attribute). A theoretically appealing feature
* of Hoeffding Trees not shared by other incremental decision tree learners is
* that it has sound guarantees of performance. Using the Hoeffding bound one
* can show that its output is asymptotically nearly identical to that of a
* non-incremental learner using infinitely many examples. See for details:
*
* G. Hulten, L. Spencer, and P. Domingos. Mining time-changing data streams.
* In KDD’01, pages 97–106, San Francisco, CA, 2001. ACM Press.
*
* Parameters:
- -m : Maximum memory consumed by the tree
* - -n : Numeric estimator to use :
- Gaussian approximation
* evaluating 10 splitpoints
- Gaussian approximation evaluating 100
* splitpoints
- Greenwald-Khanna quantile summary with 10 tuples
* - Greenwald-Khanna quantile summary with 100 tuples
* - Greenwald-Khanna quantile summary with 1000 tuples
- VFML method
* with 10 bins
- VFML method with 100 bins
- VFML method with
* 1000 bins
- Exhaustive binary tree
- -e : How many
* instances between memory consumption checks
- -g : The number of
* instances a leaf should observe between split attempts
- -s : Split
* criterion to use. Example : InfoGainSplitCriterion
- -c : The
* allowable error in split decision, values closer to 0 will take longer to
* decide
- -t : Threshold below which a split will be forced to break
* ties
- -b : Only allow binary splits
- -z : Stop growing as
* soon as memory limit is hit
- -r : Disable poor attributes
-
* -p : Disable pre-pruning
* - -l : Leaf prediction to use: MajorityClass (MC), Naive Bayes (NB) or NaiveBayes
* adaptive (NBAdaptive).
* - -q : The number of instances a leaf should observe before
* permitting Naive Bayes
*
*
* @author Richard Kirkby ([email protected])
* @version $Revision: 7 $
*/
public class HoeffdingTree extends AbstractClassifier implements MultiClassClassifier,
CapabilitiesHandler {
private static final long serialVersionUID = 1L;
@Override
public String getPurposeString() {
return "Hoeffding Tree or VFDT.";
}
public IntOption maxByteSizeOption = new IntOption("maxByteSize", 'm',
"Maximum memory consumed by the tree.", 33554432, 0,
Integer.MAX_VALUE);
/*
* public MultiChoiceOption numericEstimatorOption = new MultiChoiceOption(
* "numericEstimator", 'n', "Numeric estimator to use.", new String[]{
* "GAUSS10", "GAUSS100", "GK10", "GK100", "GK1000", "VFML10", "VFML100",
* "VFML1000", "BINTREE"}, new String[]{ "Gaussian approximation evaluating
* 10 splitpoints", "Gaussian approximation evaluating 100 splitpoints",
* "Greenwald-Khanna quantile summary with 10 tuples", "Greenwald-Khanna
* quantile summary with 100 tuples", "Greenwald-Khanna quantile summary
* with 1000 tuples", "VFML method with 10 bins", "VFML method with 100
* bins", "VFML method with 1000 bins", "Exhaustive binary tree"}, 0);
*/
public ClassOption numericEstimatorOption = new ClassOption("numericEstimator",
'n', "Numeric estimator to use.", NumericAttributeClassObserver.class,
"GaussianNumericAttributeClassObserver");
public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator",
'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class,
"NominalAttributeClassObserver");
public IntOption memoryEstimatePeriodOption = new IntOption(
"memoryEstimatePeriod", 'e',
"How many instances between memory consumption checks.", 1000000,
0, Integer.MAX_VALUE);
public IntOption gracePeriodOption = new IntOption(
"gracePeriod",
'g',
"The number of instances a leaf should observe between split attempts.",
200, 0, Integer.MAX_VALUE);
public ClassOption splitCriterionOption = new ClassOption("splitCriterion",
's', "Split criterion to use.", SplitCriterion.class,
"InfoGainSplitCriterion");
public FloatOption splitConfidenceOption = new FloatOption(
"splitConfidence",
'c',
"The allowable error in split decision, values closer to 0 will take longer to decide.",
0.0000001, 0.0, 1.0);
public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
't', "Threshold below which a split will be forced to break ties.",
0.05, 0.0, 1.0);
public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b',
"Only allow binary splits.");
public FlagOption stopMemManagementOption = new FlagOption(
"stopMemManagement", 'z',
"Stop growing as soon as memory limit is hit.");
public FlagOption removePoorAttsOption = new FlagOption("removePoorAtts",
'r', "Disable poor attributes.");
public FlagOption noPrePruneOption = new FlagOption("noPrePrune", 'p',
"Disable pre-pruning.");
public static class FoundNode {
public Node node;
public SplitNode parent;
public int parentBranch;
public FoundNode(Node node, SplitNode parent, int parentBranch) {
this.node = node;
this.parent = parent;
this.parentBranch = parentBranch;
}
}
public static class Node extends AbstractMOAObject {
private static final long serialVersionUID = 1L;
protected DoubleVector observedClassDistribution;
public Node(double[] classObservations) {
this.observedClassDistribution = new DoubleVector(classObservations);
}
public long calcByteSize() {
return SizeOf.sizeOf(this) + SizeOf.fullSizeOf(this.observedClassDistribution);
}
public long calcByteSizeIncludingSubtree() {
return calcByteSize();
}
public boolean isLeaf() {
return true;
}
public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent,
int parentBranch) {
return new FoundNode(this, parent, parentBranch);
}
public double[] getObservedClassDistribution() {
return this.observedClassDistribution.getArrayCopy();
}
public double[] getObservedClassDistributionAtLeavesReachableThroughThisNode() {
return this.observedClassDistribution.getArrayCopy();
}
public double[] getClassVotes(Instance inst, HoeffdingTree ht) {
return this.observedClassDistribution.getArrayCopy();
}
public boolean observedClassDistributionIsPure() {
return this.observedClassDistribution.numNonZeroEntries() < 2;
}
public void describeSubtree(HoeffdingTree ht, StringBuilder out,
int indent) {
StringUtils.appendIndented(out, indent, "Leaf ");
out.append(ht.getClassNameString());
out.append(" = ");
out.append(ht.getClassLabelString(this.observedClassDistribution.maxIndex()));
out.append(" weights: ");
this.observedClassDistribution.getSingleLineDescription(out,
ht.treeRoot.observedClassDistribution.numValues());
StringUtils.appendNewline(out);
}
public int subtreeDepth() {
return 0;
}
public double calculatePromise() {
double totalSeen = this.observedClassDistribution.sumOfValues();
return totalSeen > 0.0 ? (totalSeen - this.observedClassDistribution.getValue(this.observedClassDistribution.maxIndex()))
: 0.0;
}
@Override
public void getDescription(StringBuilder sb, int indent) {
describeSubtree(null, sb, indent);
}
}
public static class SplitNode extends Node {
private static final long serialVersionUID = 1L;
protected InstanceConditionalTest splitTest;
protected AutoExpandVector children; // = new AutoExpandVector();
@Override
public long calcByteSize() {
return super.calcByteSize()
+ SizeOf.sizeOf(this.children) + SizeOf.fullSizeOf(this.splitTest);
}
@Override
public long calcByteSizeIncludingSubtree() {
long byteSize = calcByteSize();
for (Node child : this.children) {
if (child != null) {
byteSize += child.calcByteSizeIncludingSubtree();
}
}
return byteSize;
}
@Override
public double[] getObservedClassDistributionAtLeavesReachableThroughThisNode() {
// Start a new DoubleVector with 0 in all positions.
DoubleVector sumObservedClassDistributionAtLeaves =
new DoubleVector(new double[this.getObservedClassDistribution().length]);
for(Node childNode : this.children) {
if(childNode != null) {
double[] childDist = childNode.getObservedClassDistributionAtLeavesReachableThroughThisNode();
sumObservedClassDistributionAtLeaves.addValues(childDist);
}
}
return sumObservedClassDistributionAtLeaves.getArrayCopy();
}
public AutoExpandVector getChildren() {
return children;
}
public InstanceConditionalTest getSplitTest() {
return splitTest;
}
public SplitNode(InstanceConditionalTest splitTest,
double[] classObservations, int size) {
super(classObservations);
this.splitTest = splitTest;
this.children = new AutoExpandVector(size);
}
public SplitNode(InstanceConditionalTest splitTest,
double[] classObservations) {
super(classObservations);
this.splitTest = splitTest;
this.children = new AutoExpandVector();
}
public int numChildren() {
return this.children.size();
}
public void setChild(int index, Node child) {
if ((this.splitTest.maxBranches() >= 0)
&& (index >= this.splitTest.maxBranches())) {
throw new IndexOutOfBoundsException();
}
this.children.set(index, child);
}
public Node getChild(int index) {
return this.children.get(index);
}
public int instanceChildIndex(Instance inst) {
return this.splitTest.branchForInstance(inst);
}
@Override
public boolean isLeaf() {
return false;
}
@Override
public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent,
int parentBranch) {
int childIndex = instanceChildIndex(inst);
if (childIndex >= 0) {
Node child = getChild(childIndex);
if (child != null) {
return child.filterInstanceToLeaf(inst, this, childIndex);
}
return new FoundNode(null, this, childIndex);
}
return new FoundNode(this, parent, parentBranch);
}
@Override
public void describeSubtree(HoeffdingTree ht, StringBuilder out,
int indent) {
for (int branch = 0; branch < numChildren(); branch++) {
Node child = getChild(branch);
if (child != null) {
StringUtils.appendIndented(out, indent, "if ");
out.append(this.splitTest.describeConditionForBranch(branch,
ht.getModelContext()));
out.append(": ");
StringUtils.appendNewline(out);
child.describeSubtree(ht, out, indent + 2);
}
}
}
@Override
public int subtreeDepth() {
int maxChildDepth = 0;
for (Node child : this.children) {
if (child != null) {
int depth = child.subtreeDepth();
if (depth > maxChildDepth) {
maxChildDepth = depth;
}
}
}
return maxChildDepth + 1;
}
}
public static abstract class LearningNode extends Node {
private static final long serialVersionUID = 1L;
public LearningNode(double[] initialClassObservations) {
super(initialClassObservations);
}
public abstract void learnFromInstance(Instance inst, HoeffdingTree ht);
}
public static class InactiveLearningNode extends LearningNode {
private static final long serialVersionUID = 1L;
public InactiveLearningNode(double[] initialClassObservations) {
super(initialClassObservations);
}
@Override
public void learnFromInstance(Instance inst, HoeffdingTree ht) {
this.observedClassDistribution.addToValue((int) inst.classValue(),
inst.weight());
}
}
public static class ActiveLearningNode extends LearningNode {
private static final long serialVersionUID = 1L;
protected double weightSeenAtLastSplitEvaluation;
protected AutoExpandVector attributeObservers = new AutoExpandVector();
protected boolean isInitialized;
public ActiveLearningNode(double[] initialClassObservations) {
super(initialClassObservations);
this.weightSeenAtLastSplitEvaluation = getWeightSeen();
this.isInitialized = false;
}
@Override
public long calcByteSize() {
return super.calcByteSize()
+ SizeOf.fullSizeOf(this.attributeObservers);
}
@Override
public void learnFromInstance(Instance inst, HoeffdingTree ht) {
if (this.isInitialized == false) {
this.attributeObservers = new AutoExpandVector(inst.numAttributes());
this.isInitialized = true;
}
this.observedClassDistribution.addToValue((int) inst.classValue(),
inst.weight());
for (int i = 0; i < inst.numAttributes() - 1; i++) {
int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
AttributeClassObserver obs = this.attributeObservers.get(i);
if (obs == null) {
obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver();
this.attributeObservers.set(i, obs);
}
obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight());
}
}
public double getWeightSeen() {
return this.observedClassDistribution.sumOfValues();
}
public double getWeightSeenAtLastSplitEvaluation() {
return this.weightSeenAtLastSplitEvaluation;
}
public void setWeightSeenAtLastSplitEvaluation(double weight) {
this.weightSeenAtLastSplitEvaluation = weight;
}
public AttributeSplitSuggestion[] getBestSplitSuggestions(
SplitCriterion criterion, HoeffdingTree ht) {
List bestSuggestions = new LinkedList();
double[] preSplitDist = this.observedClassDistribution.getArrayCopy();
if (!ht.noPrePruneOption.isSet()) {
// add null split as an option
bestSuggestions.add(new AttributeSplitSuggestion(null,
new double[0][], criterion.getMeritOfSplit(
preSplitDist,
new double[][]{preSplitDist})));
}
for (int i = 0; i < this.attributeObservers.size(); i++) {
AttributeClassObserver obs = this.attributeObservers.get(i);
if (obs != null) {
AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion,
preSplitDist, i, ht.binarySplitsOption.isSet());
if (bestSuggestion != null) {
bestSuggestions.add(bestSuggestion);
}
}
}
return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]);
}
public void disableAttribute(int attIndex) {
this.attributeObservers.set(attIndex,
new NullAttributeClassObserver());
}
}
protected Node treeRoot;
protected int decisionNodeCount;
protected int activeLeafNodeCount;
protected int inactiveLeafNodeCount;
protected double inactiveLeafByteSizeEstimate;
protected double activeLeafByteSizeEstimate;
protected double byteSizeEstimateOverheadFraction;
protected boolean growthAllowed;
public long calcByteSize() {
long size = SizeOf.sizeOf(this);
if (this.treeRoot != null) {
size += this.treeRoot.calcByteSizeIncludingSubtree();
}
return size;
}
public int getNodeCount() {
return this.decisionNodeCount + this.activeLeafNodeCount + this.inactiveLeafNodeCount;
}
public Node getTreeRoot() {
return this.treeRoot;
}
@Override
public long measureByteSize() {
return calcByteSize();
}
@Override
public void resetLearningImpl() {
this.treeRoot = null;
this.decisionNodeCount = 0;
this.activeLeafNodeCount = 0;
this.inactiveLeafNodeCount = 0;
this.inactiveLeafByteSizeEstimate = 0.0;
this.activeLeafByteSizeEstimate = 0.0;
this.byteSizeEstimateOverheadFraction = 1.0;
this.growthAllowed = true;
if (this.leafpredictionOption.getChosenIndex()>0) {
this.removePoorAttsOption = null;
}
}
@Override
public void trainOnInstanceImpl(Instance inst) {
if (this.treeRoot == null) {
this.treeRoot = newLearningNode();
this.activeLeafNodeCount = 1;
}
FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1);
Node leafNode = foundNode.node;
if (leafNode == null) {
leafNode = newLearningNode();
foundNode.parent.setChild(foundNode.parentBranch, leafNode);
this.activeLeafNodeCount++;
}
if (leafNode instanceof LearningNode) {
LearningNode learningNode = (LearningNode) leafNode;
learningNode.learnFromInstance(inst, this);
if (this.growthAllowed
&& (learningNode instanceof ActiveLearningNode)) {
ActiveLearningNode activeLearningNode = (ActiveLearningNode) learningNode;
double weightSeen = activeLearningNode.getWeightSeen();
if (weightSeen
- activeLearningNode.getWeightSeenAtLastSplitEvaluation() >= this.gracePeriodOption.getValue()) {
attemptToSplit(activeLearningNode, foundNode.parent,
foundNode.parentBranch);
activeLearningNode.setWeightSeenAtLastSplitEvaluation(weightSeen);
}
}
}
if (this.trainingWeightSeenByModel
% this.memoryEstimatePeriodOption.getValue() == 0) {
estimateModelByteSizes();
}
}
@Override
public double[] getVotesForInstance(Instance inst) {
if (this.treeRoot != null) {
FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst,
null, -1);
Node leafNode = foundNode.node;
if (leafNode == null) {
leafNode = foundNode.parent;
}
return leafNode.getClassVotes(inst, this);
} else {
int numClasses = inst.dataset().numClasses();
return new double[numClasses];
}
}
@Override
protected Measurement[] getModelMeasurementsImpl() {
return new Measurement[]{
new Measurement("tree size (nodes)", this.decisionNodeCount
+ this.activeLeafNodeCount + this.inactiveLeafNodeCount),
new Measurement("tree size (leaves)", this.activeLeafNodeCount
+ this.inactiveLeafNodeCount),
new Measurement("active learning leaves",
this.activeLeafNodeCount),
new Measurement("tree depth", measureTreeDepth()),
new Measurement("active leaf byte size estimate",
this.activeLeafByteSizeEstimate),
new Measurement("inactive leaf byte size estimate",
this.inactiveLeafByteSizeEstimate),
new Measurement("byte size estimate overhead",
this.byteSizeEstimateOverheadFraction)};
}
public int measureTreeDepth() {
if (this.treeRoot != null) {
return this.treeRoot.subtreeDepth();
}
return 0;
}
@Override
public void getModelDescription(StringBuilder out, int indent) {
this.treeRoot.describeSubtree(this, out, indent);
}
@Override
public boolean isRandomizable() {
return false;
}
public static double computeHoeffdingBound(double range, double confidence,
double n) {
return Math.sqrt(((range * range) * Math.log(1.0 / confidence))
/ (2.0 * n));
}
//Procedure added for Hoeffding Adaptive Trees (ADWIN)
protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
double[] classObservations, int size) {
return new SplitNode(splitTest, classObservations, size);
}
protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
double[] classObservations) {
return new SplitNode(splitTest, classObservations);
}
protected AttributeClassObserver newNominalClassObserver() {
AttributeClassObserver nominalClassObserver = (AttributeClassObserver) getPreparedClassOption(this.nominalEstimatorOption);
return (AttributeClassObserver) nominalClassObserver.copy();
}
protected AttributeClassObserver newNumericClassObserver() {
AttributeClassObserver numericClassObserver = (AttributeClassObserver) getPreparedClassOption(this.numericEstimatorOption);
return (AttributeClassObserver) numericClassObserver.copy();
}
protected void attemptToSplit(ActiveLearningNode node, SplitNode parent,
int parentIndex) {
if (!node.observedClassDistributionIsPure()) {
SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(this.splitCriterionOption);
AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, this);
Arrays.sort(bestSplitSuggestions);
boolean shouldSplit = false;
if (bestSplitSuggestions.length < 2) {
shouldSplit = bestSplitSuggestions.length > 0;
} else {
double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()),
this.splitConfidenceOption.getValue(), node.getWeightSeen());
AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
AttributeSplitSuggestion secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2];
if ((bestSuggestion.merit - secondBestSuggestion.merit > hoeffdingBound)
|| (hoeffdingBound < this.tieThresholdOption.getValue())) {
shouldSplit = true;
}
// }
if ((this.removePoorAttsOption != null)
&& this.removePoorAttsOption.isSet()) {
Set poorAtts = new HashSet();
// scan 1 - add any poor to set
for (int i = 0; i < bestSplitSuggestions.length; i++) {
if (bestSplitSuggestions[i].splitTest != null) {
int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn();
if (splitAtts.length == 1) {
if (bestSuggestion.merit
- bestSplitSuggestions[i].merit > hoeffdingBound) {
poorAtts.add(new Integer(splitAtts[0]));
}
}
}
}
// scan 2 - remove good ones from set
for (int i = 0; i < bestSplitSuggestions.length; i++) {
if (bestSplitSuggestions[i].splitTest != null) {
int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn();
if (splitAtts.length == 1) {
if (bestSuggestion.merit
- bestSplitSuggestions[i].merit < hoeffdingBound) {
poorAtts.remove(new Integer(splitAtts[0]));
}
}
}
}
for (int poorAtt : poorAtts) {
node.disableAttribute(poorAtt);
}
}
}
if (shouldSplit) {
AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1];
if (splitDecision.splitTest == null) {
// preprune - null wins
deactivateLearningNode(node, parent, parentIndex);
} else {
SplitNode newSplit = newSplitNode(splitDecision.splitTest,
node.getObservedClassDistribution(),splitDecision.numSplits() );
for (int i = 0; i < splitDecision.numSplits(); i++) {
Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i));
newSplit.setChild(i, newChild);
}
this.activeLeafNodeCount--;
this.decisionNodeCount++;
this.activeLeafNodeCount += splitDecision.numSplits();
if (parent == null) {
this.treeRoot = newSplit;
} else {
parent.setChild(parentIndex, newSplit);
}
}
// manage memory
enforceTrackerLimit();
}
}
}
public void enforceTrackerLimit() {
if ((this.inactiveLeafNodeCount > 0)
|| ((this.activeLeafNodeCount * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount
* this.inactiveLeafByteSizeEstimate)
* this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue())) {
if (this.stopMemManagementOption.isSet()) {
this.growthAllowed = false;
return;
}
FoundNode[] learningNodes = findLearningNodes();
Arrays.sort(learningNodes, new Comparator() {
@Override
public int compare(FoundNode fn1, FoundNode fn2) {
return Double.compare(fn1.node.calculatePromise(), fn2.node.calculatePromise());
}
});
int maxActive = 0;
while (maxActive < learningNodes.length) {
maxActive++;
if ((maxActive * this.activeLeafByteSizeEstimate + (learningNodes.length - maxActive)
* this.inactiveLeafByteSizeEstimate)
* this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue()) {
maxActive--;
break;
}
}
int cutoff = learningNodes.length - maxActive;
for (int i = 0; i < cutoff; i++) {
if (learningNodes[i].node instanceof ActiveLearningNode) {
deactivateLearningNode(
(ActiveLearningNode) learningNodes[i].node,
learningNodes[i].parent,
learningNodes[i].parentBranch);
}
}
for (int i = cutoff; i < learningNodes.length; i++) {
if (learningNodes[i].node instanceof InactiveLearningNode) {
activateLearningNode(
(InactiveLearningNode) learningNodes[i].node,
learningNodes[i].parent,
learningNodes[i].parentBranch);
}
}
}
}
public void estimateModelByteSizes() {
FoundNode[] learningNodes = findLearningNodes();
long totalActiveSize = 0;
long totalInactiveSize = 0;
for (FoundNode foundNode : learningNodes) {
if (foundNode.node instanceof ActiveLearningNode) {
totalActiveSize += SizeOf.fullSizeOf(foundNode.node);
} else {
totalInactiveSize += SizeOf.fullSizeOf(foundNode.node);
}
}
if (totalActiveSize > 0) {
this.activeLeafByteSizeEstimate = (double) totalActiveSize
/ this.activeLeafNodeCount;
}
if (totalInactiveSize > 0) {
this.inactiveLeafByteSizeEstimate = (double) totalInactiveSize
/ this.inactiveLeafNodeCount;
}
long actualModelSize = this.measureByteSize();
double estimatedModelSize = (this.activeLeafNodeCount
* this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount
* this.inactiveLeafByteSizeEstimate);
this.byteSizeEstimateOverheadFraction = actualModelSize
/ estimatedModelSize;
if (actualModelSize > this.maxByteSizeOption.getValue()) {
enforceTrackerLimit();
}
}
public void deactivateAllLeaves() {
FoundNode[] learningNodes = findLearningNodes();
for (int i = 0; i < learningNodes.length; i++) {
if (learningNodes[i].node instanceof ActiveLearningNode) {
deactivateLearningNode(
(ActiveLearningNode) learningNodes[i].node,
learningNodes[i].parent, learningNodes[i].parentBranch);
}
}
}
protected void deactivateLearningNode(ActiveLearningNode toDeactivate,
SplitNode parent, int parentBranch) {
Node newLeaf = new InactiveLearningNode(toDeactivate.getObservedClassDistribution());
if (parent == null) {
this.treeRoot = newLeaf;
} else {
parent.setChild(parentBranch, newLeaf);
}
this.activeLeafNodeCount--;
this.inactiveLeafNodeCount++;
}
protected void activateLearningNode(InactiveLearningNode toActivate,
SplitNode parent, int parentBranch) {
Node newLeaf = newLearningNode(toActivate.getObservedClassDistribution());
if (parent == null) {
this.treeRoot = newLeaf;
} else {
parent.setChild(parentBranch, newLeaf);
}
this.activeLeafNodeCount++;
this.inactiveLeafNodeCount--;
}
protected FoundNode[] findLearningNodes() {
List foundList = new LinkedList();
findLearningNodes(this.treeRoot, null, -1, foundList);
return foundList.toArray(new FoundNode[foundList.size()]);
}
protected void findLearningNodes(Node node, SplitNode parent,
int parentBranch, List found) {
if (node != null) {
if (node instanceof LearningNode) {
found.add(new FoundNode(node, parent, parentBranch));
}
if (node instanceof SplitNode) {
SplitNode splitNode = (SplitNode) node;
for (int i = 0; i < splitNode.numChildren(); i++) {
findLearningNodes(splitNode.getChild(i), splitNode, i,
found);
}
}
}
}
public MultiChoiceOption leafpredictionOption = new MultiChoiceOption(
"leafprediction", 'l', "Leaf prediction to use.", new String[]{
"MC", "NB", "NBAdaptive"}, new String[]{
"Majority class",
"Naive Bayes",
"Naive Bayes Adaptive"}, 2);
public IntOption nbThresholdOption = new IntOption(
"nbThreshold",
'q',
"The number of instances a leaf should observe before permitting Naive Bayes.",
0, 0, Integer.MAX_VALUE);
public static class LearningNodeNB extends ActiveLearningNode {
private static final long serialVersionUID = 1L;
public LearningNodeNB(double[] initialClassObservations) {
super(initialClassObservations);
}
@Override
public double[] getClassVotes(Instance inst, HoeffdingTree ht) {
if (getWeightSeen() >= ht.nbThresholdOption.getValue()) {
return NaiveBayes.doNaiveBayesPrediction(inst,
this.observedClassDistribution,
this.attributeObservers);
}
return super.getClassVotes(inst, ht);
}
@Override
public void disableAttribute(int attIndex) {
// should not disable poor atts - they are used in NB calc
}
}
public static class LearningNodeNBAdaptive extends LearningNodeNB {
private static final long serialVersionUID = 1L;
protected double mcCorrectWeight = 0.0;
protected double nbCorrectWeight = 0.0;
public LearningNodeNBAdaptive(double[] initialClassObservations) {
super(initialClassObservations);
}
@Override
public void learnFromInstance(Instance inst, HoeffdingTree ht) {
int trueClass = (int) inst.classValue();
if (this.observedClassDistribution.maxIndex() == trueClass) {
this.mcCorrectWeight += inst.weight();
}
if (Utils.maxIndex(NaiveBayes.doNaiveBayesPrediction(inst,
this.observedClassDistribution, this.attributeObservers)) == trueClass) {
this.nbCorrectWeight += inst.weight();
}
super.learnFromInstance(inst, ht);
}
@Override
public double[] getClassVotes(Instance inst, HoeffdingTree ht) {
if (this.mcCorrectWeight > this.nbCorrectWeight) {
return this.observedClassDistribution.getArrayCopy();
}
return NaiveBayes.doNaiveBayesPrediction(inst,
this.observedClassDistribution, this.attributeObservers);
}
}
protected LearningNode newLearningNode() {
return newLearningNode(new double[0]);
}
protected LearningNode newLearningNode(double[] initialClassObservations) {
LearningNode ret;
int predictionOption = this.leafpredictionOption.getChosenIndex();
if (predictionOption == 0) { //MC
ret = new ActiveLearningNode(initialClassObservations);
} else if (predictionOption == 1) { //NB
ret = new LearningNodeNB(initialClassObservations);
} else { //NBAdaptive
ret = new LearningNodeNBAdaptive(initialClassObservations);
}
return ret;
}
@Override
public ImmutableCapabilities defineImmutableCapabilities() {
if (this.getClass() == HoeffdingTree.class)
return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE);
else
return new ImmutableCapabilities(Capability.VIEW_STANDARD);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy