All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.libs.reduction.single.confusion.ConfusionBasedAlgorithm Maven / Gradle / Ivy

There is a newer version: 0.2.7
Show newest version
package ai.libs.reduction.single.confusion;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.weka.WekaUtil;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.MCTreeNodeReD;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;

public class ConfusionBasedAlgorithm extends AConfusionBasedAlgorithm {

	private Logger logger = LoggerFactory.getLogger(ConfusionBasedAlgorithm.class);

	public MCTreeNodeReD buildClassifier(final Instances data, final Collection pClassifierNames) throws Exception {

		if (this.logger.isInfoEnabled()) {
			this.logger.info("START: {}", data.relationName());
		}
		int seed = 0;

		Map confusionMatrices = new HashMap<>();
		int numClasses = data.numClasses();
		this.logger.info("Computing confusion matrices ...");
		for (int i = 0; i < 10; i++) {
			List split = WekaUtil.getStratifiedSplit(data, seed, .7f);

			/* compute confusion matrices for each classifier */
			for (String classifier : pClassifierNames) {
				try {
					Classifier c = AbstractClassifier.forName(classifier, null);
					c.buildClassifier(split.get(0));
					Evaluation eval = new Evaluation(split.get(0));
					eval.evaluateModel(c, split.get(1));
					if (!confusionMatrices.containsKey(classifier)) {
						confusionMatrices.put(classifier, new double[numClasses][numClasses]);
					}

					double[][] currentCM = confusionMatrices.get(classifier);
					double[][] addedCM = eval.confusionMatrix();

					for (int j = 0; j < numClasses; j++) {
						for (int k = 0; k < numClasses; k++) {
							currentCM[j][k] += addedCM[j][k];
						}
					}
				} catch (Exception e) {
					this.logger.error("Unexpected exception has been thrown", e);
				}
			}
		}
		this.logger.info("done");

		/* compute zero-conflict sets for each classifier */
		Map>> zeroConflictSets = new HashMap<>();
		for (Entry entry : confusionMatrices.entrySet()) {
			zeroConflictSets.put(entry.getKey(), this.getZeroConflictSets(entry.getValue()));
		}

		/* greedily identify the best left and right pair (that make least mistakes) */
		Collection> classifierPairs = SetUtil.cartesianProduct(confusionMatrices.keySet(), 2);
		String bestLeft = null;
		String bestRight = null;
		String bestInner = null;
		Collection bestLeftClasses = null;
		Collection bestRightClasses = null;
		for (List classifierPair : classifierPairs) {
			String c1 = classifierPair.get(0);
			String c2 = classifierPair.get(1);

			Collection> z1 = zeroConflictSets.get(c1);
			Collection> z2 = zeroConflictSets.get(c2);

			/* create candidate split */
			int sizeOfBestCombo = 0;
			for (Collection zeroSet1 : z1) {
				for (Collection zeroSet2 : z2) {
					Collection coveredClassesOfThisPair = SetUtil.union(zeroSet1, zeroSet2);
					if (coveredClassesOfThisPair.size() > sizeOfBestCombo) {
						bestLeft = c1;
						bestRight = c2;
						sizeOfBestCombo = coveredClassesOfThisPair.size();
						bestLeftClasses = zeroSet1;
						bestRightClasses = zeroSet2;
					}
				}
			}
		}

		/* greedily complete the best candidates */
		double[][] cm1 = confusionMatrices.get(bestLeft);
		double[][] cm2 = confusionMatrices.get(bestRight);
		for (int cId = 0; cId < numClasses; cId++) {
			if (!bestLeftClasses.contains(cId) && !bestRightClasses.contains(cId)) {

				/* compute effect of adding this class to the respective clusters */
				Collection newBestZ1 = new ArrayList<>(bestLeftClasses);
				newBestZ1.add(cId);
				int p1 = this.getPenaltyOfCluster(newBestZ1, cm1);
				Collection newBestZ2 = new ArrayList<>(bestRightClasses);
				newBestZ2.add(cId);
				int p2 = this.getPenaltyOfCluster(newBestZ2, cm2);

				if (p1 < p2) {
					bestLeftClasses = newBestZ1;
				} else {
					bestRightClasses = newBestZ2;
				}
			}
		}
		int p1 = this.getPenaltyOfCluster(bestLeftClasses, cm1);
		int p2 = this.getPenaltyOfCluster(bestRightClasses, cm2);

		/* create the split problem */
		Map classMap = new HashMap<>();
		for (int i1 : bestLeftClasses) {
			classMap.put(data.classAttribute().value(i1), "l");
		}
		for (int i2 : bestRightClasses) {
			classMap.put(data.classAttribute().value(i2), "r");
		}
		Instances newData = WekaUtil.getRefactoredInstances(data, classMap);
		List binaryInnerSplit = WekaUtil.getStratifiedSplit(newData, seed, .7f);

		/* now identify the classifier that can best separate these two clusters */
		int leastSeenMistakes = Integer.MAX_VALUE;
		for (String classifier : pClassifierNames) {
			try {
				Classifier c = AbstractClassifier.forName(classifier, null);

				c.buildClassifier(binaryInnerSplit.get(0));
				Evaluation eval = new Evaluation(newData);
				eval.evaluateModel(c, binaryInnerSplit.get(1));
				int mistakes = (int) eval.incorrect();
				int overallMistakes = p1 + p2 + mistakes;
				if (overallMistakes < leastSeenMistakes) {
					leastSeenMistakes = overallMistakes;
					this.logger.info("New best system: {}/{}/{} with {}", bestLeft, bestRight, classifier, leastSeenMistakes);
					bestInner = classifier;
				}
			} catch (Exception e) {
				this.logger.error("Exception has been thrown unexpectedly.", e);
			}
		}
		if (bestInner == null) {
			throw new IllegalStateException("No best inner has been chosen!");
		}

		/* now create MCTreeNode with choices */
		MCTreeNodeReD tree = new MCTreeNodeReD(bestInner, bestLeftClasses.stream().map(i -> data.classAttribute().value(i)).collect(Collectors.toList()), bestLeft,
				bestRightClasses.stream().map(i -> data.classAttribute().value(i)).collect(Collectors.toList()), bestRight);
		tree.buildClassifier(data);
		return tree;
	}

	private Collection> getZeroConflictSets(final double[][] confusionMatrix) {
		Collection blackList = new ArrayList<>();
		Collection> partitions = new ArrayList<>();

		int leastConflictingClass = -1;
		do {
			leastConflictingClass = this.getLeastConflictingClass(confusionMatrix, blackList);
			if (leastConflictingClass >= 0) {
				Collection cluster = new ArrayList<>();
				cluster.add(leastConflictingClass);
				do {
					Collection newCluster = this.incrementCluster(cluster, confusionMatrix, blackList);
					if (newCluster.size() == cluster.size()) {
						break;
					}
					cluster = newCluster;
					if (cluster.contains(-1)) {
						throw new IllegalStateException("Computed illegal cluster: " + cluster);
					}
				} while (this.getPenaltyOfCluster(cluster, confusionMatrix) == 0 && cluster.size() < confusionMatrix.length);
				blackList.addAll(cluster);
				partitions.add(cluster);
			}
		} while (leastConflictingClass >= 0 && blackList.size() < confusionMatrix.length);

		return partitions;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy