gov.sandia.cognition.learning.algorithm.tree.AbstractVectorThresholdMaximumGainLearner Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: AbstractVectorThresholdMaximumGainLearner.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright November 25, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.tree;
import gov.sandia.cognition.collection.ArrayUtil;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
/**
* An abstract class for decider learners that produce a threshold function
* on a vector element based on maximizing some gain value. It handles the
* looping over the elements of the vector and then for each element looping
* over the possible split points. Subclasses only need to define a method to
* compute the gain of a given split.
*
* @param
* The output category type for the training data.
* @author Justin Basilico
* @since 3.0
*/
public abstract class AbstractVectorThresholdMaximumGainLearner
extends AbstractCloneableSerializable
implements VectorThresholdLearner
{
/** The default value for the minimum split size is {@value}. */
public static final int DEFAULT_MIN_SPLIT_SIZE = 1;
/** The threshold for allowing a split to be made, determined by how many
* instances fall in each left or right sides of the split. Both sides
* must have at least this number of instances. Must be positive. */
protected int minSplitSize;
/** The array of dimensions for the learner to consider. If this is null,
* then all dimensions are considered. */
protected int[] dimensionsToConsider;
/**
* Creates a new {@code AbstractVectorThresholdMaximumGainLearner}.
*/
public AbstractVectorThresholdMaximumGainLearner()
{
this(DEFAULT_MIN_SPLIT_SIZE, null);
}
/**
*
* Creates a new {@code AbstractVectorThresholdMaximumGainLearner}.
*
* @param minSplitSize
* The minimum split size. Must be positive.
* @param dimensionsToConsider
* The array of vector dimensions to consider. Null means all of them
* are considered.
*/
public AbstractVectorThresholdMaximumGainLearner(
final int minSplitSize,
final int[] dimensionsToConsider)
{
super();
this.setMinSplitSize(minSplitSize);
this.setDimensionsToConsider(dimensionsToConsider);
}
@Override
public AbstractVectorThresholdMaximumGainLearner clone()
{
@SuppressWarnings("unchecked")
final AbstractVectorThresholdMaximumGainLearner result = (AbstractVectorThresholdMaximumGainLearner)
super.clone();
result.dimensionsToConsider = ArrayUtil.copy(this.dimensionsToConsider);
return result;
}
@Override
public VectorElementThresholdCategorizer learn(
final Collection extends InputOutputPair extends Vectorizable, OutputType>> data)
{
final int totalCount = CollectionUtil.size(data);
if (totalCount <= 1)
{
// Nothing to learn.
return null;
}
// Compute the base count values for the node.
final DefaultDataDistribution baseCounts =
CategorizationTreeLearner.getOutputCounts(data);
// Pre-allocate a workspace of data for computing the gain.
final ArrayList> workspace =
new ArrayList>(totalCount);
for (int i = 0; i < totalCount; i++)
{
workspace.add(new DefaultWeightedValue());
}
// Figure out the dimensionality of the data.
final int dimensionality = DatasetUtil.getInputDimensionality(data);
// Go through all the dimensions to find the one with the best gain
// and the best threshold.
double bestGain = -1.0;
int bestIndex = -1;
double bestThreshold = 0.0;
final int dimensionsCount = this.dimensionsToConsider == null ?
dimensionality : this.dimensionsToConsider.length;
for (int i = 0; i < dimensionsCount; i++)
{
final int index = this.dimensionsToConsider == null ?
i : this.dimensionsToConsider[i];
// Compute the best gain-threshold pair for the given dimension of
// the data.
final DefaultPair gainThresholdPair =
this.computeBestGainAndThreshold(data, index, baseCounts);
if (gainThresholdPair == null)
{
// There was no gain-threshold pair that created a
// threshold.
continue;
}
// Get the gain from the pair.
final double gain = gainThresholdPair.getFirst();
// Determine if this is the best gain seen.
if (bestIndex == -1 || gain > bestGain)
{
// This is the best gain, so store the gain, threshold,
// and index.
final double threshold = gainThresholdPair.getSecond();
bestGain = gain;
bestIndex = index;
bestThreshold = threshold;
}
}
if (bestIndex < 0)
{
// There was no dimension that provided any gain for the data,
// so no decision function can be made.
return null;
}
else
{
// Create the decision function for the best gain.
return new VectorElementThresholdCategorizer(
bestIndex, bestThreshold);
}
}
/**
* Computes the best gain and threshold for a given dimension using the
* computeSplitGain method for each potential split point of values for the
* given dimension.
*
* @param data
* The data to use to compute the threshold.
* @param dimension
* The dimension to compute the threshold for.
* @param baseCounts
* Information about the base category counts.
* @return
* A pair containing the best gain computed and its associated
* threshold. If there is no good split point, null is returned. This
* can happen if there is no data or every value is the same.
*/
public DefaultPair computeBestGainAndThreshold(
final Collection extends InputOutputPair extends Vectorizable, OutputType>> data,
final int dimension,
final DefaultDataDistribution baseCounts)
{
final int totalCount = data.size();
final ArrayList> workspace =
new ArrayList>(totalCount);
for (int i = 0; i < totalCount; i++)
{
workspace.add(new DefaultWeightedValue());
}
return this.computeBestGainAndThreshold(data, dimension, baseCounts,
workspace);
}
/**
* Computes the best gain and threshold for a given dimension using the
* computeSplitGain method for each potential split point of values for the
* given dimension.
*
* @param data
* The data to use to compute the threshold.
* @param dimension
* The dimension to compute the threshold for.
* @param baseCounts
* Information about the base category counts.
* @param values
* A workspace to store the values of the data in. Recycled to avoid
* recreating a large array each time.
* @return
* A pair containing the best gain computed and its associated
* threshold. If there is no good split point, null is returned. This
* can happen if there is no data or every value is the same.
*/
protected DefaultPair computeBestGainAndThreshold(
final Collection extends InputOutputPair extends Vectorizable, OutputType>> data,
final int dimension,
final DefaultDataDistribution baseCounts,
final ArrayList> values)
{
// We can only compute thresholds if we can put the minimum split
// size on each side.
final int totalCount = data.size();
if (totalCount < 2 * this.minSplitSize)
{
return null;
}
// To compute the gain we will sort all of the values along the given
// dimension and then walk along the values to determine the best
// threshold.
// The first step is to create a list of (value, output) pairs for the
// given dimension. We do this by using the given workspace.
int index = 0;
for (InputOutputPair extends Vectorizable, OutputType> example
: data)
{
// Add this example to the list.
final Vector input = example.getInput().convertToVector();
final OutputType output = example.getOutput();
final double value = input.getElement(dimension);
DefaultWeightedValue entry = values.get(index);
entry.setWeight(value);
entry.setValue(output);
index++;
}
// Sort the list in ascending order by value.
Collections.sort(values,
DefaultWeightedValue.WeightComparator.getInstance());
// Get the smallest and largest values.
final double smallestValue = values.get(0).getWeight();
final double largestValue = values.get(totalCount - 1).getWeight();
// If all the values on this dimension are the same then there is
// nothing to split on.
if (smallestValue >= largestValue)
{
// All of the values are the same.
return null;
}
// In order to find the best split we are going to keep track of the
// counts of each label on each side of the threshold. This means
// that we maintain two counting objects.
// To start with all of the examples are on the positive side of
// the split, so we initialize the base counts (all the data points)
// and the negative counts with nothing.
final DefaultDataDistribution positiveCounts =
baseCounts.clone();
final DefaultDataDistribution negativeCounts =
new DefaultDataDistribution(baseCounts.getDomain().size());
// We are going to loop over all the values to compute the best
// gain and the best threshold.
double bestGain = Double.NEGATIVE_INFINITY;
double bestTieBreaker = Double.NEGATIVE_INFINITY;
double bestThreshold = Double.NEGATIVE_INFINITY;
boolean validSplit = false;
// We need to keep track of the previous value for two reasons:
// 1) To determine if we've already tested the value, since we loop
// over a >= threshold.
// 2) So that the threshold can be computed to be half way between
// two values.
//
// We advance i through values and stop whenever value[i] != value[i-1].
// These are all the points where it is meaningful to evaluate a split.
// All values to the left of i go into the negative count bucket.
final int maxIndex = totalCount - this.minSplitSize;
double previousValue = 0.0;
for (int i = 0; i <= maxIndex; i++)
{
final DefaultWeightedValue valueLabel = values.get(i);
final OutputType label = valueLabel.getValue();
final double value = valueLabel.getWeight();
// Check if it is worth evaluating a threshold between
// previous and next value.
if (i < this.minSplitSize)
{
// We are going to loop over a threshold value that is >=
// to handle equivalent values properly. However, we also need
// to ignore the first minSplitSize values.
bestThreshold = value;
}
else if (value != previousValue)
{
// Compute the gain.
final double gain = computeSplitGain(
baseCounts, positiveCounts, negativeCounts);
if (gain >= bestGain)
{
final double proportionPositive =
positiveCounts.getTotal() / totalCount;
final double proportionNegative =
negativeCounts.getTotal() / totalCount;
// This is our tiebreaker criteria for the case where the
// gains are equal. It means that we prefer ties that are
// more balanced in how they split (50%/50% being optimal).
final double tieBreaker = 1.0
- Math.abs(proportionPositive - proportionNegative);
if ( gain > bestGain
|| tieBreaker > bestTieBreaker)
{
// For the decision threshold we actually want to pick
// the point that is half way between the current value
// and the previous value. Hopefully this will be more
// robust than using just the value itself.
double threshold = (value + previousValue) / 2.0;
// If a round-off error drops the threshold back to
// the previous value, set it equal to the current
// value since we use a >= threshold.
if (threshold <= previousValue)
{
threshold = value;
}
bestGain = gain;
bestTieBreaker = tieBreaker;
bestThreshold = threshold;
validSplit = true;
}
}
}
// Store this value as the previous value.
previousValue = value;
// Update the sides of the split.
positiveCounts.decrement(label);
negativeCounts.increment(label);
}
// Sanity check to make sure we found a threshold that
// partitions the values.
if ( bestThreshold <= smallestValue
|| bestThreshold > largestValue)
{
throw new RuntimeException(
"bestThreshold (" + bestThreshold + ") lies outside range of "
+ "values (" + smallestValue + ", " + largestValue + "]");
}
if (!validSplit)
{
return null;
}
// Return the pair containing the best gain and best threshold
// found.
return new DefaultPair(bestGain, bestThreshold);
}
/**
* Computes the gain of a given split. The base counts contains the
* category information before the split.
*
* @param baseCounts
* The base category information before splitting. Contains the sum of
* the positive and negative counts.
* @param positiveCounts
* The category information on the positive side of the split.
* @param negativeCounts
* The category information on the negative side of the split.
* @return
* The gain of the given split computed by comparing the positive and
* negative counts to the base counts.
*/
public abstract double computeSplitGain(
final DefaultDataDistribution baseCounts,
final DefaultDataDistribution positiveCounts,
final DefaultDataDistribution negativeCounts);
@Override
public int[] getDimensionsToConsider()
{
return this.dimensionsToConsider;
}
@Override
public void setDimensionsToConsider(
final int... dimensionsToConsider)
{
this.dimensionsToConsider = dimensionsToConsider;
}
/**
* Gets the minimum split size. This is the minimum number of examples
* that can fall on either side of the split for it to be valid. If there
* is not at least twice this number of examples in the input data, then
* no split is returned.
*
* @return
* The minimum split size. Must be positive.
*/
public int getMinSplitSize()
{
return this.minSplitSize;
}
/**
* Sets the minimum split size. This is the minimum number of examples
* that can fall on either side of the split for it to be valid. If there
* is not at least twice this number of examples in the input data, then
* no split is returned.
*
* @param minSplitSize
* The minimum split size. Must be positive.
*/
public void setMinSplitSize(
final int minSplitSize)
{
ArgumentChecker.assertIsPositive("minSplitSize", minSplitSize);
this.minSplitSize = minSplitSize;
}
}