com.github.chen0040.trees.ensembles.MultiClassAdaBoost Maven / Gradle / Ivy
package com.github.chen0040.trees.ensembles;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.TupleTwo;
import com.github.chen0040.data.utils.discretizers.KMeansDiscretizer;
import com.github.chen0040.trees.id3.ID3;
import lombok.Setter;
import lombok.Getter;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
* Created by xschen on 9/6/2017.
*/
public class MultiClassAdaBoost {
private final List classifiers = new ArrayList<>();
private final List> model = new ArrayList<>();
@Getter
@Setter
private int treeCount = 100;
private KMeansDiscretizer discretizer=new KMeansDiscretizer();
@Getter
private final List classLabels = new ArrayList<>();
@Setter
@Getter
public double dataSampleRate = 0.2; // value between 0 and 1
public MultiClassAdaBoost(){
}
public void fit(DataFrame frame){
frame = discretizer.fitAndTransform(frame);
classifiers.clear();
classLabels.clear();
for(int m = 0; m < treeCount; ++m) {
ID3 classifier = new ID3(false);
classifier.fit(frame.shuffle().split(0.2)._1());
classifiers.add(classifier);
}
final int N = frame.rowCount();
double[] weights = new double[N];
Set labels = new HashSet<>();
for(int i=0; i < N; ++i){
weights[i] = 1.0 / N;
labels.add(frame.row(i).categoricalTarget());
}
classLabels.addAll(labels);
for(int t = 0; t < treeCount; ++t) {
double min_err = Double.MAX_VALUE;
int M = -1;
for (int m = 0; m < treeCount; ++m) {
ID3 classifier_m = classifiers.get(m);
double err_m = 0;
for (int i = 0; i < N; ++i) {
DataRow row = frame.row(i);
String predicted = classifier_m.classify(row);
if (!predicted.equals(row.categoricalTarget())) {
err_m += weights[i];
}
}
if (min_err > err_m) {
min_err = err_m;
M = m;
}
}
// Add next classifier
ID3 classifier_t = classifiers.get(M);
double alpha_t = 0.5 * Math.log((1-min_err) / min_err);
model.add(new TupleTwo<>(M, alpha_t));
// Update weight
double sum = 0;
for(int i=0; i < N; ++i){
DataRow row_i = frame.row(i);
String predicted = classifier_t.classify(row_i);
double II = predicted.equals(row_i.categoricalTarget()) ? 0 : 1;
weights[i] = weights[i] * Math.exp(alpha_t * II);
sum += weights[i];
}
// Normalize weight
for(int i=0; i < N; ++i) {
weights[i] /= sum;
}
}
}
public String classify(DataRow row) {
row = discretizer.transform(row);
double max_sum_k = Double.NEGATIVE_INFINITY;
int K = -1;
for(int k =0; k < classLabels.size(); ++k){
String candidate = classLabels.get(k);
double sum_k = 0;
for(int m = 0; m < treeCount; ++m) {
TupleTwo t = model.get(m);
ID3 classifier_t = classifiers.get(t._1());
double alpha_t = t._2();
String predicted = classifier_t.classify(row);
double II = predicted.equals(candidate) ? 1 : 0;
sum_k += alpha_t * II;
}
if(sum_k > max_sum_k) {
max_sum_k =sum_k;
K = k;
}
}
return classLabels.get(K);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy