gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeLearner Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: CategorizationTreeLearner.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright October 22, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.tree;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.function.categorization.Categorizer;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.ArgumentChecker;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* The {@code CategorizationTreeLearner} class implements a supervised learning
* algorithm for learning a categorization tree.
*
* @param The input type for the tree.
* @param The output type for the tree.
* @author Justin Basilico
* @since 2.0
*/
public class CategorizationTreeLearner
extends AbstractDecisionTreeLearner
implements SupervisedBatchLearner>
{
/** The default threshold for making a leaf node based on count. */
public static final int DEFAULT_LEAF_COUNT_THRESHOLD = 1;
/** The default maximum depth to grow the tree to. */
public static final int DEFAULT_MAX_DEPTH = -1;
/** The threshold for making a node a leaf, determined by how many
* instances fall in the threshold. */
protected int leafCountThreshold;
/** The maximum depth for the tree. Ignored if less than 1. */
protected int maxDepth;
/** Prior probabilities for the different categories. If null,
* the priors default to the category frequencies of the training
* data.
*/
protected Map priors;
/**
* How often each category appears in training data.
*/
protected transient Map trainCounts;
/**
* Creates a new instance of CategorizationTreeLearner.
*/
public CategorizationTreeLearner()
{
this(null);
}
/**
* Creates a new instance of CategorizationTreeLearner.
*
* @param deciderLearner The learner for the decision function
*/
public CategorizationTreeLearner(
final DeciderLearner super InputType, OutputType, ?, ?> deciderLearner)
{
this(deciderLearner,
DEFAULT_LEAF_COUNT_THRESHOLD,
DEFAULT_MAX_DEPTH,
null);
}
/**
* Creates a new instance of CategorizationTreeLearner.
*
* @param deciderLearner
* The learner for the decision function.
* @param leafCountThreshold
* The leaf count threshold. Must be non-negative.
* @param maxDepth
* The maximum depth for the tree.
*/
public CategorizationTreeLearner(
final DeciderLearner super InputType, OutputType, ?, ?> deciderLearner,
final int leafCountThreshold,
final int maxDepth)
{
this(deciderLearner, leafCountThreshold, maxDepth, null);
}
/**
* Creates a new instance of CategorizationTreeLearner.
*
* @param deciderLearner
* The learner for the decision function.
* @param leafCountThreshold
* The leaf count threshold. Must be non-negative.
* @param maxDepth
* The maximum depth for the tree.
* @param priors
* Prior probabilities for categories. (See setCategoryPriors().)
*/
public CategorizationTreeLearner(
final DeciderLearner super InputType, OutputType, ?, ?> deciderLearner,
final int leafCountThreshold,
final int maxDepth,
final Map priors)
{
super(deciderLearner);
this.setLeafCountThreshold(leafCountThreshold);
this.setMaxDepth(maxDepth);
this.setCategoryPriors(priors);
}
@Override
public CategorizationTreeLearner clone()
{
final CategorizationTreeLearner result = (CategorizationTreeLearner)
super.clone();
result.priors = this.priors == null ? null : new LinkedHashMap<>(this.priors);
result.trainCounts = this.trainCounts == null ? null : new LinkedHashMap<>(this.trainCounts);
return result;
}
@Override
public CategorizationTree learn(
Collection extends InputOutputPair extends InputType, OutputType>>
data)
{
if ( data == null )
{
// Bad data.
return null;
}
final DefaultDataDistribution rootCounts =
getOutputCounts(data);
trainCounts = new HashMap();
for (OutputType category : rootCounts.getDomain())
{
int freq = (int)(rootCounts.get(category));
trainCounts.put(category, new Integer(freq));
}
// Configure prior weights if supported by split criterion.
if (deciderLearner instanceof PriorWeightedNodeLearner)
{
// The compiler is unable to figure out that this cast is
// safe. Suppress the incorrect warning.
@SuppressWarnings("unchecked")
PriorWeightedNodeLearner criterion =
(PriorWeightedNodeLearner)(deciderLearner);
criterion.configure(priors, trainCounts);
}
// Recursively learn the node.
CategorizationTree tree =
new CategorizationTree(
this.learnNode(data, null),
new HashSet(rootCounts.getDomain()));
trainCounts = null;
return tree;
}
/**
* Recursively learns the categorization tree using the given collection
* of data, returning the created node.
*
* @param data The set of data to learn a node from.
* @param parent The parent node.
* @return The categorization tree node learned from the given data.
*/
@Override
protected CategorizationTreeNode learnNode(
final Collection extends InputOutputPair extends InputType, OutputType>> data,
final AbstractDecisionTreeNode parent)
{
if (data == null || data.size() <= 0)
{
// Invalid data, nothing to learn. This case should never happen.
return null;
}
// TODO: Revisit the tree data structures later to determine if
// really necessary to include category on internal nodes. An
// implementation should be possible that does not spend the
// storage on internal nodes and only searches for the most
// common output when it is known that the node is a leaf.
//
// Also, would it be cleaner to push areAllOutputsEqual() into
// the search for a decision function? Both of these likely
// need to scan through all the data.
// -- mamunso (2010/11/18)
// We put the most probable output category on every node in
// the tree, in case we get a bad decision function or leaf
// node. This ensures That we can always make a
// categorization.
OutputType mostProbOutput = computeMaxProbPrediction(data);
// Give the node we are creating the most probable output.
final CategorizationTreeNode node =
new CategorizationTreeNode(
parent, mostProbOutput);
// Check for termination conditions that produce a leaf node.
boolean isLeaf = this.areAllOutputsEqual(data)
|| data.size() <= this.leafCountThreshold
|| (this.maxDepth > 0 && node.getDepth() >= this.maxDepth);
if (!isLeaf)
{
// Search for a decision function to split the data.
final Categorizer super InputType, ? extends Object> decider =
this.getDeciderLearner().learn(data);
if (decider != null)
{
// We learned a good decider.
node.setDecider(decider);
// Learn the child nodes.
super.learnChildNodes(node, data, decider);
}
else
{
// Failed to find a decision function ==> there is no
// attribute that separates the values of different
// output categories. This node necessarily becomes a
// leaf.
isLeaf = true;
}
}
// Return the new node we've created.
return node;
}
/**
* Creates a histogram of values based on the output values in the given
* collection of pairs.
*
* @param The type of the outputs to count over.
* @param data The data to create the output count histogram for.
* @return The output count histogram.
*/
public static DefaultDataDistribution getOutputCounts(
final Collection extends InputOutputPair, OutputType>> data)
{
// Create the histogram.
DefaultDataDistribution counts =
new DefaultDataDistribution();
if ( data == null )
{
// Bad data.
return counts;
}
for ( InputOutputPair, OutputType> example : data )
{
// Add the output.
final OutputType output = example.getOutput();
counts.increment(output);
}
// Return the histogram.
return counts;
}
/**
* Return the most probable output value, taking into
* consideration both the frequency counts in the data sample (at
* the current node) and the prior proabalities for possible
* output values.
*
* @param The type of outputs.
* @param data
* The data sample, with output labels for each data point.
* The sample must contain at least 1 data point.
* @return The output value with highest conditional probability.
*/
private OutputType computeMaxProbPrediction(
final Collection extends InputOutputPair, OutputType>> data)
{
DefaultDataDistribution nodeCounts = getOutputCounts(data);
if (priors == null) {
// With no explicit prior, the most probable prediction is
// the most common category.
return nodeCounts.getMaxValueKey();
}
// Iterate over possible predictions, and keep track of the
// prediction with highest probability. Note that these
// probabilities are not normalized. (It would be easy, just
// divide by sum of the unnormalized probs . . . but since
// that would be a constant scaling, the maximum probability
// prediction would be the same.)
double bestProb = -1.0;
OutputType bestVal = null;
for (OutputType category : nodeCounts.getDomain()) {
double likelihood =
nodeCounts.get(category) / trainCounts.get(category);
double prior = priors.get(category);
double prob = prior * likelihood;
if (prob > bestProb) {
bestProb = prob;
bestVal = category;
}
}
return bestVal;
}
/**
* Gets the leaf count threshold, which determines the number of elements
* at which to make an element into a leaf.
*
* @return The leaf count threshold.
*/
public int getLeafCountThreshold()
{
return this.leafCountThreshold;
}
/**
* Sets the leaf count threshold, which determines the number of elements
* at which to make an element into a leaf.
*
* @param leafCountThreshold
* The leaf count threshold. Must be non-negative.
*/
public void setLeafCountThreshold(
final int leafCountThreshold)
{
ArgumentChecker.assertIsNonNegative("leafCountThreshold", leafCountThreshold);
this.leafCountThreshold = leafCountThreshold;
}
/**
* Gets the maximum depth to grow the tree.
*
* @return
* The maximum depth to grow the tree. Zero or less means no
* maximum depth.
*/
public int getMaxDepth()
{
return this.maxDepth;
}
/**
* Sets the maximum depth to grow the tree.
*
* @param maxDepth
* The maximum depth to grow the tree. Zero or less means no
* maximum depth.
*/
public void setMaxDepth(
final int maxDepth)
{
this.maxDepth = maxDepth;
}
/**
* Set prior category probabilities. A higher prior
* probability for a category will cause the tree learner to
* weight examples from that category more highly.
*
* If the priors are not manually specified (through this
* method or passing priors into the constructor), prior
* probabilities default to the frequencies of the different
* categories in the training data.
*
* @param priors
* If null, use default prior probabilities. Otherwise, priors
* becomes the new prior weights. In the latter case,
* priors.keySet() contain the same values as the possible
* categories in data passed to the learn() method.
*/
public void setCategoryPriors(Map priors)
{
if (priors == null) {
this.priors = null;
}
else {
this.priors = new LinkedHashMap(priors);
}
}
}