All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.tree.AbstractVectorThresholdMaximumGainLearner Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                AbstractVectorThresholdMaximumGainLearner.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright November 25, 2009, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 * See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;

/**
 * An abstract class for decider learners that produce a threshold function
 * on a vector element based on maximizing some gain value. It handles the
 * looping over the elements of the vector and then for each element looping
 * over the possible split points. Subclasses only need to define a method to
 * compute the gain of a given split.
 * 
 * @param   
 *      The output category type for the training data.
 * @author  Justin Basilico
 * @since   3.0
 */
public abstract class AbstractVectorThresholdMaximumGainLearner
    extends AbstractCloneableSerializable
    implements VectorThresholdMaximumGainLearner
{

    /** The array of dimensions for the learner to consider. If this is null,
     *  then all dimensions are considered. */
    protected int[] dimensionsToConsider;

    /**
     * Creates a new {@code AbstractVectorThresholdMaximumGainLearner}.
     */
    public AbstractVectorThresholdMaximumGainLearner()
    {
        super();
    }

    @Override
    public VectorElementThresholdCategorizer learn(
        final Collection> data)
    {
        final int totalCount = CollectionUtil.size(data);
        if (totalCount <= 1)
        {
            // Nothing to learn.
            return null;
        }

        // Compute the base count values for the node.
        final DefaultDataDistribution baseCounts =
            CategorizationTreeLearner.getOutputCounts(data);

        // Pre-allocate a workspace of data for computing the gain.
        final ArrayList> workspace =
            new ArrayList>(totalCount);
        for (int i = 0; i < totalCount; i++)
        {
            workspace.add(new DefaultWeightedValue());
        }

        // Figure out the dimensionality of the data.
        final int dimensionality = getDimensionality(data);

        // Go through all the dimensions to find the one with the best gain
        // and the best threshold.
        double bestGain = -1.0;
        int bestIndex = -1;
        double bestThreshold = 0.0;

        final int dimensionsCount = this.dimensionsToConsider == null ?
            dimensionality : this.dimensionsToConsider.length;
        for (int i = 0; i < dimensionsCount; i++)
        {
            final int index = this.dimensionsToConsider == null ?
                i : this.dimensionsToConsider[i];

            // Compute the best gain-threshold pair for the given dimension of
            // the data.
            final DefaultPair gainThresholdPair =
                this.computeBestGainAndThreshold(data, index, baseCounts);

            if (gainThresholdPair == null)
            {
                // There was no gain-threshold pair that created a
                // threshold.
                continue;
            }

            // Get the gain from the pair.
            final double gain = gainThresholdPair.getFirst();

            // Determine if this is the best gain seen.
            if (bestIndex == -1 || gain > bestGain)
            {
                // This is the best gain, so store the gain, threshold,
                // and index.
                final double threshold = gainThresholdPair.getSecond();
                bestGain = gain;
                bestIndex = index;
                bestThreshold = threshold;
            }
        }

        if (bestIndex < 0)
        {
            // There was no dimension that provided any gain for the data,
            // so no decision function can be made.
            return null;
        }
        else
        {
            // Create the decision function for the best gain.
            return new VectorElementThresholdCategorizer(
                bestIndex, bestThreshold);
        }
    }

    /**
     * Computes the best gain and threshold for a given dimension using the
     * computeSplitGain method for each potential split point of values for the
     * given dimension.
     *
     * @param   data
     *      The data to use to compute the threshold.
     * @param   dimension
     *      The dimension to compute the threshold for.
     * @param   baseCounts
     *      Information about the base category counts.
     * @return
     *      A pair containing the best gain computed and its associated
     *      threshold. If there is no good split point, null is returned. This
     *      can happen if there is no data or every value is the same.
     */
    public DefaultPair computeBestGainAndThreshold(
        final Collection> data,
        final int dimension,
        final DefaultDataDistribution baseCounts)
    {
        final int totalCount = data.size();
     
        final ArrayList> workspace =
            new ArrayList>(totalCount);
        for (int i = 0; i < totalCount; i++)
        {
            workspace.add(new DefaultWeightedValue());
        }
        return this.computeBestGainAndThreshold(data, dimension, baseCounts,
            workspace);
    }

    /**
     * Computes the best gain and threshold for a given dimension using the
     * computeSplitGain method for each potential split point of values for the
     * given dimension.
     *
     * @param   data
     *      The data to use to compute the threshold.
     * @param   dimension
     *      The dimension to compute the threshold for.
     * @param   baseCounts
     *      Information about the base category counts.
     * @param   values
     *      A workspace to store the values of the data in. Recycled to avoid
     *      recreating a large array each time.
     * @return
     *      A pair containing the best gain computed and its associated
     *      threshold. If there is no good split point, null is returned. This
     *      can happen if there is no data or every value is the same.
     */
    protected DefaultPair computeBestGainAndThreshold(
        final Collection> data,
        final int dimension,
        final DefaultDataDistribution baseCounts,
        final ArrayList> values)
    {
        // We can only compute thresholds for at least 1 value.
        final int totalCount = data.size();
        if (totalCount <= 1)
        {
            return null;
        }

        // To compute the gain we will sort all of the values along the given
        // dimension and then walk along the values to determine the best
        // threshold.

        // The first step is to create a list of (value, output) pairs for the
        // given dimension. We do this by using the given workspace.
        int index = 0;
        for (InputOutputPair example
            : data)
        {
            // Add this example to the list.
            final Vector input = example.getInput().convertToVector();
            final OutputType output = example.getOutput();
            final double value = input.getElement(dimension);
            DefaultWeightedValue entry = values.get(index);
            entry.setWeight(value);
            entry.setValue(output);
            index++;
        }

        // Sort the list in ascending order by value.
        Collections.sort(values, 
            DefaultWeightedValue.WeightComparator.getInstance());

        // Get the smallest and largest values.
        final double smallestValue = values.get(0).getWeight();
        final double largestValue = values.get(totalCount - 1).getWeight();

        // If all the values on this dimension are the same then there is
        // nothing to split on.
        if (smallestValue >= largestValue)
        {
            // All of the values are the same.
            return null;
        }

        // In order to find the best split we are going to keep track of the
        // counts of each label on each side of the threshold. This means
        // that we maintain two counting objects.
        // To start with all of the examples are on the positive side of
        // the split, so we initialize the base counts (all the data points)
        // and the negative counts with nothing.
        final DefaultDataDistribution positiveCounts =
            baseCounts.clone();
        final DefaultDataDistribution negativeCounts =
            new DefaultDataDistribution(baseCounts.getDomain().size());

        // We are going to loop over all the values to compute the best
        // gain and the best threshold.
        double bestGain = Double.NEGATIVE_INFINITY;
        double bestTieBreaker = Double.NEGATIVE_INFINITY;
        double bestThreshold = Double.NEGATIVE_INFINITY;

        // We need to keep track of the previous value for two reasons:
        //    1) To determine if we've already tested the value, since we loop
        //       over a >= threshold.
        //    2) So that the threshold can be computed to be half way between
        //       two values.
	       //
        // We advance i through values and stop whenever value[i] != value[i-1].
        // These are all the points where it is meaningful to evaluate a split.
        // All values to the left of i go into the negative count bucket.
        double previousValue = smallestValue;
        for (int i = 1; i < totalCount; i++)
        {
            // Move previous value to negative count bucket.
            final OutputType label = values.get(i - 1).getValue();
            positiveCounts.decrement(label);
            negativeCounts.increment(label);

            final double value = values.get(i).getWeight();

            // Check if it is worth evaluating a threshold between
            // previous and next value.
            if (value != previousValue)
            {
                // Compute the gain.
                final double gain = computeSplitGain(
                    baseCounts, positiveCounts, negativeCounts);

                if (gain >= bestGain)
                {
                    final double proportionPositive =
                        positiveCounts.getTotal() / totalCount;
                    final double proportionNegative =
                        negativeCounts.getTotal() / totalCount;

                    // This is our tiebreaker criteria for the case where the
                    // gains are equal. It means that we prefer ties that are
                    // more balanced in how they split (50%/50% being optimal).
                    final double tieBreaker = 1.0
                        - Math.abs(proportionPositive - proportionNegative);

                    if (    gain > bestGain
                         || tieBreaker > bestTieBreaker)
                    {
                        // For the decision threshold we actually want to pick
                        // the point that is half way between the current value
                        // and the previous value. Hopefully this will be more
                        // robust than using just the value itself.
                        final double threshold =
                            (value + previousValue) / 2.0;

                        bestGain = gain;
                        bestTieBreaker = tieBreaker;
                        bestThreshold = threshold;
                    }
                }

                // Store this value as the previous value.
                previousValue = value;
            }
        }

        // Sanity check to make sure we found a threshold that
        // partitions the values.
        if (   bestThreshold <= smallestValue
            || bestThreshold >= largestValue)
        {
            throw new RuntimeException(
                "bestThreshold (" + bestThreshold + ") lies outside range of values (" + smallestValue + ", " + largestValue + ")");
        }

        // Return the pair containing the best gain and best threshold
        // found.
        return new DefaultPair(bestGain, bestThreshold);
    }

    /**
     * Computes the gain of a given split. The base counts contains the
     * category information before the split.
     *
     * @param   baseCounts
     *      The base category information before splitting. Contains the sum of
     *      the positive and negative counts.
     * @param positiveCounts
     *      The category information on the positive side of the split.
     * @param negativeCounts
     *      The category information on the negative side of the split.
     * @return
     *      The gain of the given split computed by comparing the positive and
     *      negative counts to the base counts.
     */
    public abstract double computeSplitGain(
        final DefaultDataDistribution baseCounts,
        final DefaultDataDistribution positiveCounts,
        final DefaultDataDistribution negativeCounts);

    @Override
    public int[] getDimensionsToConsider()
    {
        return this.dimensionsToConsider;
    }

    @Override
    public void setDimensionsToConsider(
        final int[] dimensionsToConsider)
    {
        this.dimensionsToConsider = dimensionsToConsider;
    }

    /**
     * Figures out the dimensionality of the Vector data.
     *
     * @param  data The data.
     * @return The dimensionality of the data in the vector.
     */
    protected static int getDimensionality(
        final Collection
            >
            data)
    {
        if (CollectionUtil.isEmpty(data))
        {
            // Bad data.
            return 0;
        }
        else
        {
            // Get the dimensionality of the first data element.
            return CollectionUtil.getFirst(data).getInput().convertToVector()
                .getDimensionality();
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy