All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.blazebit.ai.decisiontree.impl.SimpleDecisionTree Maven / Gradle / Ivy

There is a newer version: 0.1.21
Show newest version
package com.blazebit.ai.decisiontree.impl;

import com.blazebit.ai.decisiontree.*;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;

/**
 * @author Christian Beikov
 */
public class SimpleDecisionTree implements DecisionTree {

    private final Set attributes;
    private final AttributeSelector attributeSelector;
    private final DecisionNode root;

    public SimpleDecisionTree(final Set attributes, final Set> examples, final AttributeSelector attributeSelector) {
        this.attributes = new HashSet(attributes);
        this.attributeSelector = attributeSelector;
        this.root = new SimpleDecisionNodeFactory(new HashSet(0)).createNode(null, examples);
    }

    @Override
    public Set apply(final Item test) {
        return root.apply(test);
    }

    @Override
    public T applySingle(final Item test) {
        return root.applySingle(test);
    }

    private static class LeafNode implements DecisionNode {

        private final T result;
        private final Set results;

        public LeafNode() {
            this.result = null;
            this.results = Collections.emptySet();
        }

        public LeafNode(final Set> examples) {
            final Set tempResults = new HashSet(examples.size());

            for (final Example example : examples) {
                tempResults.add(example.getResult());
            }

            if (tempResults.size() > 1) {
                this.result = null;
            } else {
                this.result = tempResults.iterator().next();
            }

            this.results = Collections.unmodifiableSet(tempResults);
        }

        @Override
        public Attribute getAttribute() {
            return null;
        }

        @Override
        public Set apply(final Item item) {
            return results;
        }

        @Override
        public T applySingle(final Item item) {
            final T localResult = result;

            if (localResult == null) {
                throw new IllegalArgumentException("Ambigious result for the given item!");
            }

            return localResult;
        }
    }

    private class SimpleDecisionNodeFactory implements DecisionNodeFactory {

        private final Set usedAttributes;

        public SimpleDecisionNodeFactory(final Set usedAttributes) {
            this.usedAttributes = usedAttributes;
        }

        @Override
        public  DecisionNode createNode(final Attribute usedAttribute, final Set> examples) {
            if (examples.size() < 1) {
                return new LeafNode();
            }

            final Set localUsedAttributes = usedAttributes;
            final Set usedAttributesNew;

            if (usedAttribute != null) {
                usedAttributesNew = new HashSet(localUsedAttributes.size() + 1);
                usedAttributesNew.addAll(localUsedAttributes);
                usedAttributesNew.add(usedAttribute);
            } else {
                usedAttributesNew = localUsedAttributes;
            }

            final Attribute selectedAttribute = attributeSelector.select(examples, attributes, usedAttributesNew);

            if (selectedAttribute == null) {
                return new LeafNode(examples);
            }

            return selectedAttribute.createNode(new SimpleDecisionNodeFactory(usedAttributesNew), examples);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy