aima.core.learning.learners.DecisionTreeLearner Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of aima-core Show documentation
Show all versions of aima-core Show documentation
AIMA-Java Core Algorithms from the book Artificial Intelligence a Modern Approach 3rd Ed.
package aima.core.learning.learners;
import java.util.Iterator;
import java.util.List;
import aima.core.learning.framework.DataSet;
import aima.core.learning.framework.Example;
import aima.core.learning.framework.Learner;
import aima.core.learning.inductive.ConstantDecisonTree;
import aima.core.learning.inductive.DecisionTree;
import aima.core.util.Util;
/**
* @author Ravi Mohan
* @author Mike Stampone
*/
public class DecisionTreeLearner implements Learner {
private DecisionTree tree;
private String defaultValue;
public DecisionTreeLearner() {
this.defaultValue = "Unable To Classify";
}
// used when you have to test a non induced tree (eg: for testing)
public DecisionTreeLearner(DecisionTree tree, String defaultValue) {
this.tree = tree;
this.defaultValue = defaultValue;
}
//
// START-Learner
/**
* Induces the decision tree from the specified set of examples
*
* @param ds
* a set of examples for constructing the decision tree
*/
@Override
public void train(DataSet ds) {
List attributes = ds.getNonTargetAttributes();
this.tree = decisionTreeLearning(ds, attributes,
new ConstantDecisonTree(defaultValue));
}
@Override
public String predict(Example e) {
return (String) tree.predict(e);
}
@Override
public int[] test(DataSet ds) {
int[] results = new int[] { 0, 0 };
for (Example e : ds.examples) {
if (e.targetValue().equals(tree.predict(e))) {
results[0] = results[0] + 1;
} else {
results[1] = results[1] + 1;
}
}
return results;
}
// END-Learner
//
/**
* Returns the decision tree of this decision tree learner
*
* @return the decision tree of this decision tree learner
*/
public DecisionTree getDecisionTree() {
return tree;
}
//
// PRIVATE METHODS
//
private DecisionTree decisionTreeLearning(DataSet ds,
List attributeNames, ConstantDecisonTree defaultTree) {
if (ds.size() == 0) {
return defaultTree;
}
if (allExamplesHaveSameClassification(ds)) {
return new ConstantDecisonTree(ds.getExample(0).targetValue());
}
if (attributeNames.size() == 0) {
return majorityValue(ds);
}
String chosenAttribute = chooseAttribute(ds, attributeNames);
DecisionTree tree = new DecisionTree(chosenAttribute);
ConstantDecisonTree m = majorityValue(ds);
List values = ds.getPossibleAttributeValues(chosenAttribute);
for (String v : values) {
DataSet filtered = ds.matchingDataSet(chosenAttribute, v);
List newAttribs = Util.removeFrom(attributeNames,
chosenAttribute);
DecisionTree subTree = decisionTreeLearning(filtered, newAttribs, m);
tree.addNode(v, subTree);
}
return tree;
}
private ConstantDecisonTree majorityValue(DataSet ds) {
Learner learner = new MajorityLearner();
learner.train(ds);
return new ConstantDecisonTree(learner.predict(ds.getExample(0)));
}
private String chooseAttribute(DataSet ds, List attributeNames) {
double greatestGain = 0.0;
String attributeWithGreatestGain = attributeNames.get(0);
for (String attr : attributeNames) {
double gain = ds.calculateGainFor(attr);
if (gain > greatestGain) {
greatestGain = gain;
attributeWithGreatestGain = attr;
}
}
return attributeWithGreatestGain;
}
private boolean allExamplesHaveSameClassification(DataSet ds) {
String classification = ds.getExample(0).targetValue();
Iterator iter = ds.iterator();
while (iter.hasNext()) {
Example element = iter.next();
if (!(element.targetValue().equals(classification))) {
return false;
}
}
return true;
}
}