com.blazebit.ai.decisiontree.impl.SimpleDecisionTree Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of blaze-ai-utils Show documentation
Show all versions of blaze-ai-utils Show documentation
Artificial Intelligence Utilities.
package com.blazebit.ai.decisiontree.impl;
import com.blazebit.ai.decisiontree.*;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
/**
* @author Christian Beikov
*/
public class SimpleDecisionTree implements DecisionTree {
private final Set attributes;
private final AttributeSelector attributeSelector;
private final DecisionNode root;
public SimpleDecisionTree(final Set attributes, final Set> examples, final AttributeSelector attributeSelector) {
this.attributes = new HashSet(attributes);
this.attributeSelector = attributeSelector;
this.root = new SimpleDecisionNodeFactory(new HashSet(0)).createNode(null, examples);
}
@Override
public Set apply(final Item test) {
return root.apply(test);
}
@Override
public T applySingle(final Item test) {
return root.applySingle(test);
}
private static class LeafNode implements DecisionNode {
private final T result;
private final Set results;
public LeafNode() {
this.result = null;
this.results = Collections.emptySet();
}
public LeafNode(final Set> examples) {
final Set tempResults = new HashSet(examples.size());
for (final Example example : examples) {
tempResults.add(example.getResult());
}
if (tempResults.size() > 1) {
this.result = null;
} else {
this.result = tempResults.iterator().next();
}
this.results = Collections.unmodifiableSet(tempResults);
}
@Override
public Attribute getAttribute() {
return null;
}
@Override
public Set apply(final Item item) {
return results;
}
@Override
public T applySingle(final Item item) {
final T localResult = result;
if (localResult == null) {
throw new IllegalArgumentException("Ambigious result for the given item!");
}
return localResult;
}
}
private class SimpleDecisionNodeFactory implements DecisionNodeFactory {
private final Set usedAttributes;
public SimpleDecisionNodeFactory(final Set usedAttributes) {
this.usedAttributes = usedAttributes;
}
@Override
public DecisionNode createNode(final Attribute usedAttribute, final Set> examples) {
if (examples.size() < 1) {
return new LeafNode();
}
final Set localUsedAttributes = usedAttributes;
final Set usedAttributesNew;
if (usedAttribute != null) {
usedAttributesNew = new HashSet(localUsedAttributes.size() + 1);
usedAttributesNew.addAll(localUsedAttributes);
usedAttributesNew.add(usedAttribute);
} else {
usedAttributesNew = localUsedAttributes;
}
final Attribute selectedAttribute = attributeSelector.select(examples, attributes, usedAttributesNew);
if (selectedAttribute == null) {
return new LeafNode(examples);
}
return selectedAttribute.createNode(new SimpleDecisionNodeFactory(usedAttributesNew), examples);
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy