All Downloads are FREE. Search and download functionalities are using the official Maven repository.

moa.classifiers.trees.HoeffdingOptionTree Maven / Gradle / Ivy

Go to download

Massive On-line Analysis is an environment for massive data mining. MOA provides a framework for data stream mining and includes tools for evaluation and a collection of machine learning algorithms. Related to the WEKA project, also written in Java, while scaling to more demanding problems.

The newest version!
/*
 *    HoeffdingOptionTree.java
 *    Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
 *    @author Richard Kirkby ([email protected])
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program. If not, see .
 *    
 */
package moa.classifiers.trees;

import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import com.github.javacliparser.FileOption;
import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import moa.AbstractMOAObject;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.bayes.NaiveBayes;
import moa.classifiers.core.AttributeSplitSuggestion;
import moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
import moa.classifiers.core.attributeclassobservers.DiscreteAttributeClassObserver;
import moa.classifiers.core.attributeclassobservers.NullAttributeClassObserver;
import moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver;
import moa.classifiers.core.conditionaltests.InstanceConditionalTest;
import moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest;
import moa.classifiers.core.splitcriteria.SplitCriterion;
import moa.core.AutoExpandVector;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.SizeOf;
import moa.core.StringUtils;
import moa.core.Utils;
import moa.options.ClassOption;
import com.yahoo.labs.samoa.instances.Instance;

/**
 * Hoeffding Option Tree.
 *
 * 

Hoeffding Option Trees are regular Hoeffding trees containing additional * option nodes that allow several tests to be applied, leading to multiple * Hoeffding trees as separate paths. They consist of a single structure that * efficiently represents multiple trees. A particular example can travel down * multiple paths of the tree, contributing, in different ways, to different * options.

* *

See for details:

B. Pfahringer, G. Holmes, and R. Kirkby. New * options for hoeffding trees. In AI, pages 90–99, 2007.

* *

Parameters:

  • -o : Maximum number of option paths per node
  • *
  • -m : Maximum memory consumed by the tree
  • -n : Numeric estimator * to use :
    • Gaussian approximation evaluating 10 splitpoints
    • *
    • Gaussian approximation evaluating 100 splitpoints
    • * Greenwald-Khanna quantile summary with 10 tuples
    • Greenwald-Khanna * quantile summary with 100 tuples
    • Greenwald-Khanna quantile summary * with 1000 tuples
    • VFML method with 10 bins
    • VFML method * with 100 bins
    • VFML method with 1000 bins
    • Exhaustive * binary tree
  • -e : How many instances between memory consumption * checks
  • -g : The number of instances a leaf should observe between * split attempts
  • -s : Split criterion to use. Example : * InfoGainSplitCriterion
  • -c : The allowable error in split decision, * values closer to 0 will take longer to decide
  • -w : The allowable * error in secondary split decisions, values closer to 0 will take longer to * decide
  • -t : Threshold below which a split will be forced to break * ties
  • -b : Only allow binary splits
  • -z : Memory strategy to * use
  • -r : Disable poor attributes
  • -p : Disable * pre-pruning
  • -d : File to append option table to.
  • *
  • -l : Leaf prediction to use: MajorityClass (MC), Naive Bayes (NB) or NaiveBayes * adaptive (NBAdaptive).
  • *
  • -q : The number of instances a leaf should observe before * permitting Naive Bayes
  • *
* * @author Richard Kirkby ([email protected]) * @version $Revision: 7 $ */ public class HoeffdingOptionTree extends AbstractClassifier implements MultiClassClassifier, CapabilitiesHandler { private static final long serialVersionUID = 1L; @Override public String getPurposeString() { return "Hoeffding Option Tree: single tree that represents multiple trees."; } public IntOption maxOptionPathsOption = new IntOption("maxOptionPaths", 'o', "Maximum number of option paths per node.", 5, 1, Integer.MAX_VALUE); public IntOption maxByteSizeOption = new IntOption("maxByteSize", 'm', "Maximum memory consumed by the tree.", 33554432, 0, Integer.MAX_VALUE); /* * public MultiChoiceOption numericEstimatorOption = new MultiChoiceOption( * "numericEstimator", 'n', "Numeric estimator to use.", new String[]{ * "GAUSS10", "GAUSS100", "GK10", "GK100", "GK1000", "VFML10", "VFML100", * "VFML1000", "BINTREE"}, new String[]{ "Gaussian approximation evaluating * 10 splitpoints", "Gaussian approximation evaluating 100 splitpoints", * "Greenwald-Khanna quantile summary with 10 tuples", "Greenwald-Khanna * quantile summary with 100 tuples", "Greenwald-Khanna quantile summary * with 1000 tuples", "VFML method with 10 bins", "VFML method with 100 * bins", "VFML method with 1000 bins", "Exhaustive binary tree"}, 0); */ public ClassOption numericEstimatorOption = new ClassOption("numericEstimator", 'n', "Numeric estimator to use.", NumericAttributeClassObserver.class, "GaussianNumericAttributeClassObserver"); public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator", 'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class, "NominalAttributeClassObserver"); public IntOption memoryEstimatePeriodOption = new IntOption( "memoryEstimatePeriod", 'e', "How many instances between memory consumption checks.", 1000000, 0, Integer.MAX_VALUE); public IntOption gracePeriodOption = new IntOption( "gracePeriod", 'g', "The number of instances a leaf should observe between split attempts.", 200, 0, Integer.MAX_VALUE); public ClassOption splitCriterionOption = new ClassOption("splitCriterion", 's', "Split criterion to use.", SplitCriterion.class, "InfoGainSplitCriterion"); public FloatOption splitConfidenceOption = new FloatOption( "splitConfidence", 'c', "The allowable error in split decision, values closer to 0 will take longer to decide.", 0.0000001, 0.0, 1.0); public FloatOption secondarySplitConfidenceOption = new FloatOption( "secondarySplitConfidence", 'w', "The allowable error in secondary split decisions, values closer to 0 will take longer to decide.", 0.1, 0.0, 1.0); public FloatOption tieThresholdOption = new FloatOption("tieThreshold", 't', "Threshold below which a split will be forced to break ties.", 0.05, 0.0, 1.0); public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b', "Only allow binary splits."); public FlagOption removePoorAttsOption = new FlagOption("removePoorAtts", 'r', "Disable poor attributes."); public FlagOption noPrePruneOption = new FlagOption("noPrePrune", 'p', "Disable pre-pruning."); public FileOption dumpFileOption = new FileOption("dumpFile", 'f', "File to append option table to.", null, "csv", true); public IntOption memoryStrategyOption = new IntOption("memStrategy", 'z', "Memory strategy to use.", 2); public static class FoundNode { public Node node; public SplitNode parent; public int parentBranch; // set to -999 for option leaves public FoundNode(Node node, SplitNode parent, int parentBranch) { this.node = node; this.parent = parent; this.parentBranch = parentBranch; } } public static class Node extends AbstractMOAObject { private static final long serialVersionUID = 1L; protected DoubleVector observedClassDistribution; public Node(double[] classObservations) { this.observedClassDistribution = new DoubleVector(classObservations); } public long calcByteSize() { return SizeOf.sizeOf(this) + SizeOf.fullSizeOf(this.observedClassDistribution); } public long calcByteSizeIncludingSubtree() { return calcByteSize(); } public boolean isLeaf() { return true; } public FoundNode[] filterInstanceToLeaves(Instance inst, SplitNode parent, int parentBranch, boolean updateSplitterCounts) { List nodes = new LinkedList(); filterInstanceToLeaves(inst, parent, parentBranch, nodes, updateSplitterCounts); return nodes.toArray(new FoundNode[nodes.size()]); } public void filterInstanceToLeaves(Instance inst, SplitNode splitparent, int parentBranch, List foundNodes, boolean updateSplitterCounts) { foundNodes.add(new FoundNode(this, splitparent, parentBranch)); } public double[] getObservedClassDistribution() { return this.observedClassDistribution.getArrayCopy(); } public double[] getClassVotes(Instance inst, HoeffdingOptionTree ht) { double[] dist = this.observedClassDistribution.getArrayCopy(); double distSum = Utils.sum(dist); if (distSum > 0.0) { Utils.normalize(dist, distSum); } return dist; } public boolean observedClassDistributionIsPure() { return this.observedClassDistribution.numNonZeroEntries() < 2; } public void describeSubtree(HoeffdingOptionTree ht, StringBuilder out, int indent) { StringUtils.appendIndented(out, indent, "Leaf "); out.append(ht.getClassNameString()); out.append(" = "); out.append(ht.getClassLabelString(this.observedClassDistribution.maxIndex())); out.append(" weights: "); this.observedClassDistribution.getSingleLineDescription(out, ht.treeRoot.observedClassDistribution.numValues()); StringUtils.appendNewline(out); } public int subtreeDepth() { return 0; } public double calculatePromise() { double totalSeen = this.observedClassDistribution.sumOfValues(); return totalSeen > 0.0 ? (totalSeen - this.observedClassDistribution.getValue(this.observedClassDistribution.maxIndex())) : 0.0; } public void getDescription(StringBuilder sb, int indent) { describeSubtree(null, sb, indent); } } public static class SplitNode extends Node { private static final long serialVersionUID = 1L; protected InstanceConditionalTest splitTest; protected SplitNode parent; protected Node nextOption; protected int optionCount; // set to -999 for optional splits protected AutoExpandVector children = new AutoExpandVector(); @Override public long calcByteSize() { return super.calcByteSize() + SizeOf.sizeOf(this.children) + SizeOf.fullSizeOf(this.splitTest); } @Override public long calcByteSizeIncludingSubtree() { long byteSize = calcByteSize(); for (Node child : this.children) { if (child != null) { byteSize += child.calcByteSizeIncludingSubtree(); } } if (this.nextOption != null) { byteSize += this.nextOption.calcByteSizeIncludingSubtree(); } return byteSize; } public SplitNode(InstanceConditionalTest splitTest, double[] classObservations) { super(classObservations); this.splitTest = splitTest; } public int numChildren() { return this.children.size(); } public void setChild(int index, Node child) { if ((this.splitTest.maxBranches() >= 0) && (index >= this.splitTest.maxBranches())) { throw new IndexOutOfBoundsException(); } this.children.set(index, child); } public Node getChild(int index) { return this.children.get(index); } public int instanceChildIndex(Instance inst) { return this.splitTest.branchForInstance(inst); } @Override public boolean isLeaf() { return false; } @Override public void filterInstanceToLeaves(Instance inst, SplitNode myparent, int parentBranch, List foundNodes, boolean updateSplitterCounts) { if (updateSplitterCounts) { this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); } int childIndex = instanceChildIndex(inst); if (childIndex >= 0) { Node child = getChild(childIndex); if (child != null) { child.filterInstanceToLeaves(inst, this, childIndex, foundNodes, updateSplitterCounts); } else { foundNodes.add(new FoundNode(null, this, childIndex)); } } if (this.nextOption != null) { this.nextOption.filterInstanceToLeaves(inst, this, -999, foundNodes, updateSplitterCounts); } } @Override public void describeSubtree(HoeffdingOptionTree ht, StringBuilder out, int indent) { for (int branch = 0; branch < numChildren(); branch++) { Node child = getChild(branch); if (child != null) { StringUtils.appendIndented(out, indent, "if "); out.append(this.splitTest.describeConditionForBranch(branch, ht.getModelContext())); out.append(": "); out.append("** option count = " + this.optionCount); StringUtils.appendNewline(out); child.describeSubtree(ht, out, indent + 2); } } } @Override public int subtreeDepth() { int maxChildDepth = 0; for (Node child : this.children) { if (child != null) { int depth = child.subtreeDepth(); if (depth > maxChildDepth) { maxChildDepth = depth; } } } return maxChildDepth + 1; } public double computeMeritOfExistingSplit( SplitCriterion splitCriterion, double[] preDist) { double[][] postDists = new double[this.children.size()][]; for (int i = 0; i < this.children.size(); i++) { postDists[i] = this.children.get(i).getObservedClassDistribution(); } return splitCriterion.getMeritOfSplit(preDist, postDists); } public void updateOptionCount(SplitNode source, HoeffdingOptionTree hot) { if (this.optionCount == -999) { this.parent.updateOptionCount(source, hot); } else { int maxChildCount = -999; SplitNode curr = this; while (curr != null) { for (Node child : curr.children) { if (child instanceof SplitNode) { SplitNode splitChild = (SplitNode) child; if (splitChild.optionCount > maxChildCount) { maxChildCount = splitChild.optionCount; } } } if ((curr.nextOption != null) && (curr.nextOption instanceof SplitNode)) { curr = (SplitNode) curr.nextOption; } else { curr = null; } } if (maxChildCount > this.optionCount) { // currently only works // one // way - adding, not // removing int delta = maxChildCount - this.optionCount; this.optionCount = maxChildCount; if (this.optionCount >= hot.maxOptionPathsOption.getValue()) { killOptionLeaf(hot); } curr = this; while (curr != null) { for (Node child : curr.children) { if (child instanceof SplitNode) { SplitNode splitChild = (SplitNode) child; if (splitChild != source) { splitChild.updateOptionCountBelow(delta, hot); } } } if ((curr.nextOption != null) && (curr.nextOption instanceof SplitNode)) { curr = (SplitNode) curr.nextOption; } else { curr = null; } } if (this.parent != null) { this.parent.updateOptionCount(this, hot); } } } } public void updateOptionCountBelow(int delta, HoeffdingOptionTree hot) { if (this.optionCount != -999) { this.optionCount += delta; if (this.optionCount >= hot.maxOptionPathsOption.getValue()) { killOptionLeaf(hot); } } for (Node child : this.children) { if (child instanceof SplitNode) { SplitNode splitChild = (SplitNode) child; splitChild.updateOptionCountBelow(delta, hot); } } if (this.nextOption instanceof SplitNode) { ((SplitNode) this.nextOption).updateOptionCountBelow(delta, hot); } } private void killOptionLeaf(HoeffdingOptionTree hot) { if (this.nextOption instanceof SplitNode) { ((SplitNode) this.nextOption).killOptionLeaf(hot); } else if (this.nextOption instanceof ActiveLearningNode) { this.nextOption = null; hot.activeLeafNodeCount--; } else if (this.nextOption instanceof InactiveLearningNode) { this.nextOption = null; hot.inactiveLeafNodeCount--; } } public int getHeadOptionCount() { SplitNode sn = this; while (sn.optionCount == -999) { sn = sn.parent; } return sn.optionCount; } } public static abstract class LearningNode extends Node { private static final long serialVersionUID = 1L; public LearningNode(double[] initialClassObservations) { super(initialClassObservations); } public abstract void learnFromInstance(Instance inst, HoeffdingOptionTree ht); } public static class InactiveLearningNode extends LearningNode { private static final long serialVersionUID = 1L; public InactiveLearningNode(double[] initialClassObservations) { super(initialClassObservations); } @Override public void learnFromInstance(Instance inst, HoeffdingOptionTree ht) { this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); } } public static class ActiveLearningNode extends LearningNode { private static final long serialVersionUID = 1L; protected double weightSeenAtLastSplitEvaluation; protected AutoExpandVector attributeObservers = new AutoExpandVector(); public ActiveLearningNode(double[] initialClassObservations) { super(initialClassObservations); this.weightSeenAtLastSplitEvaluation = getWeightSeen(); } @Override public long calcByteSize() { return super.calcByteSize() + SizeOf.fullSizeOf(this.attributeObservers); } @Override public void learnFromInstance(Instance inst, HoeffdingOptionTree ht) { this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); for (int i = 0; i < inst.numAttributes() - 1; i++) { int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); AttributeClassObserver obs = this.attributeObservers.get(i); if (obs == null) { obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver(); this.attributeObservers.set(i, obs); } obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight()); } } public double getWeightSeen() { return this.observedClassDistribution.sumOfValues(); } public double getWeightSeenAtLastSplitEvaluation() { return this.weightSeenAtLastSplitEvaluation; } public void setWeightSeenAtLastSplitEvaluation(double weight) { this.weightSeenAtLastSplitEvaluation = weight; } public AttributeSplitSuggestion[] getBestSplitSuggestions( SplitCriterion criterion, HoeffdingOptionTree ht) { List bestSuggestions = new LinkedList(); double[] preSplitDist = this.observedClassDistribution.getArrayCopy(); if (!ht.noPrePruneOption.isSet()) { // add null split as an option bestSuggestions.add(new AttributeSplitSuggestion(null, new double[0][], criterion.getMeritOfSplit( preSplitDist, new double[][]{preSplitDist}))); } for (int i = 0; i < this.attributeObservers.size(); i++) { AttributeClassObserver obs = this.attributeObservers.get(i); if (obs != null) { AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion, preSplitDist, i, ht.binarySplitsOption.isSet()); if (bestSuggestion != null) { bestSuggestions.add(bestSuggestion); } } } return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]); } public void disableAttribute(int attIndex) { this.attributeObservers.set(attIndex, new NullAttributeClassObserver()); } } protected Node treeRoot; protected int decisionNodeCount; protected int activeLeafNodeCount; protected int inactiveLeafNodeCount; protected double inactiveLeafByteSizeEstimate; protected double activeLeafByteSizeEstimate; protected double byteSizeEstimateOverheadFraction; protected int maxPredictionPaths; public long calcByteSize() { long size = SizeOf.sizeOf(this); if (this.treeRoot != null) { size += this.treeRoot.calcByteSizeIncludingSubtree(); } return size; } @Override public long measureByteSize() { return calcByteSize(); } @Override public void resetLearningImpl() { this.treeRoot = null; this.decisionNodeCount = 0; this.activeLeafNodeCount = 0; this.inactiveLeafNodeCount = 0; this.inactiveLeafByteSizeEstimate = 0.0; this.activeLeafByteSizeEstimate = 0.0; this.byteSizeEstimateOverheadFraction = 1.0; this.maxPredictionPaths = 0; if (this.leafpredictionOption.getChosenIndex() > 0) { this.removePoorAttsOption = null; } } @Override public void trainOnInstanceImpl(Instance inst) { if (this.treeRoot == null) { this.treeRoot = newLearningNode(); this.activeLeafNodeCount = 1; } FoundNode[] foundNodes = this.treeRoot.filterInstanceToLeaves(inst, null, -1, true); for (FoundNode foundNode : foundNodes) { // option leaves will have a parentBranch of -999 // option splits will have an option count of -999 Node leafNode = foundNode.node; if (leafNode == null) { leafNode = newLearningNode(); foundNode.parent.setChild(foundNode.parentBranch, leafNode); this.activeLeafNodeCount++; } if (leafNode instanceof LearningNode) { LearningNode learningNode = (LearningNode) leafNode; learningNode.learnFromInstance(inst, this); if (learningNode instanceof ActiveLearningNode) { ActiveLearningNode activeLearningNode = (ActiveLearningNode) learningNode; double weightSeen = activeLearningNode.getWeightSeen(); if (weightSeen - activeLearningNode.getWeightSeenAtLastSplitEvaluation() >= this.gracePeriodOption.getValue()) { attemptToSplit(activeLearningNode, foundNode.parent, foundNode.parentBranch); activeLearningNode.setWeightSeenAtLastSplitEvaluation(weightSeen); } } } } if (this.trainingWeightSeenByModel % this.memoryEstimatePeriodOption.getValue() == 0) { estimateModelByteSizes(); } } @Override public double[] getVotesForInstance(Instance inst) { if (this.treeRoot != null) { FoundNode[] foundNodes = this.treeRoot.filterInstanceToLeaves(inst, null, -1, false); DoubleVector result = new DoubleVector(); int predictionPaths = 0; for (FoundNode foundNode : foundNodes) { if (foundNode.parentBranch != -999) { Node leafNode = foundNode.node; if (leafNode == null) { leafNode = foundNode.parent; } double[] dist = leafNode.getClassVotes(inst, this); //Albert: changed for weights //double distSum = Utils.sum(dist); //if (distSum > 0.0) { // Utils.normalize(dist, distSum); //} result.addValues(dist); predictionPaths++; } } if (predictionPaths > this.maxPredictionPaths) { this.maxPredictionPaths++; } return result.getArrayRef(); } return new double[0]; } @Override protected Measurement[] getModelMeasurementsImpl() { return new Measurement[]{ new Measurement("tree size (nodes)", this.decisionNodeCount + this.activeLeafNodeCount + this.inactiveLeafNodeCount), new Measurement("tree size (leaves)", this.activeLeafNodeCount + this.inactiveLeafNodeCount), new Measurement("active learning leaves", this.activeLeafNodeCount), new Measurement("tree depth", measureTreeDepth()), new Measurement("active leaf byte size estimate", this.activeLeafByteSizeEstimate), new Measurement("inactive leaf byte size estimate", this.inactiveLeafByteSizeEstimate), new Measurement("byte size estimate overhead", this.byteSizeEstimateOverheadFraction), new Measurement("maximum prediction paths used", this.maxPredictionPaths)}; } public int measureTreeDepth() { if (this.treeRoot != null) { return this.treeRoot.subtreeDepth(); } return 0; } @Override public void getModelDescription(StringBuilder out, int indent) { this.treeRoot.describeSubtree(this, out, indent); } @Override public boolean isRandomizable() { return false; } public static double computeHoeffdingBound(double range, double confidence, double n) { return Math.sqrt(((range * range) * Math.log(1.0 / confidence)) / (2.0 * n)); } protected AttributeClassObserver newNominalClassObserver() { AttributeClassObserver nominalClassObserver = (AttributeClassObserver) getPreparedClassOption(this.nominalEstimatorOption); return (AttributeClassObserver) nominalClassObserver.copy(); } protected AttributeClassObserver newNumericClassObserver() { AttributeClassObserver numericClassObserver = (AttributeClassObserver) getPreparedClassOption(this.numericEstimatorOption); return (AttributeClassObserver) numericClassObserver.copy(); } protected void attemptToSplit(ActiveLearningNode node, SplitNode parent, int parentIndex) { if (!node.observedClassDistributionIsPure()) { SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(this.splitCriterionOption); AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, this); Arrays.sort(bestSplitSuggestions); boolean shouldSplit = false; if (parentIndex != -999) { if (bestSplitSuggestions.length < 2) { shouldSplit = bestSplitSuggestions.length > 0; } else { double hoeffdingBound = computeHoeffdingBound( splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()), this.splitConfidenceOption.getValue(), node.getWeightSeen()); AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; AttributeSplitSuggestion secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2]; if ((bestSuggestion.merit - secondBestSuggestion.merit > hoeffdingBound) || (hoeffdingBound < this.tieThresholdOption.getValue())) { shouldSplit = true; } if ((this.removePoorAttsOption != null) && this.removePoorAttsOption.isSet()) { Set poorAtts = new HashSet(); // scan 1 - add any poor to set for (int i = 0; i < bestSplitSuggestions.length; i++) { if (bestSplitSuggestions[i].splitTest != null) { int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn(); if (splitAtts.length == 1) { if (bestSuggestion.merit - bestSplitSuggestions[i].merit > hoeffdingBound) { poorAtts.add(new Integer(splitAtts[0])); } } } } // scan 2 - remove good ones from set for (int i = 0; i < bestSplitSuggestions.length; i++) { if (bestSplitSuggestions[i].splitTest != null) { int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn(); if (splitAtts.length == 1) { if (bestSuggestion.merit - bestSplitSuggestions[i].merit < hoeffdingBound) { poorAtts.remove(new Integer( splitAtts[0])); } } } } for (int poorAtt : poorAtts) { node.disableAttribute(poorAtt); } } } } else if (bestSplitSuggestions.length > 0) { double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()), this.secondarySplitConfidenceOption.getValue(), node.getWeightSeen()); AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; // in option case, scan back through existing options to // find best SplitNode current = parent; double bestPreviousMerit = Double.NEGATIVE_INFINITY; double[] preDist = node.getObservedClassDistribution(); while (true) { double merit = current.computeMeritOfExistingSplit( splitCriterion, preDist); if (merit > bestPreviousMerit) { bestPreviousMerit = merit; } if (current.optionCount != -999) { break; } current = current.parent; } if (bestSuggestion.merit - bestPreviousMerit > hoeffdingBound) { shouldSplit = true; } } if (shouldSplit) { AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1]; if (splitDecision.splitTest == null) { // preprune - null wins if (parentIndex != -999) { deactivateLearningNode(node, parent, parentIndex); } } else { SplitNode newSplit = new SplitNode(splitDecision.splitTest, node.getObservedClassDistribution()); newSplit.parent = parent; // add option procedure SplitNode optionHead = parent; if (parent != null) { while (optionHead.optionCount == -999) { optionHead = optionHead.parent; } } if ((parentIndex == -999) && (parent != null)) { // adding a new option newSplit.optionCount = -999; optionHead.updateOptionCountBelow(1, this); if (optionHead.parent != null) { optionHead.parent.updateOptionCount(optionHead, this); } addToOptionTable(splitDecision, optionHead.parent); } else { // adding a regular leaf if (optionHead == null) { newSplit.optionCount = 1; } else { newSplit.optionCount = optionHead.optionCount; } } int numOptions = 1; if (optionHead != null) { numOptions = optionHead.optionCount; } if (numOptions < this.maxOptionPathsOption.getValue()) { newSplit.nextOption = node; // preserve leaf // disable attribute just used int[] splitAtts = splitDecision.splitTest.getAttsTestDependsOn(); for (int i : splitAtts) { node.disableAttribute(i); } } else { this.activeLeafNodeCount--; } for (int i = 0; i < splitDecision.numSplits(); i++) { Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i)); newSplit.setChild(i, newChild); } this.decisionNodeCount++; this.activeLeafNodeCount += splitDecision.numSplits(); if (parent == null) { this.treeRoot = newSplit; } else { if (parentIndex != -999) { parent.setChild(parentIndex, newSplit); } else { parent.nextOption = newSplit; } } } // manage memory enforceTrackerLimit(); } } } private void addToOptionTable(AttributeSplitSuggestion bestSuggestion, SplitNode parent) { File dumpFile = this.dumpFileOption.getFile(); PrintStream immediateResultStream = null; if (dumpFile != null) { try { if (dumpFile.exists()) { immediateResultStream = new PrintStream( new FileOutputStream(dumpFile, true), true); } else { immediateResultStream = new PrintStream( new FileOutputStream(dumpFile), true); } } catch (Exception ex) { throw new RuntimeException("Unable to open dump file: " + dumpFile, ex); } int splitAtt = bestSuggestion.splitTest.getAttsTestDependsOn()[0]; double splitVal = -1.0; if (bestSuggestion.splitTest instanceof NumericAttributeBinaryTest) { NumericAttributeBinaryTest test = (NumericAttributeBinaryTest) bestSuggestion.splitTest; splitVal = test.getSplitValue(); } int treeDepth = 0; while (parent != null) { parent = parent.parent; treeDepth++; } immediateResultStream.println(this.trainingWeightSeenByModel + "," + treeDepth + "," + splitAtt + "," + splitVal); immediateResultStream.flush(); immediateResultStream.close(); } } public void enforceTrackerLimit() { if ((this.inactiveLeafNodeCount > 0) || ((this.activeLeafNodeCount * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount * this.inactiveLeafByteSizeEstimate) * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue())) { FoundNode[] learningNodes = findLearningNodes(); Arrays.sort(learningNodes, new Comparator() { public int compare(FoundNode fn1, FoundNode fn2) { if (HoeffdingOptionTree.this.memoryStrategyOption.getValue() == 0) { // strategy 1 - every leaf treated equal return Double.compare(fn1.node.calculatePromise(), fn2.node.calculatePromise()); } else if (HoeffdingOptionTree.this.memoryStrategyOption.getValue() == 1) { // strategy 2 - internal leaves penalised double p1 = fn1.node.calculatePromise(); if (fn1.parentBranch == -999) { p1 /= fn1.parent.getHeadOptionCount(); } double p2 = fn2.node.calculatePromise(); if (fn2.parentBranch == -999) { p1 /= fn2.parent.getHeadOptionCount(); } return Double.compare(p1, p2); } else { // strategy 3 - all true leaves beat internal leaves if (fn1.parentBranch == -999) { if (fn2.parentBranch == -999) { return Double.compare(fn1.node.calculatePromise(), fn2.node.calculatePromise()); } return -1; // fn1 < fn2 } if (fn2.parentBranch == -999) { return 1; // fn1 > fn2 } return Double.compare(fn1.node.calculatePromise(), fn2.node.calculatePromise()); } } }); int maxActive = 0; while (maxActive < learningNodes.length) { maxActive++; if ((maxActive * this.activeLeafByteSizeEstimate + (learningNodes.length - maxActive) * this.inactiveLeafByteSizeEstimate) * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue()) { maxActive--; break; } } int cutoff = learningNodes.length - maxActive; for (int i = 0; i < cutoff; i++) { if (learningNodes[i].node instanceof ActiveLearningNode) { deactivateLearningNode( (ActiveLearningNode) learningNodes[i].node, learningNodes[i].parent, learningNodes[i].parentBranch); } } for (int i = cutoff; i < learningNodes.length; i++) { if (learningNodes[i].node instanceof InactiveLearningNode) { activateLearningNode( (InactiveLearningNode) learningNodes[i].node, learningNodes[i].parent, learningNodes[i].parentBranch); } } } } public void estimateModelByteSizes() { FoundNode[] learningNodes = findLearningNodes(); long totalActiveSize = 0; long totalInactiveSize = 0; for (FoundNode foundNode : learningNodes) { if (foundNode.node instanceof ActiveLearningNode) { totalActiveSize += SizeOf.fullSizeOf(foundNode.node); } else { totalInactiveSize += SizeOf.fullSizeOf(foundNode.node); } } if (totalActiveSize > 0) { this.activeLeafByteSizeEstimate = (double) totalActiveSize / this.activeLeafNodeCount; } if (totalInactiveSize > 0) { this.inactiveLeafByteSizeEstimate = (double) totalInactiveSize / this.inactiveLeafNodeCount; } long actualModelSize = this.measureByteSize(); double estimatedModelSize = (this.activeLeafNodeCount * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount * this.inactiveLeafByteSizeEstimate); this.byteSizeEstimateOverheadFraction = actualModelSize / estimatedModelSize; if (actualModelSize > this.maxByteSizeOption.getValue()) { enforceTrackerLimit(); } } public void deactivateAllLeaves() { FoundNode[] learningNodes = findLearningNodes(); for (int i = 0; i < learningNodes.length; i++) { if (learningNodes[i].node instanceof ActiveLearningNode) { deactivateLearningNode( (ActiveLearningNode) learningNodes[i].node, learningNodes[i].parent, learningNodes[i].parentBranch); } } } protected void deactivateLearningNode(ActiveLearningNode toDeactivate, SplitNode parent, int parentBranch) { Node newLeaf = new InactiveLearningNode(toDeactivate.getObservedClassDistribution()); if (parent == null) { this.treeRoot = newLeaf; } else { if (parentBranch != -999) { parent.setChild(parentBranch, newLeaf); } else { parent.nextOption = newLeaf; } } this.activeLeafNodeCount--; this.inactiveLeafNodeCount++; } protected void activateLearningNode(InactiveLearningNode toActivate, SplitNode parent, int parentBranch) { Node newLeaf = newLearningNode(toActivate.getObservedClassDistribution()); if (parent == null) { this.treeRoot = newLeaf; } else { if (parentBranch != -999) { parent.setChild(parentBranch, newLeaf); } else { parent.nextOption = newLeaf; } } this.activeLeafNodeCount++; this.inactiveLeafNodeCount--; } protected FoundNode[] findLearningNodes() { List foundList = new LinkedList(); findLearningNodes(this.treeRoot, null, -1, foundList); return foundList.toArray(new FoundNode[foundList.size()]); } protected void findLearningNodes(Node node, SplitNode parent, int parentBranch, List found) { if (node != null) { if (node instanceof LearningNode) { found.add(new FoundNode(node, parent, parentBranch)); } if (node instanceof SplitNode) { SplitNode splitNode = (SplitNode) node; for (int i = 0; i < splitNode.numChildren(); i++) { findLearningNodes(splitNode.getChild(i), splitNode, i, found); } findLearningNodes(splitNode.nextOption, splitNode, -999, found); } } } public MultiChoiceOption leafpredictionOption = new MultiChoiceOption( "leafprediction", 'l', "Leaf prediction to use.", new String[]{ "MC", "NB", "NBAdaptive"}, new String[]{ "Majority class", "Naive Bayes", "Naive Bayes Adaptive"}, 2); public IntOption nbThresholdOption = new IntOption( "nbThreshold", 'q', "The number of instances a leaf should observe before permitting Naive Bayes.", 0, 0, Integer.MAX_VALUE); public static class LearningNodeNB extends ActiveLearningNode { private static final long serialVersionUID = 1L; public LearningNodeNB(double[] initialClassObservations) { super(initialClassObservations); } @Override public double[] getClassVotes(Instance inst, HoeffdingOptionTree hot) { if (getWeightSeen() >= hot.nbThresholdOption.getValue()) { return NaiveBayes.doNaiveBayesPrediction(inst, this.observedClassDistribution, this.attributeObservers); } return super.getClassVotes(inst, hot); } @Override public void disableAttribute(int attIndex) { // should not disable poor atts - they are used in NB calc } } public static class LearningNodeNBAdaptive extends LearningNodeNB { private static final long serialVersionUID = 1L; protected double mcCorrectWeight = 0.0; protected double nbCorrectWeight = 0.0; public LearningNodeNBAdaptive(double[] initialClassObservations) { super(initialClassObservations); } @Override public void learnFromInstance(Instance inst, HoeffdingOptionTree hot) { int trueClass = (int) inst.classValue(); if (this.observedClassDistribution.maxIndex() == trueClass) { this.mcCorrectWeight += inst.weight(); } if (Utils.maxIndex(NaiveBayes.doNaiveBayesPrediction(inst, this.observedClassDistribution, this.attributeObservers)) == trueClass) { this.nbCorrectWeight += inst.weight(); } super.learnFromInstance(inst, hot); } @Override public double[] getClassVotes(Instance inst, HoeffdingOptionTree ht) { if (this.mcCorrectWeight > this.nbCorrectWeight) { return this.observedClassDistribution.getArrayCopy(); } return NaiveBayes.doNaiveBayesPrediction(inst, this.observedClassDistribution, this.attributeObservers); } } protected LearningNode newLearningNode() { return newLearningNode(new double[0]); } protected LearningNode newLearningNode(double[] initialClassObservations) { LearningNode ret; int predictionOption = this.leafpredictionOption.getChosenIndex(); if (predictionOption == 0) { //MC ret = new ActiveLearningNode(initialClassObservations); } else if (predictionOption == 1) { //NB ret = new LearningNodeNB(initialClassObservations); } else { //NBAdaptive ret = new LearningNodeNBAdaptive(initialClassObservations); } return ret; } @Override public ImmutableCapabilities defineImmutableCapabilities() { if (this.getClass() == HoeffdingOptionTree.class) return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE); else return new ImmutableCapabilities(Capability.VIEW_STANDARD); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy