moa.classifiers.trees.iadem.Iadem3 Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of moa Show documentation
Show all versions of moa Show documentation
Massive On-line Analysis is an environment for massive data mining. MOA
provides a framework for data stream mining and includes tools for evaluation
and a collection of machine learning algorithms. Related to the WEKA project,
also written in Java, while scaling to more demanding problems.
/*
* IADEM3Tree.java
*
* @author Isvani Frías-Blanco
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*
*/
package moa.classifiers.trees.iadem;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.io.Serializable;
import java.util.Arrays;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.AbstractChangeDetector;
import moa.core.AutoExpandVector;
import java.util.ArrayList;
import java.util.TreeSet;
import java.util.logging.Level;
import java.util.logging.Logger;
import moa.classifiers.core.conditionaltests.InstanceConditionalTest;
import moa.classifiers.core.conditionaltests.NominalAttributeMultiwayTest;
import moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest;
import moa.core.DoubleVector;
import moa.core.Measurement;
import weka.core.Utils;
/**
*
* @author Isvani Frías Blanco (ifriasb at hotmail dot com)
*/
public class Iadem3 extends Iadem2 implements MultiClassClassifier {
private static final long serialVersionUID = 1L;
public IntOption maxNestingLevelOption = new IntOption("maxNestingLevel", 'p',
"Maximum level of nesting for alternative subtrees (-1 => unbounded).",
1, -1, Integer.MAX_VALUE);
public IntOption maxSubtreesPerNodeOption = new IntOption("maxSubtreesPerNode", 'w',
"Maximum number of alternative subtrees per split node (-1 => unbounded).",
1, -1, Integer.MAX_VALUE);
protected final boolean restartAtDrift = true;
protected int interchangedTrees = 0;
protected int deletedTrees = 0;
protected int numTrees = 0;
protected int lastPrediction = -1,
lastPredictionInLeaf = -1;
//
protected int treeLevel = 0;
protected AutoExpandVector subtreeList = new AutoExpandVector();
protected int currentSplitState = -1;
protected final int SPLIT_BY_TIE_BREAKING = 0,
SPLIT_WITH_CONFIDENCE = 1;
public int numSplitsByBreakingTies = 0;
@Override
protected Measurement[] getModelMeasurementsImpl() {
return new Measurement[]{
new Measurement("tree size (nodes)", this.getNumberOfNodes()),
new Measurement("tree size (leaves)", this.getNumberOfLeaves()),
new Measurement("interchanged trees", this.getChangedTrees())
};
}
public AbstractChangeDetector getEstimatorCopy() {
return (AbstractChangeDetector) ((AbstractChangeDetector) getPreparedClassOption(this.driftDetectionMethodOption)).copy();
}
@Override
public void createRoot(Instance instance) {
double[] arrayCont = new double[instance.numClasses()];
Arrays.fill(arrayCont, 0);
this.treeRoot = newLeafNode(null, 0, 0, arrayCont, instance);
}
public void addSubtree(Iadem3Subtree subtree) {
this.subtreeList.add(subtree);
}
public void removeSubtree(Iadem3Subtree subtree) {
this.subtreeList.remove(subtree);
}
public boolean canCreateSubtree() {
if (this.maxSubtreesPerNodeOption.getValue() > 0) {
int count = getNumberOfSubtrees();
if (count >= this.maxSubtreesPerNodeOption.getValue()) {
return false;
}
}
return true;
}
@Override
public LeafNode newLeafNode(Node parent,
long instTreeCountSinceVirtual,
long instNodeCountSinceVirtual,
double[] initialClassCount,
Instance instance) {
switch (this.leafPredictionOption.getChosenIndex()) {
case 0: {
return new AdaptiveLeafNode(this,
parent,
instTreeCountSinceVirtual,
instNodeCountSinceVirtual,
initialClassCount,
newNumericClassObserver(),
this.estimator,
this.splitTestsOption.getChosenIndex() == 2,
this.splitTestsOption.getChosenIndex() == 0,
instance);
}
case 1: {
return new AdaptiveLeafNodeNB(this,
parent,
instTreeCountSinceVirtual,
instNodeCountSinceVirtual,
initialClassCount,
newNumericClassObserver(),
this.naiveBayesLimit,
this.estimator,
this.splitTestsOption.getChosenIndex() == 2,
this.splitTestsOption.getChosenIndex() == 0,
instance);
}
case 2: {
return new AdaptiveLeafNodeNBKirkby(this,
parent,
instTreeCountSinceVirtual,
instNodeCountSinceVirtual,
initialClassCount,
newNumericClassObserver(),
this.naiveBayesLimit,
this.splitTestsOption.getChosenIndex() == 2,
this.splitTestsOption.getChosenIndex() == 0,
this.estimator,
instance);
}
default: {
return new AdaptiveLeafNodeWeightedVote(this,
parent,
instTreeCountSinceVirtual,
instNodeCountSinceVirtual,
initialClassCount,
newNumericClassObserver(),
this.naiveBayesLimit,
this.splitTestsOption.getChosenIndex() == 2,
this.splitTestsOption.getChosenIndex() == 0,
this.estimator,
instance);
}
}
}
public int getTreeLevel() {
return treeLevel;
}
public int getMaxAltSubtreesPerNode() {
return this.maxSubtreesPerNodeOption.getValue();
}
public int getMaxNestingLevels() {
return this.maxNestingLevelOption.getValue();
}
public boolean isRestaurarVectoresPrediccion() {
return restartAtDrift;
}
public int numDeletedTrees() {
return deletedTrees;
}
public int numTrees() {
int subtreeCount;
if (this.treeRoot instanceof AdaptiveLeafNode) {
subtreeCount = 0;
} else {
AdaptiveSplitNode nodo = (AdaptiveSplitNode) this.treeRoot;
subtreeCount = nodo.getNumTrees();
}
return subtreeCount;
}
public void newTreeChange() {
interchangedTrees++;
numTrees--;
}
public void newDeletedTree() {
deletedTrees++;
numTrees--;
}
public int numSubtrees() {
return tmpNumSubtrees(treeRoot);
}
private int tmpNumSubtrees(Node node) {
int count = 0;
if (node instanceof AdaptiveSplitNode) {
count++;
AutoExpandVector subtree = ((AdaptiveSplitNode) node).alternativeTree;
for (Iadem3Subtree currentSubtree : subtree) {
count += currentSubtree.numSubtrees();
}
}
if (node instanceof AdaptiveSplitNode) {
AdaptiveSplitNode nodoAuxiliar = (AdaptiveSplitNode) node;
for (Node child : nodoAuxiliar.children) {
count += tmpNumSubtrees(child);
}
}
return count;
}
protected boolean hasTree(Node node) {
boolean ret = false;
if (node instanceof AdaptiveSplitNode) {
AdaptiveSplitNode tmp = (AdaptiveSplitNode) node;
if (tmp.alternativeTree != null) {
ret = true;
}
for (int i = 0; ret == false && i < tmp.children.size(); i++) {
ret = ret || hasTree(tmp.children.get(i));
}
}
return ret;
}
@Override
public void learnFromInstance(Instance instance)
throws IademException {
getClassVotes(instance); // to update lastPrediction in the trees
getClassVotesFromLeaf(instance);
super.learnFromInstance(instance);
}
protected void getClassVotesFromLeaf(Instance instance) {
double[] votes = null;
Node node = this.treeRoot;
while (votes == null) {
if (node instanceof AdaptiveSplitNode) {
AdaptiveSplitNode splitNode = (AdaptiveSplitNode) node;
int childIndex = splitNode.instanceChildIndex(instance);
if (childIndex >= 0) {
node = splitNode.getChild(childIndex);
} else {
votes = splitNode.leaf.getClassVotes(instance);
}
} else {
AdaptiveLeafNode leafNode = (AdaptiveLeafNode) node;
votes = leafNode.getClassVotes(instance);
}
}
this.lastPredictionInLeaf = Utils.maxIndex(votes);
}
public void copyTree(Iadem3Subtree arbol) {
this.treeRoot = arbol.treeRoot;
}
void setNewTree() {
numTrees++;
}
public int getChangedTrees() {
return interchangedTrees;
}
@Override
public double[] getClassVotes(Instance instance) {
double[] votes = super.getClassVotes(instance);
this.lastPrediction = Utils.maxIndex(votes);
return votes;
}
public int getNumberOfSubtrees() {
if (this.treeRoot instanceof AdaptiveSplitNode) {
return ((AdaptiveSplitNode) this.treeRoot).getNumberOfSubtrees();
}
return 0;
}
protected Iadem3 getMainTree() {
return this;
}
public void updateNumberOfLeaves(int amount) {
this.numberOfLeaves += amount;
}
public void updateNumberOfNodes(int amount) {
this.numberOfNodes += amount;
}
public void updateNumberOfNodesSplitByTieBreaking(int amount) {
this.numSplitsByBreakingTies += amount;
}
public class AdaptiveLeafNode extends LeafNode implements Serializable {
private static final long serialVersionUID = 1L;
protected AbstractChangeDetector estimator;
public AdaptiveLeafNode(Iadem3 arbol,
Node parent,
long instTreeCountSinceVirtual,
long instNodeCountSinceVirtual,
double[] initialClassCount,
IademNumericAttributeObserver numericAttClassObserver,
AbstractChangeDetector estimator,
boolean onlyMultiwayTest,
boolean onlyBinaryTest,
Instance instance) {
super(arbol,
parent, instTreeCountSinceVirtual, instNodeCountSinceVirtual, initialClassCount, numericAttClassObserver, onlyMultiwayTest, onlyBinaryTest, instance);
if (estimator != null) {
this.estimator = (AbstractChangeDetector) estimator.copy();
} else {
this.estimator = null;
}
}
@Override
protected void createVirtualNodes(IademNumericAttributeObserver numericAttClassObserver,
boolean onlyMultiwayTest,
boolean onlyBinaryTest,
Instance instance) {
ArrayList nominalUsed = nominalAttUsed(instance);
TreeSet sort = new TreeSet<>(nominalUsed);
for (int i = 0; i < instance.numAttributes(); i++) {
if (instance.classIndex() != i
&& instance.attribute(i).isNominal()) {
if ((!sort.isEmpty()) && (i == sort.first())) {
sort.remove(new Integer(sort.first()));
virtualChildren.set(i, null);
} else {
virtualChildren.set(i, new AdaptiveNominalVirtualNode((Iadem3) tree,
this,
i,
onlyMultiwayTest,
onlyBinaryTest));
}
} else if (instance.classIndex() != i
&& instance.attribute(i).isNumeric()) {
virtualChildren.set(i, new AdaptiveNumericVirtualNode((Iadem3) tree,
this,
i,
numericAttClassObserver));
} else {
virtualChildren.set(i, null);
}
}
}
private void updateCounters(Instance experiencia) {
double[] classVotes = this.getClassVotes(experiencia);
boolean trueClass = (Utils.maxIndex(classVotes) == (int) experiencia.classValue());
if (estimator != null && ((Iadem3) this.tree).restartAtDrift) {
double error = trueClass == true ? 0.0 : 1.0;
this.estimator.input(error);
if (this.estimator.getChange()) {
this.restartVariablesAtDrift();
}
}
}
@Override
public void attemptToSplit(Instance instance) {
if (this.classValueDist.numNonZeroEntries() > 1) {
if (hasInformationToSplit()) {
try {
this.instSeenSinceLastSplitAttempt = 0;
IademAttributeSplitSuggestion bestSplitSuggestion;
if (this.instNodeCountSinceReal > 5000) {
((Iadem3) this.tree).updateNumberOfNodesSplitByTieBreaking(1);
bestSplitSuggestion = getFastSplitSuggestion(instance);
if (bestSplitSuggestion != null) {
((Iadem3) this.tree).currentSplitState = ((Iadem3) this.tree).SPLIT_BY_TIE_BREAKING;
doSplit(bestSplitSuggestion, instance);
}
} else {
bestSplitSuggestion = getBestSplitSuggestion(instance);
if (bestSplitSuggestion != null) {
((Iadem3) this.tree).currentSplitState = ((Iadem3) this.tree).SPLIT_WITH_CONFIDENCE;
doSplit(bestSplitSuggestion, instance);
}
}
} catch (IademException ex) {
Logger.getLogger(LeafNode.class.getName()).log(Level.SEVERE, null, ex);
}
}
}
}
@Override
public Node learnFromInstance(Instance inst) {
updateCounters(inst);
return super.learnFromInstance(inst);
}
@Override
public AdaptiveLeafNode[] doSplit(IademAttributeSplitSuggestion mejorExpansion, Instance instance) {
AdaptiveSplitNode splitNode;
splitNode = (AdaptiveSplitNode) virtualChildren.get(mejorExpansion.splitTest.getAttsTestDependsOn()[0]).getNewSplitNode(instTreeCountSinceReal,
parent,
mejorExpansion,
instance);
splitNode.setParent(this.parent);
splitNode.estimator = this.tree.newEstimator();
if (this.parent == null) {
tree.setTreeRoot(splitNode);
} else {
((SplitNode) parent).changeChildren(this, splitNode);
}
this.tree.newSplit(splitNode.getLeaves().size());
return null;
}
protected void restartVariablesAtDrift() {
instNodeCountSinceVirtual = 0;
classValueDist = new DoubleVector();
instTreeCountSinceReal = 0;
instNodeCountSinceReal = 0;
for (int i = 0; i < virtualChildren.size(); i++) {
if (virtualChildren.get(i) != null) {
((restartsVariablesAtDrift) virtualChildren.get(i)).resetVariablesAtDrift();
}
}
}
}
public class AdaptiveLeafNodeNB extends AdaptiveLeafNode {
private static final long serialVersionUID = 1L;
protected int limitNaiveBayes;
public AdaptiveLeafNodeNB(Iadem3 tree,
Node parent,
long instTreeCountSinceVirtual,
long instNodeCountSinceVirtual,
double[] initialClassCount,
IademNumericAttributeObserver numericAttClassObserver,
int limitNaiveBayes,
AbstractChangeDetector estimator,
boolean onlyMultiwayTest,
boolean onlyBinaryTest,
Instance instance) {
super(tree,
parent,
instTreeCountSinceVirtual,
instNodeCountSinceVirtual,
initialClassCount,
numericAttClassObserver,
estimator,
onlyMultiwayTest,
onlyBinaryTest,
instance);
this.limitNaiveBayes = limitNaiveBayes;
}
@Override
public double[] getClassVotes(Instance inst) {
double[] votes;
if (instNodeCountSinceVirtual == 0 || instNodeCountSinceReal < limitNaiveBayes) {
votes = getMajorityClassVotes(inst);
} else {
votes = getNaiveBayesPrediction(inst);
}
return votes;
}
protected double[] getNaiveBayesPrediction(Instance inst) {
double[] classDist = getMajorityClassVotes(inst);
DoubleVector conditionalProbability = null;
for (int i = 0; i < virtualChildren.size(); i++) {
VirtualNode virtual = virtualChildren.get(i);
if (virtual != null && virtual.hasInformation()) {
double currentValue = inst.value(i);
conditionalProbability = virtual.computeConditionalProbability(currentValue);
if (conditionalProbability != null) {
for (int j = 0; j < classDist.length; j++) {
classDist[j] *= conditionalProbability.getValue(j);
}
}
}
}
double sum = 0.0;
for (int i = 0; i < classDist.length; i++) {
sum += classDist[i];
}
if (sum == 0.0) {
for (int i = 0; i < classDist.length; i++) {
classDist[i] = 1.0 / classDist.length;
}
} else {
for (int i = 0; i < classDist.length; i++) {
classDist[i] /= sum;
}
}
return classDist;
}
}
public class AdaptiveLeafNodeNBAdaptive extends AdaptiveLeafNodeNB {
private static final long serialVersionUID = 1L;
protected AbstractChangeDetector naiveBayesError,
majorityClassError;
public AdaptiveLeafNodeNBAdaptive(Iadem3 tree,
Node parent,
long instancesProcessedByTheTree,
long instancesProcessedByThisLeaf,
double[] classDist,
IademNumericAttributeObserver observadorContinuos,
int naiveBayesLimit,
boolean onlyMultiwayTest,
boolean onlyBinaryTest,
AbstractChangeDetector estimator,
Instance instance) {
super(tree,
parent,
instancesProcessedByTheTree,
instancesProcessedByThisLeaf,
classDist,
observadorContinuos,
naiveBayesLimit,
estimator,
onlyMultiwayTest,
onlyBinaryTest,
instance);
this.naiveBayesError = (AbstractChangeDetector) estimator.copy();
this.majorityClassError = (AbstractChangeDetector) estimator.copy();
}
@Override
public double[] getClassVotes(Instance instance) {
double mean1 = this.naiveBayesError.getEstimation(),
mean2 = this.majorityClassError.getEstimation();
if (mean1 > mean2) {
return getMajorityClassVotes(instance);
} else {
return getNaiveBayesPrediction(instance);
}
}
@Override
public Node learnFromInstance(Instance inst) {
double[] classVote = getMajorityClassVotes(inst);
double error = (Utils.maxIndex(classVote) == (int) inst.classValue()) ? 0.0 : 1.0;
this.majorityClassError.input(error);
classVote = getNaiveBayesPrediction(inst);
error = (Utils.maxIndex(classVote) == (int) inst.classValue()) ? 0.0 : 1.0;
this.naiveBayesError.input(error);
return super.learnFromInstance(inst);
}
}
public class AdaptiveLeafNodeNBKirkby extends AdaptiveLeafNodeNB {
private static final long serialVersionUID = 1L;
protected int naiveBayesError,
majorityClassError;
public AdaptiveLeafNodeNBKirkby(Iadem3 tree,
Node parent,
long instancesProcessedByTheTree,
long instancesProcessedByThisLeaf,
double[] classDist,
IademNumericAttributeObserver observadorContinuos,
int naiveBayesLimit,
boolean onlyMultiwayTest,
boolean onlyBinaryTest,
AbstractChangeDetector estimator,
Instance instance) {
super(tree,
parent,
instancesProcessedByTheTree,
instancesProcessedByThisLeaf,
classDist,
observadorContinuos,
naiveBayesLimit,
estimator,
onlyMultiwayTest,
onlyBinaryTest,
instance);
this.naiveBayesError = 0;
this.majorityClassError = 0;
}
@Override
public double[] getClassVotes(Instance instance) {
if (naiveBayesError > majorityClassError) {
return getMajorityClassVotes(instance);
} else {
return getNaiveBayesPrediction(instance);
}
}
@Override
public Node learnFromInstance(Instance inst) {
double[] classVotes = getMajorityClassVotes(inst);
double error = (Utils.maxIndex(classVotes) == (int) inst.classValue()) ? 0.0 : 1.0;
this.majorityClassError += error;
classVotes = getNaiveBayesPrediction(inst);
error = (Utils.maxIndex(classVotes) == (int) inst.classValue()) ? 0.0 : 1.0;
this.naiveBayesError += error;
return super.learnFromInstance(inst);
}
}
public class AdaptiveLeafNodeWeightedVote extends AdaptiveLeafNodeNBAdaptive {
private static final long serialVersionUID = 1L;
public AdaptiveLeafNodeWeightedVote(Iadem3 tree,
Node parent,
long instTreeCountSinceVirtual,
long instNodeCountSinceVirtual,
double[] classDist,
IademNumericAttributeObserver observadorContinuos,
int naiveBayesLimit,
boolean onlyMultiwayTest,
boolean onlyBinaryTest,
AbstractChangeDetector estimator,
Instance instance) {
super(tree,
parent,
instTreeCountSinceVirtual,
instNodeCountSinceVirtual,
classDist,
observadorContinuos,
naiveBayesLimit,
onlyMultiwayTest,
onlyBinaryTest,
estimator,
instance);
}
@Override
public double[] getClassVotes(Instance instance) {
double NBweight = 1 - this.naiveBayesError.getEstimation(),
MCweight = 1 - this.majorityClassError.getEstimation();
double[] MC = getMajorityClassVotes(instance),
NB = getNaiveBayesPrediction(instance),
votes = new double[MC.length];
for (int i = 0; i < MC.length; i++) {
votes[i] = MC[i] * MCweight + NB[i] * NBweight;
}
return votes;
}
protected boolean isSignificantlyGreaterThan(double mean1, double mean2, int n1, int n2) {
double m = 1.0 / n1 + 1.0 / n2,
confidence = 0.001,
log = Math.log(1.0 / confidence),
bound = Math.sqrt(m * log / 2);
return mean1 - mean2 > bound;
}
}
public class AdaptiveNominalVirtualNode extends NominalVirtualNode implements Serializable, restartsVariablesAtDrift {
private static final long serialVersionUID = 1L;
protected AbstractChangeDetector estimador;
public AdaptiveNominalVirtualNode(Iadem3 tree,
Node parent,
int attID,
boolean onlyMultiwayTest,
boolean onlyBinaryTest) {
super(tree, parent, attID, onlyMultiwayTest, onlyBinaryTest);
}
@Override
public Node learnFromInstance(Instance inst) {
double attValue = inst.value(attIndex);
if (Utils.isMissingValue(attValue)) {
} else {
updateCountersForChange(inst);
}
return super.learnFromInstance(inst);
}
private void updateCountersForChange(Instance inst) {
double[] classVotes = this.getClassVotes(inst);
boolean trueClass = (Utils.maxIndex(classVotes) == (int) inst.classValue());
if (estimador != null && ((Iadem3) this.tree).restartAtDrift) {
double error = trueClass == true ? 0.0 : 1.0;
this.estimador.input(error);
if (this.estimador.getChange()) {
this.resetVariablesAtDrift();
}
}
}
@Override
public SplitNode getNewSplitNode(long counter,
Node parent,
IademAttributeSplitSuggestion bestSplit,
Instance instance) {
AdaptiveSplitNode splitNode = new AdaptiveSplitNode((Iadem3) this.tree,
parent,
null,
((LeafNode) this.parent).getMajorityClassVotes(instance),
bestSplit.splitTest,
((AdaptiveLeafNode) this.parent).estimator,
(AdaptiveLeafNode) this.parent,
((Iadem3) this.tree).currentSplitState);
Node[] children;
if (bestSplit.splitTest instanceof NominalAttributeMultiwayTest) {
children = new Node[instance.attribute(this.attIndex).numValues()];
for (int i = 0; i < children.length; i++) {
long tmpConter = 0;
double[] newClassDist = new double[instance.attribute(instance.classIndex()).numValues()];
Arrays.fill(newClassDist, 0);
for (int j = 0; j < newClassDist.length; j++) {
DoubleVector tmpClassDist = nominalAttClassObserver.get(i);
double tmpAttClassCounter = tmpClassDist != null ? tmpClassDist.getValue(j) : 0.0;
newClassDist[j] = tmpAttClassCounter;
tmpConter += newClassDist[j];
}
children[i] = ((Iadem3) tree).newLeafNode(splitNode,
counter,
tmpConter,
newClassDist,
instance);
}
} else { // binary split
children = new Node[2];
IademNominalAttributeBinaryTest binarySplit = (IademNominalAttributeBinaryTest) bestSplit.splitTest;
double[] newClassDist = new double[instance.attribute(instance.classIndex()).numValues()];
double tmpCounter = 0;
Arrays.fill(newClassDist, 0);
DoubleVector classDist = nominalAttClassObserver.get(binarySplit.getAttValue());
for (int i = 0; i < newClassDist.length; i++) {
newClassDist[i] = classDist.getValue(i);
tmpCounter += classDist.getValue(i);
}
children[0] = ((Iadem3) tree).newLeafNode(splitNode,
counter,
(int) tmpCounter,
newClassDist,
instance);
// a la derecha...
tmpCounter = this.classValueDist.sumOfValues() - tmpCounter;
for (int i = 0; i < newClassDist.length; i++) {
newClassDist[i] = this.classValueDist.getValue(i) - newClassDist[i];
}
children[1] = ((Iadem3) tree).newLeafNode(splitNode,
counter,
(int) tmpCounter,
newClassDist,
instance);
}
splitNode.setChildren(children);
return splitNode;
}
@Override
public void resetVariablesAtDrift() {
attValueDist = new DoubleVector();
nominalAttClassObserver = new AutoExpandVector();
classValueDist = new DoubleVector();
}
}
public class AdaptiveNumericVirtualNode extends NumericVirtualNode implements Serializable, restartsVariablesAtDrift {
private static final long serialVersionUID = 1L;
protected IademNumericAttributeObserver altAttClassObserver;
protected DoubleVector altClassDist;
protected AbstractChangeDetector estimator;
public AdaptiveNumericVirtualNode(Iadem3 tree,
Node parent,
int attID,
IademNumericAttributeObserver observadorContinuos) {
super(tree, parent, attID, observadorContinuos);
}
@Override
public Node learnFromInstance(Instance inst) {
updateCounters(inst);
return super.learnFromInstance(inst);
}
private void updateCounters(Instance inst) {
double[] classVotes = this.getClassVotes(inst);
boolean correct = (Utils.maxIndex(classVotes) == (int) inst.classValue());
if (this.estimator != null && ((Iadem3) this.tree).restartAtDrift) {
double error = correct == true ? 0.0 : 1.0;
this.estimator.input(error);
if (this.estimator.getChange()) {
this.resetVariablesAtDrift();
}
}
}
private long sum(long[] arr) {
long s = 0;
for (int i = 0; i < arr.length; i++) {
s += arr[i];
}
return s;
}
@Override
public SplitNode getNewSplitNode(long counter,
Node parent,
IademAttributeSplitSuggestion bestSplit,
Instance instance) {
double[] cutPoints = new double[]{bestCutPoint};
Node[] children = new Node[2]; // a binary split
long[] newClassDist = numericAttClassObserver.getLeftClassDist(bestCutPoint);
long sumClassDist = numericAttClassObserver.getValueCount();
long[] sumAttClassDist = numericAttClassObserver.getClassDist();
boolean equalsPassesTest = true;
if (this.numericAttClassObserver instanceof IademVFMLNumericAttributeClassObserver) {
equalsPassesTest = false;
}
AdaptiveSplitNode splitNode = new AdaptiveSplitNode((Iadem3) this.tree,
parent,
null,
((LeafNode) this.parent).getMajorityClassVotes(instance),
new NumericAttributeBinaryTest(this.attIndex, cutPoints[0], equalsPassesTest),
((AdaptiveLeafNode) this.parent).estimator,
(AdaptiveLeafNode) this.parent,
((Iadem3) this.tree).currentSplitState);
long leftClassDist = sum(newClassDist);
long rightClassDist = sumClassDist - leftClassDist;
double[] newLeftClassDist = new double[instance.attribute(instance.classIndex()).numValues()];
double[] newRightClassDist = new double[instance.attribute(instance.classIndex()).numValues()];
Arrays.fill(newLeftClassDist, 0);
Arrays.fill(newRightClassDist, 0);
for (int i = 0; i < newClassDist.length; i++) {
newLeftClassDist[i] = newClassDist[i];
newRightClassDist[i] = sumAttClassDist[i] - newLeftClassDist[i];
}
splitNode.setChildren(null);
children[0] = ((Iadem3) tree).newLeafNode(splitNode,
counter,
leftClassDist,
newLeftClassDist,
instance);
children[1] = ((Iadem3) tree).newLeafNode(splitNode,
counter,
rightClassDist,
newRightClassDist,
instance);
splitNode.setChildren(children);
return splitNode;
}
@Override
public void resetVariablesAtDrift() {
this.bestSplitSuggestion = null;
this.heuristicMeasureUpdated = false;
numericAttClassObserver.reset();
classValueDist = new DoubleVector();
}
}
public class AdaptiveSplitNode extends SplitNode implements Serializable {
private static final long serialVersionUID = 1L;
protected AutoExpandVector alternativeTree = new AutoExpandVector();
// Detector de cambio de concepto
protected AbstractChangeDetector estimator;
protected int causeOfSplit;
protected AdaptiveLeafNode leaf;
public AdaptiveSplitNode(Iadem3 tree,
Node parent,
Node[] child,
double[] freq,
InstanceConditionalTest splitTest,
AbstractChangeDetector estimator,
AdaptiveLeafNode predictionLeaf,
int causeOfSplit) {
super(tree, parent, child, freq, splitTest);
if (estimator != null) {
this.estimator = (AbstractChangeDetector) estimator.copy();
} else {
this.estimator = null;
}
this.leaf = predictionLeaf;
this.leaf.setSplit(false);
this.causeOfSplit = causeOfSplit;
}
@Override
public Node learnFromInstance(Instance instance) {
try {
double thisError = this.estimator.getEstimation(),
thisSize = this.estimator.getDelay();
double leafError = this.leaf.estimator.getEstimation(),
leafSize = this.leaf.estimator.getDelay();
double m = 1.0 / thisSize + 1.0 / leafSize;
double delta = 0.0001;
double bound = Math.sqrt(m * Math.log(2.0 / delta) / 2.0);
double diff = thisError - leafError;
if (diff > bound && thisSize > 600 && leafSize > 600/**/) {
prune();
return this.leaf;
} else if (-diff > bound) {
this.leaf.restartVariablesAtDrift();
this.leaf.estimator = (AbstractChangeDetector) this.leaf.estimator.copy();
}
Node node;
boolean rightPredicted = ((Iadem3) this.tree).lastPredictionInLeaf == instance.classValue();
node = checkAlternativeSubtrees(rightPredicted, instance);
if (node == null) {
// no subtree change
for (Iadem3Subtree subtree : this.alternativeTree) {
try {
subtree.learnFromInstance(instance);
subtree.incrNumberOfInstancesProcessed();
} catch (IademException ex) {
Logger.getLogger(AdaptiveSplitNode.class.getName()).log(Level.SEVERE, null, ex);
}
}
this.leaf.learnFromInstance(instance);
return super.learnFromInstance(instance);
} else {
// subtree change
return node.learnFromInstance(instance);
}
} catch (IademException ex) {
Logger.getLogger(AdaptiveSplitNode.class.getName()).log(Level.SEVERE, null, ex);
}
return null;
}
private Node checkAlternativeSubtrees(boolean acierto, Instance instance) throws IademException {
if (this.estimator != null) {
double loss = (acierto == true ? 0.0 : 1.0);
estimator.input(loss);
if (estimator.getChange()) {
this.createTree(instance);
}
for (int i = 0; i < this.alternativeTree.size(); i++) {
Iadem3Subtree subtree = alternativeTree.get(i);
double treeError = subtree.estimacionValorMedio(),
thisError = this.estimator.getEstimation();
double bound = IademCommonProcedures.AverageComparitionByHoeffdingCorollary(this.estimator.getDelay(),
subtree.windowWidth(),
1e-4);
if (thisError - treeError > bound/**/) {
((Iadem3) this.tree).interchangedTrees++;
return changeTrees(i);
} else if (isUseless(i)) {
((Iadem3) this.tree).updateNumberOfLeaves(-subtree.getNumberOfLeaves());
((Iadem3) this.tree).updateNumberOfNodes(-subtree.getNumberOfNodes());
((Iadem3) this.tree).updateNumberOfNodesSplitByTieBreaking(-subtree.numSplitsByBreakingTies);
i--;
} else if (this.estimator.getDelay() > 6000
&& subtree.windowWidth() > 6000/**/) {
{
if (treeError - thisError > bound) {
this.alternativeTree.remove(i);
// update number of nodes
((Iadem3) this.tree).updateNumberOfLeaves(-subtree.getNumberOfLeaves());
((Iadem3) this.tree).updateNumberOfNodes(-subtree.getNumberOfNodes());
((Iadem3) this.tree).updateNumberOfNodesSplitByTieBreaking(-subtree.numSplitsByBreakingTies);
i--;
} else /**/ {
int[] countMain = new int[3],
countAlt = new int[3];
for (Node child : this.children) {
child.getNumberOfNodes(countMain);
}
subtree.getNumberOfNodes(countAlt);/**/
if (countMain[0] + countMain[1] + 1 > countAlt[0] + countAlt[1]) {
return changeTrees(i);
} else {
this.alternativeTree.remove(i);
// update number of nodes
((Iadem3) this.tree).updateNumberOfLeaves(-subtree.getNumberOfLeaves());
((Iadem3) this.tree).updateNumberOfNodes(-subtree.getNumberOfNodes());
((Iadem3) this.tree).updateNumberOfNodesSplitByTieBreaking(-subtree.numSplitsByBreakingTies);
i--;
}
}
}/**/
}
}
}
return null;
}
public boolean isUseless(int i) {
boolean removed = false;
Iadem3Subtree subtree = this.alternativeTree.get(i);
if (subtree.getTreeRoot() instanceof AdaptiveSplitNode) /**/ {
// change if it already has an alternative subtree
AdaptiveSplitNode splitNode = ((AdaptiveSplitNode) subtree.getTreeRoot());
int nMain = (int) this.estimator.getDelay(),
nAlt = (int) subtree.getEstimador().getDelay();
double errorMain = this.estimator.getEstimation(),
errorAlt = subtree.getEstimador().getEstimation();
double errorDifference = errorAlt - errorMain,
absError = Math.abs(errorDifference);
if (!removed && nMain > 0 && nAlt > 0) {
double m = 1.0 / nMain + 1.0 / nAlt;
double delta = 1e-4;
double bound = Math.sqrt(m * Math.log(2.0 / delta) / 2.0);
if (errorDifference > bound) {
// alternative tree is too inaccurate
this.alternativeTree.remove(i);
removed = true;
}
}
if (!removed) {
InstanceConditionalTest condTest = splitNode.splitTest;
if (condTest instanceof IademNominalAttributeBinaryTest
&& this.splitTest instanceof IademNominalAttributeBinaryTest) {
IademNominalAttributeBinaryTest altTest = (IademNominalAttributeBinaryTest) condTest,
mainTest = (IademNominalAttributeBinaryTest) this.splitTest;
if (mainTest.getAttValue() == altTest.getAttValue()
&& mainTest.getAttsTestDependsOn()[0] == altTest.getAttsTestDependsOn()[0]) {
this.alternativeTree.remove(i);
removed = true;
}
} else if (condTest instanceof NominalAttributeMultiwayTest
&& this.splitTest instanceof NominalAttributeMultiwayTest) {
NominalAttributeMultiwayTest altTest = (NominalAttributeMultiwayTest) condTest,
mainTest = (NominalAttributeMultiwayTest) this.splitTest;
if (mainTest.getAttsTestDependsOn()[0] == altTest.getAttsTestDependsOn()[0]) {
this.alternativeTree.remove(i);
removed = true;
}
} else if (condTest instanceof NumericAttributeBinaryTest
&& this.splitTest instanceof NumericAttributeBinaryTest) {
NumericAttributeBinaryTest altTest = (NumericAttributeBinaryTest) condTest,
mainTest = (NumericAttributeBinaryTest) this.splitTest;
if (mainTest.getAttsTestDependsOn()[0] == altTest.getAttsTestDependsOn()[0]
&& mainTest.getSplitValue() == altTest.getSplitValue()) {
this.alternativeTree.remove(i);
removed = true;
}
}
}
}
return removed;
}
private Node changeTrees(int index) {
for (int i = 0; i < this.alternativeTree.size(); i++) {
if (i != index) {
Iadem3Subtree subtree = this.alternativeTree.get(i);
((Iadem3) this.tree).updateNumberOfLeaves(-subtree.getNumberOfLeaves());
((Iadem3) this.tree).updateNumberOfNodes(-subtree.getNumberOfNodes());
((Iadem3) this.tree).updateNumberOfNodesSplitByTieBreaking(-subtree.numSplitsByBreakingTies);
}
}
Iadem3Subtree subtree = this.alternativeTree.get(index);
// rest nodes of this main tree
int count[] = new int[3];
super.getNumberOfNodes(count);
if (this.causeOfSplit == ((Iadem3) this.tree).SPLIT_BY_TIE_BREAKING) {
count[2]++;
}
((Iadem3) this.tree).updateNumberOfLeaves(-count[1]);
((Iadem3) this.tree).updateNumberOfNodes(-count[0] - count[1]);
((Iadem3) this.tree).updateNumberOfNodesSplitByTieBreaking(-count[2]);
//
AdaptiveSplitNode tmpParent = (AdaptiveSplitNode) this.parent;
Node newNode = subtree.getTreeRoot();
((Iadem3) tree).newTreeChange();
if (tmpParent == null) {
((Iadem3) tree).copyTree(subtree);
} else {
for (int i = 0; i < tmpParent.children.size(); i++) {
if (tmpParent.children.get(i) == this) {
tmpParent.children.set(i, newNode);
newNode.parent = tmpParent;
}
}
}
updateAttributes(newNode);
if (newNode instanceof AdaptiveSplitNode) {
AdaptiveSplitNode splitNode = (AdaptiveSplitNode) newNode;
for (Iadem3Subtree currentSubtree : splitNode.alternativeTree) {
updateSubtreeLevel(currentSubtree.getTreeRoot());
}
}
return newNode;
}
void updateAttributes(Node newNode) {
if (newNode == null) {
return;
}
newNode.setTree(this.tree);
if (newNode instanceof AdaptiveSplitNode) {
AdaptiveSplitNode splitNode = (AdaptiveSplitNode) newNode;
splitNode.leaf.setTree(this.tree);
for (Node child : splitNode.children) {
updateAttributes(child);
}
} else if (newNode instanceof AdaptiveLeafNode) {
AdaptiveLeafNode leafNode = (AdaptiveLeafNode) newNode;
AutoExpandVector virtualChildren = leafNode.getVirtualChildren();
for (VirtualNode child : virtualChildren) {
if (child != null) {
child.setTree(this.tree);
}
}
}
}
protected void updateSubtreeLevel(Node node) {
if (node != null) {
((Iadem3) node.getTree()).treeLevel--;
if (node instanceof AdaptiveSplitNode) {
AdaptiveSplitNode splitNode = (AdaptiveSplitNode) node;
for (Node child : splitNode.children) {
updateSubtreeLevelAux(child);
}
for (Iadem3Subtree subtree : splitNode.alternativeTree) {
updateSubtreeLevel(subtree.getTreeRoot());
}
}
}
}
protected void updateSubtreeLevelAux(Node node) {
if (node != null) {
if (node instanceof AdaptiveSplitNode) {
AdaptiveSplitNode splitNode = (AdaptiveSplitNode) node;
// update level in alternative subtrees
for (Iadem3Subtree subtree : splitNode.alternativeTree) {
updateSubtreeLevel(subtree.getTreeRoot());
}
for (Node child : splitNode.children) {
updateSubtreeLevelAux(child);
}
}
}
}
void createTree(Instance instance) throws IademException {
Iadem3 iadem3Tree = ((Iadem3) this.tree);
if (iadem3Tree.canCreateSubtree()) {
int maxTreeLevel = iadem3Tree.getMaxNestingLevels();
int maxAltSubtrees = iadem3Tree.getMaxAltSubtreesPerNode();
if ((maxTreeLevel == -1 || iadem3Tree.getTreeLevel() < maxTreeLevel)
&& (maxAltSubtrees == -1 || this.alternativeTree.size() < maxAltSubtrees)) {
if (this.estimator != null) {
Iadem3Subtree subtree = new Iadem3Subtree(this,
iadem3Tree.getTreeLevel() + 1,
(Iadem3) this.tree,
instance);
this.alternativeTree.add(subtree);
((Iadem3) tree).setNewTree();
}
}
}
}
public int getNumTrees() {
int trees = this.alternativeTree.size() == 0 ? 0 : 1;
for (Node child : children) {
if (child instanceof AdaptiveSplitNode) {
trees += ((AdaptiveSplitNode) child).getNumTrees();
}
}
for (Iadem3Subtree subtree : this.alternativeTree) {
if (subtree.getTreeRoot() instanceof AdaptiveSplitNode) {
AdaptiveSplitNode node = (AdaptiveSplitNode) subtree.getTreeRoot();
trees += node.getNumTrees();
}
}
return trees;
}
@Override
public double[] getClassVotes(Instance observacion) {
double[] classDist = this.leaf.getClassVotes(observacion);
double thisError = this.estimator.getEstimation();
double leafError = this.leaf.estimator.getEstimation();
int childIndex = instanceChildIndex(observacion);
if (childIndex >= 0 && thisError < leafError) {
Node hijo = getChild(childIndex);
classDist = hijo.getClassVotes(observacion);
}
for (Iadem3Subtree subtree : this.alternativeTree) {
double[] tmp = subtree.getClassVotes(observacion);
double altWeight = 1.0 - subtree.estimacionValorMedio();
for (int j = 0; j < classDist.length; j++) {
classDist[j] = classDist[j] + tmp[j] * altWeight;
}
}
return classDist;
}
@Override
public int getSubtreeNodeCount() {
int tmp = super.getSubtreeNodeCount();
for (Iadem3Subtree subtree : this.alternativeTree) {
tmp += subtree.getTreeRoot().getSubtreeNodeCount();
}
return tmp;
}
public double getErrorEstimation() {
return this.estimator.getEstimation();
}
@Override
public void getNumberOfNodes(int[] count) {
for (Iadem3Subtree tree : this.alternativeTree) {
tree.getNumberOfNodes(count);
}
if (this.causeOfSplit == ((Iadem3) this.tree).SPLIT_BY_TIE_BREAKING) {
count[2]++;
}
super.getNumberOfNodes(count);
}
public int getNumberOfSubtrees() {
int count = this.alternativeTree.size();
for (Iadem3Subtree subtree : this.alternativeTree) {
count += ((Iadem3) subtree).getNumberOfSubtrees();
}
for (Node child : children) {
if (child instanceof AdaptiveSplitNode) {
count += ((AdaptiveSplitNode) child).getNumberOfSubtrees();
}
}
return count;
}
private void prune() {
this.leaf.setSplit(true);
for (Node node = this.parent; node != null; node = node.parent) {
((AdaptiveSplitNode) node).leaf.restartVariablesAtDrift();
this.leaf.estimator = (AbstractChangeDetector) this.leaf.estimator.copy();
}
// update tree
this.leaf.setTree(this.tree);
AutoExpandVector nodeList = this.leaf.getVirtualChildren();
for (VirtualNode node : nodeList) {
if (node != null) {
node.setTree(this.tree);
}
}
this.leaf.setParent(this.parent);
if (this.parent == null) {
this.tree.setTreeRoot(this.leaf);
} else {
((SplitNode) this.parent).changeChildren(this, this.leaf);
}
int count[] = new int[3];
getNumberOfNodes(count);
((Iadem3) this.tree).updateNumberOfLeaves(-count[1] + 1);
((Iadem3) this.tree).updateNumberOfNodes(-count[0] - count[1] + 1);
((Iadem3) this.tree).updateNumberOfNodesSplitByTieBreaking(-count[2]);
}
}
public interface restartsVariablesAtDrift {
public void resetVariablesAtDrift();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy