All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.tree.VectorThresholdGiniImpurityLearner Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                VectorThresholdGiniImpurityLearner.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright January 14, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 * See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;

/**
 * Learns vector thresholds based on the Gini impurity measure. It attempts to
 * minimize the Gini impurity in splits. If f_i is the fraction of examples
 * belonging to category i in split f, then the Gini impurity measure is defined
 * as:
 * 
sum_i f_i * (1 - f_i) *
Notice that sum_i f_i = 1, so the value will range between 0 and 1. *

* This measure is the one used in the Classification and Regression Tree * (CART) algorithm. * * @param * The type of the output categories to learn over. * @author Justin Basilico * @since 3.0 */ @PublicationReference( author="Wikipedia", title="Decision tree learning", year=2010, type=PublicationType.WebPage, url="http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity") public class VectorThresholdGiniImpurityLearner extends AbstractVectorThresholdMaximumGainLearner { /** * Creates a new instance of VectorThresholdGiniImpurityLearner. */ public VectorThresholdGiniImpurityLearner() { super(); } /** * Creates a new {@code VectorThresholdGiniImpurityLearner}. * * @param minSplitSize * The minimum split size. Must be positive. */ public VectorThresholdGiniImpurityLearner( final int minSplitSize) { super(minSplitSize, null); } @Override public VectorThresholdGiniImpurityLearner clone() { return (VectorThresholdGiniImpurityLearner) super.clone(); } /** * Computes the split gain by computing the Gini impurity for the * given split. * * @param baseCounts * The histogram of counts before the split. * @param positiveCounts * The counts on the positive side of the threshold. * @param negativeCounts * The counts on the negative side of the threshold. * @return * The split gain by computing the gain in Gini impurity for * the given split. Will be between 0.0 and 1.0. */ @Override public double computeSplitGain( final DefaultDataDistribution baseCounts, final DefaultDataDistribution positiveCounts, final DefaultDataDistribution negativeCounts) { // TODO: This is almost the same as the code in the // InformationGainLearner. They should be merged. // -- jdbasil (2010-01-14) // Compute the initial impurity. final double impurityBase = giniImpurity(baseCounts); final double impurityPositive = giniImpurity(positiveCounts); final double impurityNegative = giniImpurity(negativeCounts); // Compute the proportion positive and negative. final double totalCount = baseCounts.getTotal(); final double proportionPositive = positiveCounts.getTotal() / totalCount; final double proportionNegative = negativeCounts.getTotal() / totalCount; // Compute the gain in impurity. final double gain = impurityBase - proportionPositive * impurityPositive - proportionNegative * impurityNegative; // Return the gain. return gain; } /** * Computes the Gini impurity of a histogram. For each item in the * histogram, it is the probability that it is randomly assigned to the * wrong category, given the frequency of the different categories. This * is computed by looping over all the categories and multiplying the * fraction of elements in that category (f_i) times the probability of * choosing a different category (1 - f_i). That is: * * sum_i f_i * (1 - f_i) * * @param * The type of data the counts are over. * @param counts * The distribution to compute the impurity over. * @return * The Gini impurity of the given distribution. */ public static double giniImpurity( final DefaultDataDistribution counts) { double sum = 0.0; for (DataType value : counts.getDomain()) { final double fraction = counts.getFraction(value); sum += fraction * (1.0 - fraction); } return sum; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy