data:image/s3,"s3://crabby-images/02ace/02ace956f9868cf2a1a780bd2c0a517cd3a46077" alt="JAR search and dependency download from the Maven repository"
gov.sandia.cognition.learning.algorithm.tree.AbstractVectorThresholdMaximumGainLearner Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: AbstractVectorThresholdMaximumGainLearner.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright November 25, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.tree;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
/**
* An abstract class for decider learners that produce a threshold function
* on a vector element based on maximizing some gain value. It handles the
* looping over the elements of the vector and then for each element looping
* over the possible split points. Subclasses only need to define a method to
* compute the gain of a given split.
*
* @param
* The output category type for the training data.
* @author Justin Basilico
* @since 3.0
*/
public abstract class AbstractVectorThresholdMaximumGainLearner
extends AbstractCloneableSerializable
implements VectorThresholdMaximumGainLearner
{
/** The array of dimensions for the learner to consider. If this is null,
* then all dimensions are considered. */
protected int[] dimensionsToConsider;
/**
* Creates a new {@code AbstractVectorThresholdMaximumGainLearner}.
*/
public AbstractVectorThresholdMaximumGainLearner()
{
super();
}
@Override
public VectorElementThresholdCategorizer learn(
final Collection extends InputOutputPair extends Vectorizable, OutputType>> data)
{
final int totalCount = CollectionUtil.size(data);
if (totalCount <= 1)
{
// Nothing to learn.
return null;
}
// Compute the base count values for the node.
final DefaultDataDistribution baseCounts =
CategorizationTreeLearner.getOutputCounts(data);
// Pre-allocate a workspace of data for computing the gain.
final ArrayList> workspace =
new ArrayList>(totalCount);
for (int i = 0; i < totalCount; i++)
{
workspace.add(new DefaultWeightedValue());
}
// Figure out the dimensionality of the data.
final int dimensionality = getDimensionality(data);
// Go through all the dimensions to find the one with the best gain
// and the best threshold.
double bestGain = -1.0;
int bestIndex = -1;
double bestThreshold = 0.0;
final int dimensionsCount = this.dimensionsToConsider == null ?
dimensionality : this.dimensionsToConsider.length;
for (int i = 0; i < dimensionsCount; i++)
{
final int index = this.dimensionsToConsider == null ?
i : this.dimensionsToConsider[i];
// Compute the best gain-threshold pair for the given dimension of
// the data.
final DefaultPair gainThresholdPair =
this.computeBestGainAndThreshold(data, index, baseCounts);
if (gainThresholdPair == null)
{
// There was no gain-threshold pair that created a
// threshold.
continue;
}
// Get the gain from the pair.
final double gain = gainThresholdPair.getFirst();
// Determine if this is the best gain seen.
if (bestIndex == -1 || gain > bestGain)
{
// This is the best gain, so store the gain, threshold,
// and index.
final double threshold = gainThresholdPair.getSecond();
bestGain = gain;
bestIndex = index;
bestThreshold = threshold;
}
}
if (bestIndex < 0)
{
// There was no dimension that provided any gain for the data,
// so no decision function can be made.
return null;
}
else
{
// Create the decision function for the best gain.
return new VectorElementThresholdCategorizer(
bestIndex, bestThreshold);
}
}
/**
* Computes the best gain and threshold for a given dimension using the
* computeSplitGain method for each potential split point of values for the
* given dimension.
*
* @param data
* The data to use to compute the threshold.
* @param dimension
* The dimension to compute the threshold for.
* @param baseCounts
* Information about the base category counts.
* @return
* A pair containing the best gain computed and its associated
* threshold. If there is no good split point, null is returned. This
* can happen if there is no data or every value is the same.
*/
public DefaultPair computeBestGainAndThreshold(
final Collection extends InputOutputPair extends Vectorizable, OutputType>> data,
final int dimension,
final DefaultDataDistribution baseCounts)
{
final int totalCount = data.size();
final ArrayList> workspace =
new ArrayList>(totalCount);
for (int i = 0; i < totalCount; i++)
{
workspace.add(new DefaultWeightedValue());
}
return this.computeBestGainAndThreshold(data, dimension, baseCounts,
workspace);
}
/**
* Computes the best gain and threshold for a given dimension using the
* computeSplitGain method for each potential split point of values for the
* given dimension.
*
* @param data
* The data to use to compute the threshold.
* @param dimension
* The dimension to compute the threshold for.
* @param baseCounts
* Information about the base category counts.
* @param values
* A workspace to store the values of the data in. Recycled to avoid
* recreating a large array each time.
* @return
* A pair containing the best gain computed and its associated
* threshold. If there is no good split point, null is returned. This
* can happen if there is no data or every value is the same.
*/
protected DefaultPair computeBestGainAndThreshold(
final Collection extends InputOutputPair extends Vectorizable, OutputType>> data,
final int dimension,
final DefaultDataDistribution baseCounts,
final ArrayList> values)
{
// We can only compute thresholds for at least 1 value.
final int totalCount = data.size();
if (totalCount <= 1)
{
return null;
}
// To compute the gain we will sort all of the values along the given
// dimension and then walk along the values to determine the best
// threshold.
// The first step is to create a list of (value, output) pairs for the
// given dimension. We do this by using the given workspace.
int index = 0;
for (InputOutputPair extends Vectorizable, OutputType> example
: data)
{
// Add this example to the list.
final Vector input = example.getInput().convertToVector();
final OutputType output = example.getOutput();
final double value = input.getElement(dimension);
DefaultWeightedValue entry = values.get(index);
entry.setWeight(value);
entry.setValue(output);
index++;
}
// Sort the list in ascending order by value.
Collections.sort(values,
DefaultWeightedValue.WeightComparator.getInstance());
// Get the smallest and largest values.
final double smallestValue = values.get(0).getWeight();
final double largestValue = values.get(totalCount - 1).getWeight();
// If all the values on this dimension are the same then there is
// nothing to split on.
if (smallestValue >= largestValue)
{
// All of the values are the same.
return null;
}
// In order to find the best split we are going to keep track of the
// counts of each label on each side of the threshold. This means
// that we maintain two counting objects.
// To start with all of the examples are on the positive side of
// the split, so we initialize the base counts (all the data points)
// and the negative counts with nothing.
final DefaultDataDistribution positiveCounts =
baseCounts.clone();
final DefaultDataDistribution negativeCounts =
new DefaultDataDistribution(baseCounts.getDomain().size());
// We are going to loop over all the values to compute the best
// gain and the best threshold.
double bestGain = Double.NEGATIVE_INFINITY;
double bestTieBreaker = Double.NEGATIVE_INFINITY;
double bestThreshold = Double.NEGATIVE_INFINITY;
// We need to keep track of the previous value for two reasons:
// 1) To determine if we've already tested the value, since we loop
// over a >= threshold.
// 2) So that the threshold can be computed to be half way between
// two values.
//
// We advance i through values and stop whenever value[i] != value[i-1].
// These are all the points where it is meaningful to evaluate a split.
// All values to the left of i go into the negative count bucket.
double previousValue = smallestValue;
for (int i = 1; i < totalCount; i++)
{
// Move previous value to negative count bucket.
final OutputType label = values.get(i - 1).getValue();
positiveCounts.decrement(label);
negativeCounts.increment(label);
final double value = values.get(i).getWeight();
// Check if it is worth evaluating a threshold between
// previous and next value.
if (value != previousValue)
{
// Compute the gain.
final double gain = computeSplitGain(
baseCounts, positiveCounts, negativeCounts);
if (gain >= bestGain)
{
final double proportionPositive =
positiveCounts.getTotal() / totalCount;
final double proportionNegative =
negativeCounts.getTotal() / totalCount;
// This is our tiebreaker criteria for the case where the
// gains are equal. It means that we prefer ties that are
// more balanced in how they split (50%/50% being optimal).
final double tieBreaker = 1.0
- Math.abs(proportionPositive - proportionNegative);
if ( gain > bestGain
|| tieBreaker > bestTieBreaker)
{
// For the decision threshold we actually want to pick
// the point that is half way between the current value
// and the previous value. Hopefully this will be more
// robust than using just the value itself.
final double threshold =
(value + previousValue) / 2.0;
bestGain = gain;
bestTieBreaker = tieBreaker;
bestThreshold = threshold;
}
}
// Store this value as the previous value.
previousValue = value;
}
}
// Sanity check to make sure we found a threshold that
// partitions the values.
if ( bestThreshold <= smallestValue
|| bestThreshold >= largestValue)
{
throw new RuntimeException(
"bestThreshold (" + bestThreshold + ") lies outside range of values (" + smallestValue + ", " + largestValue + ")");
}
// Return the pair containing the best gain and best threshold
// found.
return new DefaultPair(bestGain, bestThreshold);
}
/**
* Computes the gain of a given split. The base counts contains the
* category information before the split.
*
* @param baseCounts
* The base category information before splitting. Contains the sum of
* the positive and negative counts.
* @param positiveCounts
* The category information on the positive side of the split.
* @param negativeCounts
* The category information on the negative side of the split.
* @return
* The gain of the given split computed by comparing the positive and
* negative counts to the base counts.
*/
public abstract double computeSplitGain(
final DefaultDataDistribution baseCounts,
final DefaultDataDistribution positiveCounts,
final DefaultDataDistribution negativeCounts);
@Override
public int[] getDimensionsToConsider()
{
return this.dimensionsToConsider;
}
@Override
public void setDimensionsToConsider(
final int[] dimensionsToConsider)
{
this.dimensionsToConsider = dimensionsToConsider;
}
/**
* Figures out the dimensionality of the Vector data.
*
* @param data The data.
* @return The dimensionality of the data in the vector.
*/
protected static int getDimensionality(
final Collection
extends InputOutputPair extends Vectorizable, ?>>
data)
{
if (CollectionUtil.isEmpty(data))
{
// Bad data.
return 0;
}
else
{
// Get the dimensionality of the first data element.
return CollectionUtil.getFirst(data).getInput().convertToVector()
.getDimensionality();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy