All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.classification.DecisionTree Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.classification;

import java.io.Serial;
import java.util.*;
import java.util.stream.Collectors;

import smile.base.cart.*;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.measure.Measure;
import smile.data.measure.NominalScale;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.math.MathEx;
import smile.util.IntSet;

/**
 * Decision tree. A classification/regression tree can be learned by
 * splitting the training set into subsets based on an attribute value
 * test. This process is repeated on each derived subset in a recursive
 * manner called recursive partitioning. The recursion is completed when
 * the subset at a node all has the same value of the target variable,
 * or when splitting no longer adds value to the predictions.
 * 

* The algorithms that are used for constructing decision trees usually * work top-down by choosing a variable at each step that is the next best * variable to use in splitting the set of items. "Best" is defined by how * well the variable splits the set into homogeneous subsets that have * the same value of the target variable. Different algorithms use different * formulae for measuring "best". Used by the CART algorithm, Gini impurity * is a measure of how often a randomly chosen element from the set would * be incorrectly labeled if it were randomly labeled according to the * distribution of labels in the subset. Gini impurity can be computed by * summing the probability of each item being chosen times the probability * of a mistake in categorizing that item. It reaches its minimum (zero) when * all cases in the node fall into a single target category. Information gain * is another popular measure, used by the ID3, C4.5 and C5.0 algorithms. * Information gain is based on the concept of entropy used in information * theory. For categorical variables with different number of levels, however, * information gain are biased in favor of those attributes with more levels. * Instead, one may employ the information gain ratio, which solves the drawback * of information gain. *

* Classification and Regression Tree techniques have a number of advantages * over many of those alternative techniques. *

*
Simple to understand and interpret.
*
In most cases, the interpretation of results summarized in a tree is * very simple. This simplicity is useful not only for purposes of rapid * classification of new observations, but can also often yield a much simpler * "model" for explaining why observations are classified or predicted in a * particular manner.
*
Able to handle both numerical and categorical data.
*
Other techniques are usually specialized in analyzing datasets that * have only one type of variable.
*
Tree methods are nonparametric and nonlinear.
*
The final results of using tree methods for classification or regression * can be summarized in a series of (usually few) logical if-then conditions * (tree nodes). Therefore, there is no implicit assumption that the underlying * relationships between the predictor variables and the dependent variable * are linear, follow some specific non-linear link function, or that they * are even monotonic in nature. Thus, tree methods are particularly well * suited for data mining tasks, where there is often little a priori * knowledge nor any coherent set of theories or predictions regarding which * variables are related and how. In those types of data analytics, tree * methods can often reveal simple relationships between just a few variables * that could have easily gone unnoticed using other analytic techniques.
*
* One major problem with classification and regression trees is their high * variance. Often a small change in the data can result in a very different * series of splits, making interpretation somewhat precarious. Besides, * decision-tree learners can create over-complex trees that cause over-fitting. * Mechanisms such as pruning are necessary to avoid this problem. * Another limitation of trees is the lack of smoothness of the prediction * surface. *

* Some techniques such as bagging, boosting, and random forest use more than * one decision tree for their analysis. * * @see AdaBoost * @see GradientTreeBoost * @see RandomForest * * @author Haifeng Li */ public class DecisionTree extends CART implements Classifier, DataFrameClassifier { @Serial private static final long serialVersionUID = 2L; /** * The splitting rule. */ private final SplitRule rule; /** * The number of classes. */ private final int k; /** * The class labels. */ private IntSet classes; /** The dependent variable. */ private transient int[] y; @Override protected double impurity(LeafNode node) { return ((DecisionNode) node).impurity(rule); } @Override protected LeafNode newNode(int[] nodeSamples) { int[] count = new int[k]; for (int i : nodeSamples) { count[y[i]] += samples[i]; } return new DecisionNode(count); } @Override protected Optional findBestSplit(LeafNode leaf, int j, double impurity, int lo, int hi) { DecisionNode node = (DecisionNode) leaf; BaseVector xj = x.column(j); int[] falseCount = new int[k]; Split split = null; double splitScore = 0.0; int splitTrueCount = 0; int splitFalseCount = 0; Measure measure = schema.field(j).measure; if (measure instanceof NominalScale scale) { int splitValue = -1; int m = scale.size(); int[][] trueCount = new int[m][k]; for (int i = lo; i < hi; i++) { int o = index[i]; trueCount[xj.getInt(o)][y[o]] += samples[o]; } for (int l : scale.values()) { int tc = (int) MathEx.sum(trueCount[l]); int fc = node.size() - tc; // If either side is too small, skip this value. if (tc < nodeSize || fc < nodeSize) { continue; } for (int q = 0; q < k; q++) { falseCount[q] = node.count()[q] - trueCount[l][q]; } double gain = impurity - (double) tc / node.size() * DecisionNode.impurity(rule, tc, trueCount[l]) - (double) fc / node.size() * DecisionNode.impurity(rule, fc, falseCount); // new best split if (gain > splitScore) { splitValue = l; splitTrueCount = tc; splitFalseCount = fc; splitScore = gain; } } if (splitScore > 0.0) { final int value = splitValue; split = new NominalSplit(leaf, j, splitValue, splitScore, lo, hi, splitTrueCount, splitFalseCount, (int o) -> xj.getInt(o) == value); } } else { double splitValue = 0.0; int[] trueCount = new int[k]; int[] orderj = order[j]; int first = orderj[lo]; double prevx = xj.getDouble(first); int prevy = y[first]; for (int i = lo; i < hi; i++) { int tc = 0; int fc = 0; int o = orderj[i]; int yi = y[o]; double xij = xj.getDouble(o); if (yi != prevy && !MathEx.isZero(xij - prevx, 1E-7)) { tc = (int) MathEx.sum(trueCount); fc = node.size() - tc; } // If either side is empty, skip this value. if (tc >= nodeSize && fc >= nodeSize) { for (int l = 0; l < k; l++) { falseCount[l] = node.count()[l] - trueCount[l]; } double gain = impurity - (double) tc / node.size() * DecisionNode.impurity(rule, tc, trueCount) - (double) fc / node.size() * DecisionNode.impurity(rule, fc, falseCount); // new best split if (gain > splitScore) { splitValue = (xij + prevx) / 2; splitTrueCount = tc; splitFalseCount = fc; splitScore = gain; } } prevx = xij; prevy = yi; trueCount[prevy] += samples[o]; } if (splitScore > 0.0) { final double value = splitValue; split = new OrdinalSplit(leaf, j, splitValue, splitScore, lo, hi, splitTrueCount, splitFalseCount, (int o) -> xj.getDouble(o) <= value); } } return Optional.ofNullable(split); } /** * Constructor. Fits a classification tree for AdaBoost and Random Forest. * @param x the data frame of the explanatory variable. * @param y the response variables. * @param response the metadata of response variable. * @param k the number of classes. * @param maxDepth the maximum depth of the tree. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the minimum size of leaf nodes. * @param mtry the number of input variables to pick to split on at each * node. It seems that sqrt(p) give generally good performance, * where p is the number of variables. * @param rule the splitting rule. * @param samples the sample set of instances for stochastic learning. * samples[i] is the number of sampling for instance i. * @param order the index of training values in ascending order. Note * that only numeric attributes need be sorted. */ public DecisionTree(DataFrame x, int[] y, StructField response, int k, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, int mtry, int[] samples, int[][] order) { super(x, response, maxDepth, maxNodes, nodeSize, mtry, samples, order); this.k = k; this.y = y; this.rule = rule; final int[] count = new int[k]; int n = x.size(); for (int i = 0; i < n; i++) { count[y[i]] += this.samples[i]; } LeafNode node = new DecisionNode(count); this.root = node; Optional split = findBestSplit(node, 0, index.length, new boolean[x.ncol()]); if (maxNodes == Integer.MAX_VALUE) { // deep-first split split.ifPresent(s -> split(s, null)); } else { // best-first split PriorityQueue queue = new PriorityQueue<>(2 * maxNodes, Split.comparator.reversed()); split.ifPresent(queue::add); for (int leaves = 1; leaves < this.maxNodes && !queue.isEmpty(); ) { if (split(queue.poll(), queue)) leaves++; } } // merge the sister leaves that produce the same output. this.root = this.root.merge(); clear(); } /** * Fits a classification tree. * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @return the model. */ public static DecisionTree fit(Formula formula, DataFrame data) { return fit(formula, data, new Properties()); } /** * Fits a classification tree. * The hyperparameters in prop include *

    *
  • smile.cart.split.rule *
  • smile.cart.node.size *
  • smile.cart.max.nodes *
* @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param params the hyperparameters. * @return the model. */ public static DecisionTree fit(Formula formula, DataFrame data, Properties params) { SplitRule rule = SplitRule.valueOf(params.getProperty("smile.cart.split_rule", "GINI")); int maxDepth = Integer.parseInt(params.getProperty("smile.cart.max_depth", "20")); int maxNodes = Integer.parseInt(params.getProperty("smile.cart.max_nodes", String.valueOf(data.size() / 5))); int nodeSize = Integer.parseInt(params.getProperty("smile.cart.node_size", "5")); return fit(formula, data, rule, maxDepth, maxNodes, nodeSize); } /** * Fits a classification tree. * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param rule the splitting rule. * @param maxDepth the maximum depth of the tree. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the minimum size of leaf nodes. * @return the model. */ public static DecisionTree fit(Formula formula, DataFrame data, SplitRule rule, int maxDepth, int maxNodes, int nodeSize) { formula = formula.expand(data.schema()); DataFrame x = formula.x(data); BaseVector y = formula.y(data); ClassLabels codec = ClassLabels.fit(y); int mtry = x.ncol(); DecisionTree tree = new DecisionTree(x, codec.y, y.field(), codec.k, rule, maxDepth, maxNodes, nodeSize, mtry, null, null); tree.formula = formula; tree.classes = codec.classes; return tree; } @Override public int numClasses() { return classes.size(); } @Override public int[] classes() { return classes.values; } @Override public int predict(Tuple x) { DecisionNode leaf = (DecisionNode) root.predict(predictors(x)); int y = leaf.output(); return classes == null ? y : classes.valueOf(y); } @Override public boolean soft() { return true; } /** * Predicts the class label of an instance and also calculate a posteriori * probabilities. The posteriori estimation is based on sample distribution * in the leaf node. It is not accurate at all when be used in a single tree. * It is mainly used by RandomForest in an ensemble way. */ @Override public int predict(Tuple x, double[] posteriori) { DecisionNode leaf = (DecisionNode) root.predict(predictors(x)); leaf.posteriori(posteriori); int y = leaf.output(); return classes == null ? y : classes.valueOf(y); } /** Returns null if the tree is part of ensemble algorithm. */ @Override public Formula formula() { return formula; } @Override public StructType schema() { return schema; } /** Private constructor for prune(). */ private DecisionTree(Formula formula, StructType schema, StructField response, Node root, int k, SplitRule rule, double[] importance, IntSet classes) { super(formula, schema, response, root, importance); this.k = k; this.rule = rule; this.classes = classes; } /** * Returns a new decision tree by reduced error pruning. * @param test the test data set to evaluate the errors of nodes. * @return a new pruned tree. */ public DecisionTree prune(DataFrame test) { return prune(test, formula, classes); } /** * Reduced error pruning for random forest. * @param test the test data set to evaluate the errors of nodes. * @return a new pruned tree. */ DecisionTree prune(DataFrame test, Formula formula, IntSet classes) { double[] imp = importance.clone(); Prune prune = prune(root, test.stream().collect(Collectors.toList()), imp, formula, classes); return new DecisionTree(this.formula, schema, response, prune.node, k, rule, imp, this.classes); } /** * The result of pruning a subtree. * @param node The merged node if pruned. Otherwise, the original node. * @param error The test error on this node. * @param count The training sample size of each class. */ record Prune(Node node, int error, int[] count) { } /** Prunes a subtree. */ private Prune prune(Node node, List test, double[] importance, Formula formula, IntSet labels) { if (node instanceof DecisionNode leaf) { int y = leaf.output(); int error = 0; for (Tuple t : test) { if (y != labels.indexOf(formula.yint(t))) error++; } return new Prune(node, error, leaf.count()); } InternalNode parent = (InternalNode) node; List trueBranch = new ArrayList<>(); List falseBranch = new ArrayList<>(); for (Tuple t : test) { if (parent.branch(formula.x(t))) trueBranch.add(t); else falseBranch.add(t); } Prune trueChild = prune(parent.trueChild(), trueBranch, importance, formula, labels); Prune falseChild = prune(parent.falseChild(), falseBranch, importance, formula, labels); int[] count = new int[k]; for (int i = 0; i < k; i++) { count[i] = trueChild.count[i] + falseChild.count[i]; } int y = MathEx.whichMax(count); int error = 0; for (Tuple t : test) { if (y != labels.indexOf(formula.yint(t))) error++; } if (error < trueChild.error + falseChild.error) { node = new DecisionNode(count); importance[parent.feature()] -= parent.score(); } else { error = trueChild.error + falseChild.error; node = parent.replace(trueChild.node, falseChild.node); } return new Prune(node, error, count); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy