All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.tree.RegressionTreeLearner Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                RegressionTreeLearner.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright November 29, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.Categorizer;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

/**
 * The {@code RegressionTreeLearner} class implements a learning algorithm for
 * a regression tree that makes use of a decider learner and a regression 
 * learner. The tree grows as a decision tree until it gets to a leaf node
 * (determined by a minimum number of nodes), and then learns a regression
 * function at the leaf node.
 *
 * @param   The type of the input to the tree.
 * @author Justin Basilico
 * @since  2.0
 */
public class RegressionTreeLearner
    extends AbstractDecisionTreeLearner
    implements SupervisedBatchLearner>
{

    /** The default threshold for making a leaf node based on count. */
    public static final int DEFAULT_LEAF_COUNT_THRESHOLD = 4;

    /** The default maximum depth to grow the tree to. */
    public static final int DEFAULT_MAX_DEPTH = -1;

    /** The learning algorithm for the regression function. */
    protected BatchLearner
        >, 
         ? extends Evaluator> 
         regressionLearner;

    /** The threshold for making a node a leaf, determined by how many 
     *  instances fall in the threshold. */
    protected int leafCountThreshold;

    /** The maximum depth for the tree. Ignored if less than 1. */
    protected int maxDepth;

    /**
     * Creates a new instance of RegressionTreeLearner
     */
    public RegressionTreeLearner()
    {
        this(null);
    }
    
    /**
     * Creates a new instance of CategorizationTreeLearner with a mean node
     * learner
     *
     * @param  deciderLearner The learner for the decision function.
     */
    public RegressionTreeLearner(
        final DeciderLearner deciderLearner)
    {
        this(deciderLearner, null);
    }

    /**
     * Creates a new instance of CategorizationTreeLearner.
     *
     * @param  deciderLearner The learner for the decision function.
     * @param  regressionLearner The learner for the regression function.
     */
    public RegressionTreeLearner(
        final DeciderLearner deciderLearner,
        final BatchLearner
            >, 
             ? extends Evaluator>  
             regressionLearner)
    {
        this(deciderLearner, regressionLearner, 
            DEFAULT_LEAF_COUNT_THRESHOLD, DEFAULT_MAX_DEPTH);
    }

    /**
     * Creates a new instance of CategorizationTreeLearner.
     *
     * @param   deciderLearner The learner for the decision function.
     * @param   regressionLearner The learner for the regression function.
     * @param   leafCountThreshold 
     *          The leaf count threshold, which determines the number of 
     *          elements at which to learn a regression function.
     * @param   maxDepth
     *          The maximum depth to learn the tree. Must be positive.
     */
    public RegressionTreeLearner(
        final DeciderLearner deciderLearner,
        final BatchLearner
            >, 
             ? extends Evaluator>  
             regressionLearner,
        final int leafCountThreshold,
        final int maxDepth)
    {
        super(deciderLearner);

        this.setRegressionLearner(regressionLearner);
        this.setLeafCountThreshold(leafCountThreshold);
        this.setMaxDepth(maxDepth);
    }

    @Override
    public RegressionTreeLearner clone()
    {
        final RegressionTreeLearner result = (RegressionTreeLearner) super.clone();
        result.regressionLearner = ObjectUtil.cloneSafe(this.regressionLearner);
        return result;
    }
    
    @Override
    public RegressionTree learn(
        Collection> data)
    {
        if (data == null)
        {
            // Bad data.
            return null;
        }
        else
        {
            // Recursively learn the node.
            return new RegressionTree(
                this.learnNode(data, null));
        }
    }

    /**
     * Recursively learns the regression tree using the given collection
     * of data, returning the created node.
     *
     * @param  data The set of data to learn a node from.
     * @param  parent The parent node.
     * @return The regression tree node learned from the given data.
     */
    @Override
    protected RegressionTreeNode learnNode(
        final Collection> data,
        final AbstractDecisionTreeNode parent)
    {

        if (data == null || data.size() <= 0)
        {
            // Invalid data, nothing to learn.
            return null;
        }
        
        // Figure out the depth of the node.
        int depth = parent == null ? 1 : 1 + parent.getDepth();

        // Determine if this is a leaf node by checking the cound threshold and
        // determining if all the outputs are equal.
        final boolean isLeaf =
               data.size() <= this.leafCountThreshold
            || (this.maxDepth > 0 && depth >= maxDepth)
            || this.areAllOutputsEqual(data);

        // We use the mean value as part of the node.
        final double mean = DatasetUtil.computeOutputMean(data);

        // Learn the decision function for this node.
        Categorizer decider = null;
        if (!isLeaf)
        {
            // Only learn for a leaf node.
            decider = this.getDeciderLearner().learn(data);
        }

        // If we couldn't learn a decider, then this is also aleaf node.
        if (isLeaf || decider == null)
        {
            // This is a leaf node.
            // Build a regression function for the node.
            Evaluator scalarFunction = null;

            if (this.regressionLearner != null)
            {
                scalarFunction = this.regressionLearner.learn(data);
            }
            // else - Without a regression learner the output value for the
            //        tree will be the mean.

            // Create the leaf node.
            return new RegressionTreeNode(
                parent, scalarFunction, mean);
        }

        // We give the node we are creating the most common output value.
        final RegressionTreeNode node =
            new RegressionTreeNode(
                parent, decider, mean);

        // Learn the child nodes.
        this.learnChildNodes(node, data, decider);

        // Return the new node we've created.
        return node;
    }

    /**
     * Gets the regression learner that is to be used to fit a function
     * approximator to the values in the tree.
     *
     * @return  The regression learner.
     */
    public BatchLearner
        >, 
         ? extends Evaluator>
        getRegressionLearner()
    {
        return this.regressionLearner;
    }

    /**
     * Sets the regression learner that is to be used to fit a function
     * approximator to the values in the tree.
     *
     * @param   regressionLearner The regression learner.
     */
    public void setRegressionLearner(
        final BatchLearner
            >, 
             ? extends Evaluator>  
             regressionLearner)
    {
        this.regressionLearner = regressionLearner;
    }

    /**
     * Gets the leaf count threshold, which determines the number of elements
     * at which to learn a regression function.
     *
     * @return The leaf count threshold.
     */
    public int getLeafCountThreshold()
    {
        return this.leafCountThreshold;
    }

    /**
     * Sets the leaf count threshold, which determines the number of elements
     * at which to learn a regression function.
     *
     * @param   leafCountThreshold 
     *          The leaf count threshold. Must be non-negative.
     */
    public void setLeafCountThreshold(
        final int leafCountThreshold)
    {
        ArgumentChecker.assertIsNonNegative("leafCountThreshold", leafCountThreshold);
        this.leafCountThreshold = leafCountThreshold;
    }

    /**
     * Gets the maximum depth to grow the tree.
     *
     * @return
     *      The maximum depth to grow the tree. Zero or less means no
     *      maximum depth.
     */
    public int getMaxDepth()
    {
        return this.maxDepth;
    }

    /**
     * Sets the maximum depth to grow the tree.
     *
     * @param   maxDepth
     *      The maximum depth to grow the tree. Zero or less means no
     *      maximum depth.
     */
    public void setMaxDepth(
        final int maxDepth)
    {
        this.maxDepth = maxDepth;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy