All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.base.cart.CART Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.base.cart;

import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.measure.Measure;
import smile.data.measure.NominalScale;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.feature.importance.SHAP;
import smile.math.MathEx;
import smile.sort.QuickSort;

import java.io.Serial;
import java.io.Serializable;
import java.math.BigInteger;
import java.util.*;
import java.util.stream.IntStream;
import java.util.AbstractMap.SimpleEntry;

/**
 * Classification and regression tree.
 *
 * @author Haifeng Li
 */
public abstract class CART implements SHAP, Serializable {
    @Serial
    private static final long serialVersionUID = 2L;
    private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(CART.class);

    /** The model formula. */
    protected Formula formula;

    /** The schema of predictors. */
    protected StructType schema;

    /** The schema of response variable. */
    protected StructField response;

    /** The root of decision tree. */
    protected Node root;
    /**
     * The maximum depth of the tree.
     */
    protected int maxDepth = 20;
    /**
     * The maximum number of leaf nodes in the tree.
     */
    protected int maxNodes = 6;
    /**
     * The number of instances in a node below which the tree will
     * not split, setting nodeSize = 5 generally gives good results.
     */
    protected int nodeSize = 5;
    /**
     * The number of input variables to be used to determine the decision
     * at a node of the tree.
     */
    protected int mtry = -1;

    /**
     * Variable importance. Every time a split of a node is made on variable
     * the (GINI, information gain, etc.) impurity criterion for the two
     * descendent nodes is less than the parent node. Adding up the decreases
     * for each individual variable over the tree gives a simple measure of
     * variable importance.
     */
    protected double[] importance;

    /** The training data. */
    protected transient DataFrame x;

    /**
     * The samples for training this node. Note that samples[i] is the
     * number of sampling of dataset[i]. 0 means that the datum is not
     * included and values of greater than 1 are possible because of
     * sampling with replacement.
     */
    protected transient int[] samples;

    /**
     * An index of samples to their original locations in training dataset.
     */
    protected transient int[] index;

    /**
     * An index of training values. Initially, order[i] is a set of indices that iterate through the
     * training values for attribute i in ascending order. During training, the array is rearranged
     * so that all values for each leaf node occupy a contiguous range, but within that range they
     * maintain the original ordering. Note that only numeric attributes will be sorted; non-numeric
     * attributes will have a null in the corresponding place in the array.
     */
    protected transient int[][] order;

    /**
     * The working buffer for reordering {@link #index} array.
     */
    private transient int[] buffer;

    /** Private constructor for deserialization. */
    private CART() {

    }

    /**
     * Constructor.
     * @param formula The model formula.
     * @param schema The data schema of predictors.
     * @param response The response variable.
     * @param root The root node.
     * @param importance The feature importance.
     */
    public CART(Formula formula, StructType schema, StructField response, Node root, double[] importance) {
        this.formula = formula;
        this.schema = schema;
        this.response = response;
        this.root = root;
        this.importance = importance;
    }

    /**
     * Constructor.
     * @param x the data frame of the explanatory variable.
     * @param y the response variables.
     * @param maxDepth the maximum depth of the tree.
     * @param maxNodes the maximum number of leaf nodes in the tree.
     * @param nodeSize the minimum size of leaf nodes.
     * @param mtry the number of input variables to pick to split on at each
     *             node. It seems that sqrt(p) give generally good performance,
     *             where p is the number of variables.
     * @param samples the sample set of instances for stochastic learning.
     *               samples[i] is the number of sampling for instance i.
     * @param order the index of training values in ascending order. Note
     *              that only numeric attributes need be sorted.
     */
    public CART(DataFrame x, StructField y, int maxDepth, int maxNodes, int nodeSize, int mtry, int[] samples, int[][] order) {
        this.x = x;
        this.response = y;
        this.schema = x.schema();
        this.importance = new double[x.ncol()];
        this.maxDepth = maxDepth;
        this.maxNodes = maxNodes;
        this.nodeSize = nodeSize;
        this.mtry = mtry;

        int n = x.size();
        int p = x.ncol();

        if (mtry < 1 || mtry > p) {
            logger.warn("Invalid mtry. Use all features.");
            this.mtry = p;
        }

        if (maxDepth < 1) {
            throw new IllegalArgumentException("Invalid maximum depth: " + maxDepth);
        }

        if (maxNodes < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + maxNodes);
        }

        if (nodeSize < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize);
        }

        IntStream idx;
        if (samples == null) {
            this.samples = Collections.nCopies(n, 1).parallelStream().mapToInt(i -> i).toArray();
            idx = IntStream.range(0, n);
        } else {
            this.samples = samples;
            idx = IntStream.range(0, samples.length).filter(i -> samples[i] > 0);
        }
        this.index = idx.toArray();

        buffer  = new int[index.length];

        if (order == null) {
            this.order = order(x);
        } else {
            this.order = new int[order.length][];
            for (int i = 0; i < order.length; i++) {
                if (order[i] != null) {
                    this.order[i] = Arrays.stream(order[i]).filter(o -> this.samples[o] > 0).toArray();
                }
            }
        }
    }

    /**
     * Returns the number of nodes in the tree.
     * @return the number of nodes in the tree.
     */
    public int size() {
        return size(root);
    }

    /** Returns the number of nodes of the subtree. */
    private int size(Node node) {
        if (node instanceof LeafNode) return 1;

        InternalNode parent = (InternalNode) node;
        return size(parent.trueChild) + size(parent.falseChild) + 1;
    }

    /**
     * Returns the index of ordered samples for each ordinal column.
     * @param x the predictors.
     * @return the index of ordered samples for each ordinal column.
     */
    public static int[][] order(DataFrame x) {
        int n = x.size();
        int p = x.ncol();
        StructType schema = x.schema();

        double[] a = new double[n];
        int[][] order = new int[p][];

        for (int j = 0; j < p; j++) {
            Measure measure = schema.field(j).measure;
            if (!(measure instanceof NominalScale)) {
                x.column(j).toDoubleArray(a);
                order[j] = QuickSort.sort(a);
            }
        }

        return order;
    }

    /**
     * Returns the predictors by the model formula if it is not null.
     * Otherwise, return the input tuple.
     * @param x the input tuple.
     * @return the predictors.
     */
    protected Tuple predictors(Tuple x) {
        return formula == null ? x : formula.x(x);
    }

    /** Clear the workspace of building tree. */
    protected void clear() {
        this.x = null;
        this.order = null;
        this.index = null;
        this.samples = null;
        this.buffer = null;
    }

    /**
     * Split a node into two children nodes.
     *
     * @param split the split candidate.
     * @param queue the queue of splits.
     * @return true if split success.
     */
    protected boolean split(final Split split, PriorityQueue queue) {
        if (split.feature < 0) {
            throw new IllegalStateException("Split a node with invalid feature.");
        }

        if (split.depth >= maxDepth) {
            logger.debug("Reach maximum depth");
            return false;
        }

        if (split.trueCount < nodeSize || split.falseCount < nodeSize) {
            // We should not reach here as findBestSplit filters this situation out.
            logger.debug("Node size is too small after splitting");
            return false;
        }

        int[] trueSamples = Arrays.stream(index, split.lo, split.hi).filter(i -> split.predicate().test(i)).toArray();

        // cache the results of predicate.test()
        boolean[] trues = new boolean[samples.length];
        for (int i : trueSamples) trues[i] = true;

        int[] falseSamples = Arrays.stream(index, split.lo, split.hi).filter(i -> !trues[i]).toArray();
        int mid = split.lo + trueSamples.length;

        LeafNode trueChild = newNode(trueSamples);
        assert trueChild.size == split.trueCount : String.format("trueChild.size != split.trueCount: %d != %d", trueChild.size, split.trueCount);
        assert trueChild.size >= nodeSize : String.format("trueChild size is too small: %d < %d", trueChild.size, nodeSize);

        LeafNode falseChild = newNode(falseSamples);
        assert falseChild.size == split.falseCount : String.format("falseChild.size != split.falseCount: %d != %d", falseChild.size, split.falseCount);
        assert falseChild.size >= nodeSize : String.format("falseChild size is too small: %d < %d", falseChild.size, nodeSize);
        InternalNode node = split.toNode(trueChild, falseChild);

        shuffle(split.lo, mid, split.hi, trues);

        Optional trueSplit = findBestSplit(trueChild, split.lo, mid, split.unsplittable.clone());
        Optional falseSplit = findBestSplit(falseChild, mid, split.hi, split.unsplittable); // reuse parent's array

        // Prune the branch if both children are leaf nodes and of same output value.
        if (trueChild.equals(falseChild) && trueSplit.isEmpty() && falseSplit.isEmpty()) {
            return false;
        }

        if (split.parent == null) {
            this.root = node;
        } else if (split.parent.trueChild == split.leaf) {
            split.parent.trueChild = node;
        } else if (split.parent.falseChild == split.leaf) {
            split.parent.falseChild = node;
        } else {
            throw new IllegalStateException("split.parent and leaf don't match");
        }

        importance[node.feature] += node.score;
        trueSplit.ifPresent(s -> {s.parent = node; s.depth = split.depth + 1;});
        falseSplit.ifPresent(s -> {s.parent = node; s.depth = split.depth + 1;});

        if (queue == null) {
            // deep first split
            trueSplit.ifPresent(s -> split(s, null));
            falseSplit.ifPresent(s -> split(s, null));
        } else {
            // best first split
            trueSplit.ifPresent(queue::add);
            falseSplit.ifPresent(queue::add);
        }

        return true;
    }

    /**
     * Finds the best attribute to split on a set of samples. at the current node. Returns
     * Optional.empty if a split doesn't exist to reduce the impurity.
     * @param node the leaf node to split.
     * @param lo the inclusive lower bound of the data partition in the reordered sample index array.
     * @param hi the exclusive upper bound of the data partition in the reordered sample index array.
     * @param unsplittable unsplittable[j] is true if the column j cannot be split further in the node.
     * @return the best split candidate.
     */
    protected Optional findBestSplit(LeafNode node, int lo, int hi, boolean[] unsplittable) {
        if (node.size() < 2 * nodeSize) {
            return Optional.empty(); // one child will have less than nodeSize samples.
        }

        final double impurity = impurity(node);
        if (impurity == 0.0) {
            return Optional.empty(); // all the samples in the node have the same response
        }

        // skip the unsplittable columns
        int p = schema.length();
        int[] columns = IntStream.range(0, p).filter(i -> !unsplittable[i]).toArray();

        // random forest
        if (mtry < p) {
            MathEx.permutate(columns);
        }

        IntStream stream = Arrays.stream(columns).limit(mtry);
        Optional split = (mtry < p ? stream : stream.parallel()) // random forest is in parallel already
                .mapToObj(j -> {
                    Optional s = findBestSplit(node, j, impurity, lo, hi);
                    if (s.isEmpty()) unsplittable[j] = true;
                    return s;
                })
                .filter(Optional::isPresent)
                .map(Optional::get)
                .max(Split.comparator);

        split.ifPresent(s -> s.unsplittable = unsplittable);
        return split;
    }

    /**
     * Returns the impurity of node.
     * @param node the node to calculate the impurity.
     * @return the impurity of node.
     */
    protected abstract double impurity(LeafNode node);

    /**
     * Creates a new leaf node.
     * @param nodeSamples the samples belonging to this node.
     * @return the new leaf node.
     */
    protected abstract LeafNode newNode(int[] nodeSamples);

    /**
     * Finds the best split for given column.
     * @param node the node to split.
     * @param column the column to split on.
     * @param impurity the impurity of node.
     * @param lo the lower bound of sample index in the node.
     * @param hi the upper bound of sample index in the node.
     * @return the best split.
     */
    protected abstract Optional findBestSplit(LeafNode node, int column, double impurity, int lo, int hi);

    /**
     * Returns the variable importance. Every time a split of a node is made
     * on variable the (GINI, information gain, etc.) impurity criterion for
     * the two descendent nodes is less than the parent node. Adding up the
     * decreases for each individual variable over the tree gives a simple
     * measure of variable importance.
     *
     * @return the variable importance
     */
    public double[] importance() {
        return importance;
    }

    /**
     * Returs the root node.
     * @return root node.
     */
    public Node root() {
        return root;
    }

    /**
     * Returns the graphic representation in Graphviz dot format.
     * Try http://viz-js.com/
     * to visualize the returned string.
     * @return the graphic representation in Graphviz dot format.
     */
    public String dot() {
        StringBuilder builder = new StringBuilder();
        builder.append("digraph CART {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n");

        String trueLabel  = " [labeldistance=2.5, labelangle=45, headlabel=\"True\"];\n";
        String falseLabel = " [labeldistance=2.5, labelangle=-45, headlabel=\"False\"];\n";

        Queue> queue = new LinkedList<>();
        queue.add(new SimpleEntry<>(1, root));

        while (!queue.isEmpty()) {
            // Dequeue a vertex from queue and print it
            SimpleEntry entry = queue.poll();
            int id = entry.getKey();
            Node node = entry.getValue();

            // leaf node
            builder.append(node.dot(schema, response, id));

            if (node instanceof InternalNode inode) {
                int tid = 2 * id;
                int fid = 2 * id + 1;
                queue.add(new SimpleEntry<>(tid, inode.trueChild));
                queue.add(new SimpleEntry<>(fid, inode.falseChild));

                // add edge
                builder.append(' ').append(id).append(" -> ").append(tid).append(trueLabel);
                builder.append(' ').append(id).append(" -> ").append(fid).append(falseLabel);

                // only draw edge label at top
                if (id == 1) {
                    trueLabel = "\n";
                    falseLabel = "\n";
                }
            }
        }

        builder.append("}");
        return builder.toString();
    }

    /**
     * Shuffles {@link #index} and {@link #order} by partitioning the range
     * from low (inclusive) to high (exclusive) so that all elements i for
     * which predicate(i) is true come before all elements for which it is
     * false, but element ordering is otherwise preserved. The number of
     * true values returned by predicate must equal {@code split - low}.
     *
     * @param low the low bound of the segment of the order arrays
     *            which will be partitioned.
     * @param split where the partition's split point will end up.
     * @param high the high bound of the segment of the order arrays
     *             which will be partitioned.
     * @param predicate whether an element goes to the left side or
     *                  the right side of the partition.
     */
    private void shuffle(int low, int split, int high, boolean[] predicate) {
        Arrays.stream(order).filter(Objects::nonNull).forEach(o -> shuffle(o, low, split, high, predicate));
        shuffle(index, low, split, high, predicate);
    }

    /**
     * Shuffles an array in-place by partitioning the range from
     * low (inclusive) to high (exclusive) so that all elements i
     * for which goesLeft(i) is true come before all elements for
     * which it is false, but element ordering is otherwise preserved.
     * The number of true values returned by goesLeft must equal
     * {@code split - low}. buffer is scratch space large enough
     * (i.e., at least {@code high - split} long) to hold all
     * elements for which goesLeft is false.
     *
     * @param a the array to shuffle.
     * @param low the low bound of the segment of the order arrays
     *            which will be partitioned.
     * @param split where the partition's split point will end up.
     * @param high the high bound of the segment of the order arrays
     *             which will be partitioned.
     * @param predicate whether an element goes to the left side or
     *                  the right side of the partition.
     */
    private void shuffle(int[] a, int low, int split, int high, boolean[] predicate) {
        int k = 0;
        for (int i = low, j = low; i < high; i++) {
            if (predicate[a[i]]) {
                a[j++] = a[i];
            } else {
                buffer[k++] = a[i];
            }
        }

        assert(split + k == high);
        System.arraycopy(buffer, 0, a, split, k);
    }

    /**
     * Returns a text representation of the tree in R's rpart format.
     * A semi-graphical layout of the tree. Indentation is used to convey
     * the tree topology. Information for each node includes the node number,
     * split, size, deviance, and fitted value. For the decision tree,
     * the class probabilities are also printed.
     */
    @Override
    public String toString() {
        // Build up the lines in reverse order:
        // the false-child-first postorder turns into
        // the true-child-first preorder, which is what's needed.
        List lines = new ArrayList<>();
        root.toString(schema, response, null, 0, BigInteger.ONE, lines);
        lines.add("* denotes terminal node");
        lines.add("node), split, n, loss, yval, (yprob)");
        lines.add("n=" + root.size());
        Collections.reverse(lines);
        return String.join("\n", lines);
    }

    /**
     * Returns the average of absolute SHAP values over a data frame.
     * @param data the data.
     * @return the average of absolute SHAP values.
     */
    public double[] shap(DataFrame data) {
        // Binds the formula to the data frame's schema in case that
        // it is different from that of training data.
        formula.bind(data.schema());
        return shap(data.stream().parallel());
    }

    @Override
    public double[] shap(Tuple x) {
        int k = 1;
        Node node = root;
        while (node instanceof InternalNode inode) {
            node = inode.trueChild;
        }
        if (node instanceof DecisionNode dnode) {
            k = dnode.count().length;
        }

        int p = schema.length();
        double[] phi = new double[p * k];
        Path m = new Path(new int[0], new double[0], new double[0], new double[0]);
        recurse(phi, predictors(x), root, m, 1, 1, -1);
        return phi;
    }

    /**
     * Recursively keep track of what proportion of all possible subsets
     * flow down into each of the leaves of the tree.
     */
    private void recurse(double[] phi, Tuple x, Node node, Path m, double pz, double po, int pi) {
        int l = m.length;
        m = m.extend(pz, po, pi);

        if (node instanceof InternalNode split) {
            int dj = split.feature();
            Node h, c;
            if (split.branch(x)) {
                h = split.trueChild();
                c = split.falseChild();
            } else {
                h = split.falseChild();
                c = split.trueChild();
            }

            int rh = h.size();
            int rc = c.size();
            int rj = node.size();

            int k = 0;
            for (; k <= l; k++) {
                if (m.d[k] == dj) break;
            }

            double iz = 1.0;
            double io = 1.0;
            if (k <= l) {
                iz = m.z[k];
                io = m.o[k];
                m.unwind(k);
            }

            recurse(phi, x, h, m, iz * rh / rj, io, dj);
            recurse(phi, x, c, m, iz * rc /rj, 0, dj);
        } else {
            if (node instanceof DecisionNode leaf) {
                int k = leaf.count().length;
                double[] prob = new double[k];
                leaf.posteriori(prob);
                for (int i = 1; i <= l; i++) {
                    double w = m.unwoundSum(i)* (m.o[i] - m.z[i]);
                    int di = m.d[i] * k;
                    for (int j = 0; j < k; j++) {
                        phi[di + j] += w * prob[j];
                    }
                }
            } else {
                double vj = ((RegressionNode) node).output();
                for (int i = 1; i <= l; i++) {
                    double w = m.unwoundSum(i);
                    phi[m.d[i]] += w * (m.o[i] - m.z[i]) * vj;
                }
            }
        }
    }

    /**
     * The path of unique features we have split
     * on so far during SHAP recursive traverse.
     */
    private static class Path {
        /** The length of path. */
        int length;
        /** The unique feature index. */
        final int[] d;
        /**
         * The fraction of zero paths (where this feature is not
         * in the non-zero index set S) that flow through this path.
         */
        final double[] z;
        /**
         * The fraction of one path (where this feature is
         * in the non-zero index set S) that flow through this path.
         */
        final double[] o;
        /**
         * The proportion of sets of a given cardinality that are present.
         */
        final double[] w;

        /**
         * Constructor.
         */
        Path(int[] d, double[] z, double[] o, double[] w) {
            this.length = d.length;
            this.d = d;
            this.z = z;
            this.o = o;
            this.w = w;
        }

        /**
         * To keep track of each possible subset size during the recursion,
         * grows all these subsets according to a given fraction of ones and
         * zeros.
         */
        Path extend(double pz, double po, int pi) {
            int l = length;
            Path m = new Path(
                    // Arrays.copyOf will truncate or pad with zeros.
                    Arrays.copyOf(d, l + 1),
                    Arrays.copyOf(z, l + 1),
                    Arrays.copyOf(o, l + 1),
                    Arrays.copyOf(w, l + 1)
            );

            m.d[l] = pi;
            m.z[l] = pz;
            m.o[l] = po;
            m.w[l] = l == 0 ? 1 : 0;

            for (int i = l-1; i >= 0; i--) {
                m.w[i+1] += po * m.w[i] * (i+1) / (l+1);
                m.w[i] = pz * m.w[i] * (l - i) / (l+1);
            }

            return m;
        }

        /**
         * Undo previous extensions when we split on the same feature twice,
         * and undo each extension of the path inside a leaf to compute
         * weights for each feature in the path.
         */
        void unwind(int i) {
            double po = o[i];
            double pz = z[i];
            int l = --length;

            double n = w[l];
            if (po != 0) {
                for (int j = l - 1; j >= 0; j--) {
                    double t = w[j];
                    w[j] = n * (l + 1) / ((j + 1) * po);
                    n = t - w[j] * pz * (l - j) / (l + 1);
                }
            } else {
                for (int j = l - 1; j >= 0; j--) {
                    w[j] = (w[j] * (l+1)) / (pz * (l - j));
                }
            }

            for (int j = i; j < l; j++) {
                d[j] = d[j+1];
                z[j] = z[j+1];
                o[j] = o[j+1];
            }
        }

        /**
         * Return the total permutation weight if we unwind a previous
         * extension in the decision path.
         */
        double unwoundSum(int i) {
            double po = o[i];
            double pz = z[i];
            int l = length - 1;
            double sum = 0.0;

            double n = w[l];
            if (po != 0) {
                for (int j = l - 1; j >= 0; j--) {
                    double t = n / ((j+1) * po);
                    sum += t;
                    n = w[j] - t * pz * (l - j);
                }
            } else {
                for (int j = l - 1; j >= 0; j--) {
                    sum += w[j] / (pz * (l - j));
                }
            }

            return sum * (l + 1);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy