com.blazebit.ai.decisiontree.impl.ID3AttributeSelector Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of blaze-ai-utils Show documentation
Show all versions of blaze-ai-utils Show documentation
Artificial Intelligence Utilities.
package com.blazebit.ai.decisiontree.impl;
import com.blazebit.ai.decisiontree.*;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
/**
* @author Christian Beikov
*/
public class ID3AttributeSelector implements AttributeSelector {
@Override
public Attribute select(final Set> examples, final Set availableAttributes, final Set usedAttributes) {
Attribute attribute = null;
float attributeRem = Float.MAX_VALUE;
int attributeValueCount = Integer.MAX_VALUE;
float positives = 0;
float negatives = 0;
final Map> attributeUsage = new HashMap>();
/* Make array for performance */
final Example[] exampleArray = examples.toArray(new Example[0]);
final int examplesSize = exampleArray.length;
for (final Attribute attr : availableAttributes) {
if (usedAttributes.contains(attr)) {
continue;
}
final Map valueUsage = new HashMap();
attributeUsage.put(attr, valueUsage);
for (int i = 0; i < examplesSize; i++) {
final AttributeValue value = exampleArray[i].getValues().get(attr);
Pair valueUsageExamples = valueUsage.get(value);
if (valueUsageExamples == null) {
valueUsageExamples = new Pair();
valueUsage.put(value, valueUsageExamples);
}
if (exampleArray[i].getResult()) {
++valueUsageExamples.positive;
++positives;
} else {
++valueUsageExamples.negative;
++negatives;
}
}
}
if (positives > 0 && negatives > 0) {
for (final Map.Entry> entry : attributeUsage.entrySet()) {
final Attribute attr = entry.getKey();
final float rem = Pair.rem(entry.getValue().values(), positives, negatives);
if (rem < attributeRem) {
attribute = attr;
attributeRem = rem;
if (attr instanceof DiscreteAttribute) {
attributeValueCount = ((DiscreteAttribute) attr).getValues().size();
}
} else if (attr instanceof DiscreteAttribute && (rem == attributeRem) && ((DiscreteAttribute) attr).getValues().size() < attributeValueCount) {
attribute = attr;
attributeRem = rem;
attributeValueCount = ((DiscreteAttribute) attr).getValues().size();
}
}
}
return attribute;
}
private static class Pair {
static final float log2 = (float) Math.log(2);
float positive = 0;
float negative = 0;
double entropy() {
final float localPositive = positive;
final float localNegative = negative;
final float localLog2 = log2;
if (localPositive == 0 || localNegative == 0) {
return 0;
}
final float p = localPositive / (localPositive + localNegative);
return -p * (Math.log(p) / localLog2) - (1 - p) * (Math.log(1 - p) / localLog2);
}
static float rem(final Collection pairs, final float positives, final float negatives) {
float rem = 0;
for (final Pair p : pairs) {
rem += ((p.positive + p.negative) / (positives + negatives)) * p.entropy();
}
return rem;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy