
aima.core.learning.learners.AdaBoostLearner Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of aima-core Show documentation
Show all versions of aima-core Show documentation
AIMA-Java Core Algorithms from the book Artificial Intelligence a Modern Approach 3rd Ed.
package aima.core.learning.learners;
import java.util.Hashtable;
import java.util.List;
import aima.core.learning.framework.DataSet;
import aima.core.learning.framework.Example;
import aima.core.learning.framework.Learner;
import aima.core.util.Util;
import aima.core.util.datastructure.Table;
/**
* @author Ravi Mohan
*
*/
public class AdaBoostLearner implements Learner {
private List learners;
private DataSet dataSet;
private double[] exampleWeights;
private Hashtable learnerWeights;
public AdaBoostLearner(List learners, DataSet ds) {
this.learners = learners;
this.dataSet = ds;
initializeExampleWeights(ds.examples.size());
initializeHypothesisWeights(learners.size());
}
public void train(DataSet ds) {
initializeExampleWeights(ds.examples.size());
for (Learner learner : learners) {
learner.train(ds);
double error = calculateError(ds, learner);
if (error < 0.0001) {
break;
}
adjustExampleWeights(ds, learner, error);
double newHypothesisWeight = learnerWeights.get(learner)
* Math.log((1.0 - error) / error);
learnerWeights.put(learner, newHypothesisWeight);
}
}
public String predict(Example e) {
return weightedMajority(e);
}
public int[] test(DataSet ds) {
int[] results = new int[] { 0, 0 };
for (Example e : ds.examples) {
if (e.targetValue().equals(predict(e))) {
results[0] = results[0] + 1;
} else {
results[1] = results[1] + 1;
}
}
return results;
}
//
// PRIVATE METHODS
//
private String weightedMajority(Example e) {
List targetValues = dataSet.getPossibleAttributeValues(dataSet
.getTargetAttributeName());
Table table = createTargetValueLearnerTable(
targetValues, e);
return getTargetValueWithTheMaximumVotes(targetValues, table);
}
private Table createTargetValueLearnerTable(
List targetValues, Example e) {
// create a table with target-attribute values as rows and learners as
// columns and cells containing the weighted votes of each Learner for a
// target value
// Learner1 Learner2 Laerner3 .......
// Yes 0.83 0.5 0
// No 0 0 0.6
Table table = new Table(
targetValues, learners);
// initialize table
for (Learner l : learners) {
for (String s : targetValues) {
table.set(s, l, 0.0);
}
}
for (Learner learner : learners) {
String predictedValue = learner.predict(e);
for (String v : targetValues) {
if (predictedValue.equals(v)) {
table.set(v, learner, table.get(v, learner)
+ learnerWeights.get(learner) * 1);
}
}
}
return table;
}
private String getTargetValueWithTheMaximumVotes(List targetValues,
Table table) {
String targetValueWithMaxScore = targetValues.get(0);
double score = scoreOfValue(targetValueWithMaxScore, table, learners);
for (String value : targetValues) {
double scoreOfValue = scoreOfValue(value, table, learners);
if (scoreOfValue > score) {
targetValueWithMaxScore = value;
score = scoreOfValue;
}
}
return targetValueWithMaxScore;
}
private void initializeExampleWeights(int size) {
if (size == 0) {
throw new RuntimeException(
"cannot initialize Ensemble learning with Empty Dataset");
}
double value = 1.0 / (1.0 * size);
exampleWeights = new double[size];
for (int i = 0; i < size; i++) {
exampleWeights[i] = value;
}
}
private void initializeHypothesisWeights(int size) {
if (size == 0) {
throw new RuntimeException(
"cannot initialize Ensemble learning with Zero Learners");
}
learnerWeights = new Hashtable();
for (Learner le : learners) {
learnerWeights.put(le, 1.0);
}
}
private double calculateError(DataSet ds, Learner l) {
double error = 0.0;
for (int i = 0; i < ds.examples.size(); i++) {
Example e = ds.getExample(i);
if (!(l.predict(e).equals(e.targetValue()))) {
error = error + exampleWeights[i];
}
}
return error;
}
private void adjustExampleWeights(DataSet ds, Learner l, double error) {
double epsilon = error / (1.0 - error);
for (int j = 0; j < ds.examples.size(); j++) {
Example e = ds.getExample(j);
if ((l.predict(e).equals(e.targetValue()))) {
exampleWeights[j] = exampleWeights[j] * epsilon;
}
}
exampleWeights = Util.normalize(exampleWeights);
}
private double scoreOfValue(String targetValue,
Table table, List learners) {
double score = 0.0;
for (Learner l : learners) {
score += table.get(targetValue, l);
}
return score;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy