ai.libs.reduction.single.confusion.ConfusionBasedAlgorithm Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mlplan-ext-reduction Show documentation
Show all versions of mlplan-ext-reduction Show documentation
This project provides an implementation of the AutoML tool ML-Plan.
package ai.libs.reduction.single.confusion;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.weka.WekaUtil;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.MCTreeNodeReD;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;
public class ConfusionBasedAlgorithm extends AConfusionBasedAlgorithm {
private Logger logger = LoggerFactory.getLogger(ConfusionBasedAlgorithm.class);
public MCTreeNodeReD buildClassifier(final Instances data, final Collection pClassifierNames) throws Exception {
if (this.logger.isInfoEnabled()) {
this.logger.info("START: {}", data.relationName());
}
int seed = 0;
Map confusionMatrices = new HashMap<>();
int numClasses = data.numClasses();
this.logger.info("Computing confusion matrices ...");
for (int i = 0; i < 10; i++) {
List split = WekaUtil.getStratifiedSplit(data, seed, .7f);
/* compute confusion matrices for each classifier */
for (String classifier : pClassifierNames) {
try {
Classifier c = AbstractClassifier.forName(classifier, null);
c.buildClassifier(split.get(0));
Evaluation eval = new Evaluation(split.get(0));
eval.evaluateModel(c, split.get(1));
if (!confusionMatrices.containsKey(classifier)) {
confusionMatrices.put(classifier, new double[numClasses][numClasses]);
}
double[][] currentCM = confusionMatrices.get(classifier);
double[][] addedCM = eval.confusionMatrix();
for (int j = 0; j < numClasses; j++) {
for (int k = 0; k < numClasses; k++) {
currentCM[j][k] += addedCM[j][k];
}
}
} catch (Exception e) {
this.logger.error("Unexpected exception has been thrown", e);
}
}
}
this.logger.info("done");
/* compute zero-conflict sets for each classifier */
Map>> zeroConflictSets = new HashMap<>();
for (Entry entry : confusionMatrices.entrySet()) {
zeroConflictSets.put(entry.getKey(), this.getZeroConflictSets(entry.getValue()));
}
/* greedily identify the best left and right pair (that make least mistakes) */
Collection> classifierPairs = SetUtil.cartesianProduct(confusionMatrices.keySet(), 2);
String bestLeft = null;
String bestRight = null;
String bestInner = null;
Collection bestLeftClasses = null;
Collection bestRightClasses = null;
for (List classifierPair : classifierPairs) {
String c1 = classifierPair.get(0);
String c2 = classifierPair.get(1);
Collection> z1 = zeroConflictSets.get(c1);
Collection> z2 = zeroConflictSets.get(c2);
/* create candidate split */
int sizeOfBestCombo = 0;
for (Collection zeroSet1 : z1) {
for (Collection zeroSet2 : z2) {
Collection coveredClassesOfThisPair = SetUtil.union(zeroSet1, zeroSet2);
if (coveredClassesOfThisPair.size() > sizeOfBestCombo) {
bestLeft = c1;
bestRight = c2;
sizeOfBestCombo = coveredClassesOfThisPair.size();
bestLeftClasses = zeroSet1;
bestRightClasses = zeroSet2;
}
}
}
}
/* greedily complete the best candidates */
double[][] cm1 = confusionMatrices.get(bestLeft);
double[][] cm2 = confusionMatrices.get(bestRight);
for (int cId = 0; cId < numClasses; cId++) {
if (!bestLeftClasses.contains(cId) && !bestRightClasses.contains(cId)) {
/* compute effect of adding this class to the respective clusters */
Collection newBestZ1 = new ArrayList<>(bestLeftClasses);
newBestZ1.add(cId);
int p1 = this.getPenaltyOfCluster(newBestZ1, cm1);
Collection newBestZ2 = new ArrayList<>(bestRightClasses);
newBestZ2.add(cId);
int p2 = this.getPenaltyOfCluster(newBestZ2, cm2);
if (p1 < p2) {
bestLeftClasses = newBestZ1;
} else {
bestRightClasses = newBestZ2;
}
}
}
int p1 = this.getPenaltyOfCluster(bestLeftClasses, cm1);
int p2 = this.getPenaltyOfCluster(bestRightClasses, cm2);
/* create the split problem */
Map classMap = new HashMap<>();
for (int i1 : bestLeftClasses) {
classMap.put(data.classAttribute().value(i1), "l");
}
for (int i2 : bestRightClasses) {
classMap.put(data.classAttribute().value(i2), "r");
}
Instances newData = WekaUtil.getRefactoredInstances(data, classMap);
List binaryInnerSplit = WekaUtil.getStratifiedSplit(newData, seed, .7f);
/* now identify the classifier that can best separate these two clusters */
int leastSeenMistakes = Integer.MAX_VALUE;
for (String classifier : pClassifierNames) {
try {
Classifier c = AbstractClassifier.forName(classifier, null);
c.buildClassifier(binaryInnerSplit.get(0));
Evaluation eval = new Evaluation(newData);
eval.evaluateModel(c, binaryInnerSplit.get(1));
int mistakes = (int) eval.incorrect();
int overallMistakes = p1 + p2 + mistakes;
if (overallMistakes < leastSeenMistakes) {
leastSeenMistakes = overallMistakes;
this.logger.info("New best system: {}/{}/{} with {}", bestLeft, bestRight, classifier, leastSeenMistakes);
bestInner = classifier;
}
} catch (Exception e) {
this.logger.error("Exception has been thrown unexpectedly.", e);
}
}
if (bestInner == null) {
throw new IllegalStateException("No best inner has been chosen!");
}
/* now create MCTreeNode with choices */
MCTreeNodeReD tree = new MCTreeNodeReD(bestInner, bestLeftClasses.stream().map(i -> data.classAttribute().value(i)).collect(Collectors.toList()), bestLeft,
bestRightClasses.stream().map(i -> data.classAttribute().value(i)).collect(Collectors.toList()), bestRight);
tree.buildClassifier(data);
return tree;
}
private Collection> getZeroConflictSets(final double[][] confusionMatrix) {
Collection blackList = new ArrayList<>();
Collection> partitions = new ArrayList<>();
int leastConflictingClass = -1;
do {
leastConflictingClass = this.getLeastConflictingClass(confusionMatrix, blackList);
if (leastConflictingClass >= 0) {
Collection cluster = new ArrayList<>();
cluster.add(leastConflictingClass);
do {
Collection newCluster = this.incrementCluster(cluster, confusionMatrix, blackList);
if (newCluster.size() == cluster.size()) {
break;
}
cluster = newCluster;
if (cluster.contains(-1)) {
throw new IllegalStateException("Computed illegal cluster: " + cluster);
}
} while (this.getPenaltyOfCluster(cluster, confusionMatrix) == 0 && cluster.size() < confusionMatrix.length);
blackList.addAll(cluster);
partitions.add(cluster);
}
} while (leastConflictingClass >= 0 && blackList.size() < confusionMatrix.length);
return partitions;
}
}