All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeLearner Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                CategorizationTreeLearner.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright October 22, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.function.categorization.Categorizer;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.ArgumentChecker;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;

/**
 * The {@code CategorizationTreeLearner} class implements a supervised learning 
 * algorithm for learning a categorization tree.
 *
 * @param   The input type for the tree.
 * @param   The output type for the tree.
 * @author Justin Basilico
 * @since  2.0
 */
public class CategorizationTreeLearner
    extends AbstractDecisionTreeLearner
    implements SupervisedBatchLearner>
{

    /** The default threshold for making a leaf node based on count. */
    public static final int DEFAULT_LEAF_COUNT_THRESHOLD = 1;

    /** The default maximum depth to grow the tree to. */
    public static final int DEFAULT_MAX_DEPTH = -1;

    /** The threshold for making a node a leaf, determined by how many
     *  instances fall in the threshold. */
    protected int leafCountThreshold;

    /** The maximum depth for the tree. Ignored if less than 1. */
    protected int maxDepth;

    /** Prior probabilities for the different categories.  If null,
     *  the priors default to the category frequencies of the training
     *  data.
     */
    protected Map priors;

    /**
     * How often each category appears in training data.
     */
    protected transient Map trainCounts;


    /**
     * Creates a new instance of CategorizationTreeLearner.
     */
    public CategorizationTreeLearner()
    {
        this(null);
    }

    /**
     * Creates a new instance of CategorizationTreeLearner.
     *
     * @param  deciderLearner The learner for the decision function
     */
    public CategorizationTreeLearner(
        final DeciderLearner deciderLearner)
    {
        this(deciderLearner, 
             DEFAULT_LEAF_COUNT_THRESHOLD, 
             DEFAULT_MAX_DEPTH,
             null);
    }

    /**
     * Creates a new instance of CategorizationTreeLearner.
     *
     * @param   deciderLearner
     *      The learner for the decision function.
     * @param   leafCountThreshold
     *      The leaf count threshold. Must be non-negative.
     * @param   maxDepth
     *      The maximum depth for the tree.
     */
    public CategorizationTreeLearner(
        final DeciderLearner deciderLearner,
        final int leafCountThreshold,
        final int maxDepth)
    {
        this(deciderLearner, leafCountThreshold, maxDepth, null);
    }


    /**
     * Creates a new instance of CategorizationTreeLearner.
     *
     * @param   deciderLearner
     *      The learner for the decision function.
     * @param   leafCountThreshold
     *      The leaf count threshold. Must be non-negative.
     * @param   maxDepth
     *      The maximum depth for the tree.
     * @param   priors
     *      Prior probabilities for categories.  (See setCategoryPriors().)
     */
    public CategorizationTreeLearner(
        final DeciderLearner deciderLearner,
        final int leafCountThreshold,
        final int maxDepth,
        final Map priors)
    {
        super(deciderLearner);

        this.setLeafCountThreshold(leafCountThreshold);
        this.setMaxDepth(maxDepth);
        this.setCategoryPriors(priors);
    }

    @Override
    public CategorizationTreeLearner clone()
    {
        final CategorizationTreeLearner result = (CategorizationTreeLearner) 
            super.clone();
        
        result.priors = this.priors == null ? null : new LinkedHashMap<>(this.priors);
        result.trainCounts = this.trainCounts == null ? null : new LinkedHashMap<>(this.trainCounts);
        
        return result;
    }
    
    @Override
    public CategorizationTree learn(
        Collection>
            data)
    {
        if ( data == null )
        {
            // Bad data.
            return null;
        }

        final DefaultDataDistribution rootCounts =
            getOutputCounts(data);

        trainCounts = new HashMap();
        for (OutputType category : rootCounts.getDomain()) 
        {
            int freq = (int)(rootCounts.get(category));
            trainCounts.put(category, new Integer(freq));
        }

        // Configure prior weights if supported by split criterion.
        if (deciderLearner instanceof PriorWeightedNodeLearner)
        {
            // The compiler is unable to figure out that this cast is
            // safe.  Suppress the incorrect warning.
            @SuppressWarnings("unchecked")
            PriorWeightedNodeLearner criterion =
                (PriorWeightedNodeLearner)(deciderLearner);
            criterion.configure(priors, trainCounts);
        }

        // Recursively learn the node.
	CategorizationTree tree = 
            new CategorizationTree(
                this.learnNode(data, null),
                new HashSet(rootCounts.getDomain()));

        trainCounts = null;
        return tree;
    }

    /**
     * Recursively learns the categorization tree using the given collection
     * of data, returning the created node.
     *
     * @param  data The set of data to learn a node from.
     * @param  parent The parent node.
     * @return The categorization tree node learned from the given data.
     */
    @Override
    protected CategorizationTreeNode learnNode(
        final Collection> data,
        final AbstractDecisionTreeNode parent)
    {
        if (data == null || data.size() <= 0)
        {
            // Invalid data, nothing to learn. This case should never happen.
            return null;
        }


// TODO: Revisit the tree data structures later to determine if
// really necessary to include category on internal nodes.  An
// implementation should be possible that does not spend the
// storage on internal nodes and only searches for the most
// common output when it is known that the node is a leaf.
//
// Also, would it be cleaner to push areAllOutputsEqual() into
// the search for a decision function?  Both of these likely
// need to scan through all the data.
// -- mamunso (2010/11/18)

        // We put the most probable output category on every node in
        // the tree, in case we get a bad decision function or leaf
        // node. This ensures That we can always make a
        // categorization.
        OutputType mostProbOutput = computeMaxProbPrediction(data);

        // Give the node we are creating the most probable output.
        final CategorizationTreeNode node =
            new CategorizationTreeNode(
                parent, mostProbOutput);

        // Check for termination conditions that produce a leaf node.
        boolean isLeaf = this.areAllOutputsEqual(data)
            || data.size() <= this.leafCountThreshold
            || (this.maxDepth > 0 && node.getDepth() >= this.maxDepth);

        if (!isLeaf)
        {
            // Search for a decision function to split the data.
            final Categorizer decider =
                this.getDeciderLearner().learn(data);

            if (decider != null)
            {
                // We learned a good decider.
                node.setDecider(decider);

                // Learn the child nodes.
                super.learnChildNodes(node, data, decider);
            }
            else
            {
                // Failed to find a decision function ==> there is no
                // attribute that separates the values of different
                // output categories.  This node necessarily becomes a
                // leaf.
                isLeaf = true;
            }
        }

        // Return the new node we've created.
        return node;
    }
    
    /**
     * Creates a histogram of values based on the output values in the given
     * collection of pairs.
     *
     * @param    The type of the outputs to count over.
     * @param  data The data to create the output count histogram for.
     * @return The output count histogram.
     */
    public static  DefaultDataDistribution getOutputCounts(
        final Collection> data)
    {
        // Create the histogram.
        DefaultDataDistribution counts =
            new DefaultDataDistribution();
        
        if ( data == null )
        {
            // Bad data.
            return counts;
        }
        
        for ( InputOutputPair example : data )
        {
            // Add the output.
            final OutputType output = example.getOutput();
            counts.increment(output);
        }
        
        // Return the histogram.
        return counts;
    }

    /**
     * Return the most probable output value, taking into
     * consideration both the frequency counts in the data sample (at
     * the current node) and the prior proabalities for possible
     * output values.
     *
     * @param  The type of outputs.
     * @param data 
     *    The data sample, with output labels for each data point.
     *    The sample must contain at least 1 data point.
     * @return The output value with highest conditional probability.
     */
    private OutputType computeMaxProbPrediction(
        final Collection> data)
    {
        DefaultDataDistribution nodeCounts = getOutputCounts(data);
        if (priors == null) {
            // With no explicit prior, the most probable prediction is
            // the most common category.
            return nodeCounts.getMaxValueKey();
        }

        // Iterate over possible predictions, and keep track of the
        // prediction with highest probability.  Note that these
        // probabilities are not normalized.  (It would be easy, just
        // divide by sum of the unnormalized probs . . . but since
        // that would be a constant scaling, the maximum probability
        // prediction would be the same.)
        double bestProb = -1.0;
        OutputType bestVal = null;
        for (OutputType category : nodeCounts.getDomain()) {
            double likelihood = 
                nodeCounts.get(category) / trainCounts.get(category);
            double prior = priors.get(category);
            double prob = prior * likelihood;
            if (prob > bestProb) {
                bestProb = prob;
                bestVal = category;
            }
        }

        return bestVal;
    }


    /**
     * Gets the leaf count threshold, which determines the number of elements
     * at which to make an element into a leaf.
     *
     * @return The leaf count threshold.
     */
    public int getLeafCountThreshold()
    {
        return this.leafCountThreshold;
    }

    /**
     * Sets the leaf count threshold, which determines the number of elements
     * at which to make an element into a leaf.
     *
     * @param   leafCountThreshold
     *          The leaf count threshold. Must be non-negative.
     */
    public void setLeafCountThreshold(
        final int leafCountThreshold)
    {
        ArgumentChecker.assertIsNonNegative("leafCountThreshold", leafCountThreshold);
        this.leafCountThreshold = leafCountThreshold;
    }

    /**
     * Gets the maximum depth to grow the tree.
     *
     * @return
     *      The maximum depth to grow the tree. Zero or less means no
     *      maximum depth.
     */
    public int getMaxDepth()
    {
        return this.maxDepth;
    }

    /**
     * Sets the maximum depth to grow the tree.
     *
     * @param   maxDepth
     *      The maximum depth to grow the tree. Zero or less means no
     *      maximum depth.
     */
    public void setMaxDepth(
        final int maxDepth)
    {
        this.maxDepth = maxDepth;
    }

    /**
     * 

Set prior category probabilities. A higher prior * probability for a category will cause the tree learner to * weight examples from that category more highly.

* *

If the priors are not manually specified (through this * method or passing priors into the constructor), prior * probabilities default to the frequencies of the different * categories in the training data.

* * @param priors * If null, use default prior probabilities. Otherwise, priors * becomes the new prior weights. In the latter case, * priors.keySet() contain the same values as the possible * categories in data passed to the learn() method. */ public void setCategoryPriors(Map priors) { if (priors == null) { this.priors = null; } else { this.priors = new LinkedHashMap(priors); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy