All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.tree.AbstractVectorThresholdMaximumGainLearner Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                AbstractVectorThresholdMaximumGainLearner.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright November 25, 2009, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 * See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.collection.ArrayUtil;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;

/**
 * An abstract class for decider learners that produce a threshold function
 * on a vector element based on maximizing some gain value. It handles the
 * looping over the elements of the vector and then for each element looping
 * over the possible split points. Subclasses only need to define a method to
 * compute the gain of a given split.
 * 
 * @param   
 *      The output category type for the training data.
 * @author  Justin Basilico
 * @since   3.0
 */
public abstract class AbstractVectorThresholdMaximumGainLearner
    extends AbstractCloneableSerializable
    implements VectorThresholdLearner
{

    /** The default value for the minimum split size is {@value}. */
    public static final int DEFAULT_MIN_SPLIT_SIZE = 1;
    
    /** The threshold for allowing a split to be made, determined by how many
     *  instances fall in each left or right sides of the split. Both sides
     *  must have at least this number of instances. Must be positive. */
    protected int minSplitSize;

    /** The array of dimensions for the learner to consider. If this is null,
     *  then all dimensions are considered. */
    protected int[] dimensionsToConsider;

    /**
     * Creates a new {@code AbstractVectorThresholdMaximumGainLearner}.
     */
    public AbstractVectorThresholdMaximumGainLearner()
    {
        this(DEFAULT_MIN_SPLIT_SIZE, null);
    }

    /**
     * 
     * Creates a new {@code AbstractVectorThresholdMaximumGainLearner}.
     * 
     * @param   minSplitSize
     *      The minimum split size. Must be positive.
     * @param   dimensionsToConsider
     *      The array of vector dimensions to consider. Null means all of them
     *      are considered.
     */
    public AbstractVectorThresholdMaximumGainLearner(
        final int minSplitSize,
        final int[] dimensionsToConsider)
    {
        super();
        
        this.setMinSplitSize(minSplitSize);
        this.setDimensionsToConsider(dimensionsToConsider);
    }
    
    @Override
    public AbstractVectorThresholdMaximumGainLearner clone()
    {
        @SuppressWarnings("unchecked")
        final AbstractVectorThresholdMaximumGainLearner result = (AbstractVectorThresholdMaximumGainLearner)
            super.clone();
        result.dimensionsToConsider = ArrayUtil.copy(this.dimensionsToConsider);
        return result;
    }

    @Override
    public VectorElementThresholdCategorizer learn(
        final Collection> data)
    {
        final int totalCount = CollectionUtil.size(data);
        if (totalCount <= 1)
        {
            // Nothing to learn.
            return null;
        }

        // Compute the base count values for the node.
        final DefaultDataDistribution baseCounts =
            CategorizationTreeLearner.getOutputCounts(data);

        // Pre-allocate a workspace of data for computing the gain.
        final ArrayList> workspace =
            new ArrayList>(totalCount);
        for (int i = 0; i < totalCount; i++)
        {
            workspace.add(new DefaultWeightedValue());
        }

        // Figure out the dimensionality of the data.
        final int dimensionality = DatasetUtil.getInputDimensionality(data);

        // Go through all the dimensions to find the one with the best gain
        // and the best threshold.
        double bestGain = -1.0;
        int bestIndex = -1;
        double bestThreshold = 0.0;

        final int dimensionsCount = this.dimensionsToConsider == null ?
            dimensionality : this.dimensionsToConsider.length;
        for (int i = 0; i < dimensionsCount; i++)
        {
            final int index = this.dimensionsToConsider == null ?
                i : this.dimensionsToConsider[i];

            // Compute the best gain-threshold pair for the given dimension of
            // the data.
            final DefaultPair gainThresholdPair =
                this.computeBestGainAndThreshold(data, index, baseCounts);

            if (gainThresholdPair == null)
            {
                // There was no gain-threshold pair that created a
                // threshold.
                continue;
            }

            // Get the gain from the pair.
            final double gain = gainThresholdPair.getFirst();

            // Determine if this is the best gain seen.
            if (bestIndex == -1 || gain > bestGain)
            {
                // This is the best gain, so store the gain, threshold,
                // and index.
                final double threshold = gainThresholdPair.getSecond();
                bestGain = gain;
                bestIndex = index;
                bestThreshold = threshold;
            }
        }

        if (bestIndex < 0)
        {
            // There was no dimension that provided any gain for the data,
            // so no decision function can be made.
            return null;
        }
        else
        {
            // Create the decision function for the best gain.
            return new VectorElementThresholdCategorizer(
                bestIndex, bestThreshold);
        }
    }

    /**
     * Computes the best gain and threshold for a given dimension using the
     * computeSplitGain method for each potential split point of values for the
     * given dimension.
     *
     * @param   data
     *      The data to use to compute the threshold.
     * @param   dimension
     *      The dimension to compute the threshold for.
     * @param   baseCounts
     *      Information about the base category counts.
     * @return
     *      A pair containing the best gain computed and its associated
     *      threshold. If there is no good split point, null is returned. This
     *      can happen if there is no data or every value is the same.
     */
    public DefaultPair computeBestGainAndThreshold(
        final Collection> data,
        final int dimension,
        final DefaultDataDistribution baseCounts)
    {
        final int totalCount = data.size();
     
        final ArrayList> workspace =
            new ArrayList>(totalCount);
        for (int i = 0; i < totalCount; i++)
        {
            workspace.add(new DefaultWeightedValue());
        }
        return this.computeBestGainAndThreshold(data, dimension, baseCounts,
            workspace);
    }

    /**
     * Computes the best gain and threshold for a given dimension using the
     * computeSplitGain method for each potential split point of values for the
     * given dimension.
     *
     * @param   data
     *      The data to use to compute the threshold.
     * @param   dimension
     *      The dimension to compute the threshold for.
     * @param   baseCounts
     *      Information about the base category counts.
     * @param   values
     *      A workspace to store the values of the data in. Recycled to avoid
     *      recreating a large array each time.
     * @return
     *      A pair containing the best gain computed and its associated
     *      threshold. If there is no good split point, null is returned. This
     *      can happen if there is no data or every value is the same.
     */
    protected DefaultPair computeBestGainAndThreshold(
        final Collection> data,
        final int dimension,
        final DefaultDataDistribution baseCounts,
        final ArrayList> values)
    {
        // We can only compute thresholds if we can put the minimum split
        // size on each side.
        final int totalCount = data.size();
        if (totalCount < 2 * this.minSplitSize)
        {
            return null;
        }

        // To compute the gain we will sort all of the values along the given
        // dimension and then walk along the values to determine the best
        // threshold.

        // The first step is to create a list of (value, output) pairs for the
        // given dimension. We do this by using the given workspace.
        int index = 0;
        for (InputOutputPair example
            : data)
        {
            // Add this example to the list.
            final Vector input = example.getInput().convertToVector();
            final OutputType output = example.getOutput();
            final double value = input.getElement(dimension);
            DefaultWeightedValue entry = values.get(index);
            entry.setWeight(value);
            entry.setValue(output);
            index++;
        }

        // Sort the list in ascending order by value.
        Collections.sort(values,
            DefaultWeightedValue.WeightComparator.getInstance());

        // Get the smallest and largest values.
        final double smallestValue = values.get(0).getWeight();
        final double largestValue = values.get(totalCount - 1).getWeight();

        // If all the values on this dimension are the same then there is
        // nothing to split on.
        if (smallestValue >= largestValue)
        {
            // All of the values are the same.
            return null;
        }

        // In order to find the best split we are going to keep track of the
        // counts of each label on each side of the threshold. This means
        // that we maintain two counting objects.
        // To start with all of the examples are on the positive side of
        // the split, so we initialize the base counts (all the data points)
        // and the negative counts with nothing.
        final DefaultDataDistribution positiveCounts =
            baseCounts.clone();
        final DefaultDataDistribution negativeCounts =
            new DefaultDataDistribution(baseCounts.getDomain().size());

        // We are going to loop over all the values to compute the best
        // gain and the best threshold.
        double bestGain = Double.NEGATIVE_INFINITY;
        double bestTieBreaker = Double.NEGATIVE_INFINITY;
        double bestThreshold = Double.NEGATIVE_INFINITY;
        boolean validSplit = false;

        // We need to keep track of the previous value for two reasons:
        //    1) To determine if we've already tested the value, since we loop
        //       over a >= threshold.
        //    2) So that the threshold can be computed to be half way between
        //       two values.
	    //
        // We advance i through values and stop whenever value[i] != value[i-1].
        // These are all the points where it is meaningful to evaluate a split.
        // All values to the left of i go into the negative count bucket.
        final int maxIndex = totalCount - this.minSplitSize;
        double previousValue = 0.0;
        for (int i = 0; i <= maxIndex; i++)
        {
            final DefaultWeightedValue valueLabel = values.get(i);
            final OutputType label = valueLabel.getValue();
            final double value = valueLabel.getWeight();

            // Check if it is worth evaluating a threshold between
            // previous and next value.
            if (i < this.minSplitSize)
            {
                // We are going to loop over a threshold value that is >=
                // to handle equivalent values properly. However, we also need
                // to ignore the first minSplitSize values.
                bestThreshold = value;
            }
            else if (value != previousValue)
            {
                // Compute the gain.
                final double gain = computeSplitGain(
                    baseCounts, positiveCounts, negativeCounts);

                if (gain >= bestGain)
                {
                    final double proportionPositive =
                        positiveCounts.getTotal() / totalCount;
                    final double proportionNegative =
                        negativeCounts.getTotal() / totalCount;

                    // This is our tiebreaker criteria for the case where the
                    // gains are equal. It means that we prefer ties that are
                    // more balanced in how they split (50%/50% being optimal).
                    final double tieBreaker = 1.0
                        - Math.abs(proportionPositive - proportionNegative);

                    if (    gain > bestGain
                         || tieBreaker > bestTieBreaker)
                    {
                        // For the decision threshold we actually want to pick
                        // the point that is half way between the current value
                        // and the previous value. Hopefully this will be more
                        // robust than using just the value itself.
                        double threshold = (value + previousValue) / 2.0;

                        // If a round-off error drops the threshold back to
                        // the previous value, set it equal to the current
                        // value since we use a >= threshold.
                        if (threshold <= previousValue)
                        {
                            threshold = value;
                        }

                        bestGain = gain;
                        bestTieBreaker = tieBreaker;
                        bestThreshold = threshold;
                        validSplit = true;
                    }
                }
            }
            
            // Store this value as the previous value.
            previousValue = value;
            
            // Update the sides of the split.
            positiveCounts.decrement(label);
            negativeCounts.increment(label);
        }

        // Sanity check to make sure we found a threshold that
        // partitions the values.
        if (   bestThreshold <= smallestValue
            || bestThreshold >  largestValue)
        {
            throw new RuntimeException(
                "bestThreshold (" + bestThreshold + ") lies outside range of "
                + "values (" + smallestValue + ", " + largestValue + "]");
        }

        if (!validSplit)
        {
            return null;
        }
        
        // Return the pair containing the best gain and best threshold
        // found.
        return new DefaultPair(bestGain, bestThreshold);
    }

    /**
     * Computes the gain of a given split. The base counts contains the
     * category information before the split.
     *
     * @param   baseCounts
     *      The base category information before splitting. Contains the sum of
     *      the positive and negative counts.
     * @param positiveCounts
     *      The category information on the positive side of the split.
     * @param negativeCounts
     *      The category information on the negative side of the split.
     * @return
     *      The gain of the given split computed by comparing the positive and
     *      negative counts to the base counts.
     */
    public abstract double computeSplitGain(
        final DefaultDataDistribution baseCounts,
        final DefaultDataDistribution positiveCounts,
        final DefaultDataDistribution negativeCounts);

    @Override
    public int[] getDimensionsToConsider()
    {
        return this.dimensionsToConsider;
    }

    @Override
    public void setDimensionsToConsider(
        final int... dimensionsToConsider)
    {
        this.dimensionsToConsider = dimensionsToConsider;
    }

    /**
     * Gets the minimum split size. This is the minimum number of examples
     * that can fall on either side of the split for it to be valid. If there
     * is not at least twice this number of examples in the input data, then
     * no split is returned.
     * 
     * @return 
     *      The minimum split size. Must be positive.
     */
    public int getMinSplitSize()
    {
        return this.minSplitSize;
    }
    
    /**
     * Sets the minimum split size. This is the minimum number of examples
     * that can fall on either side of the split for it to be valid. If there
     * is not at least twice this number of examples in the input data, then
     * no split is returned.
     * 
     * @param   minSplitSize
     *      The minimum split size. Must be positive.
     */
    public void setMinSplitSize(
        final int minSplitSize)
    {
        ArgumentChecker.assertIsPositive("minSplitSize", minSplitSize);
        this.minSplitSize = minSplitSize;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy