All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.regression.RegressionTree Maven / Gradle / Ivy

There is a newer version: 2024.11.2
Show newest version
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package smile.regression;

import smile.data.Attribute;
import smile.data.AttributeDataset;
import smile.math.Math;
import smile.data.NominalAttribute;
import smile.data.NumericAttribute;
import smile.sort.QuickSort;
import smile.util.MulticoreExecutor;

import java.io.Serializable;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.stream.IntStream;

/**
 * Decision tree for regression. A decision tree can be learned by
 * splitting the training set into subsets based on an attribute value
 * test. This process is repeated on each derived subset in a recursive
 * manner called recursive partitioning.
 * 

* Classification and Regression Tree techniques have a number of advantages * over many of those alternative techniques. *

*
Simple to understand and interpret.
*
In most cases, the interpretation of results summarized in a tree is * very simple. This simplicity is useful not only for purposes of rapid * classification of new observations, but can also often yield a much simpler * "model" for explaining why observations are classified or predicted in a * particular manner.
*
Able to handle both numerical and categorical data.
*
Other techniques are usually specialized in analyzing datasets that * have only one type of variable.
*
Tree methods are nonparametric and nonlinear.
*
The final results of using tree methods for classification or regression * can be summarized in a series of (usually few) logical if-then conditions * (tree nodes). Therefore, there is no implicit assumption that the underlying * relationships between the predictor variables and the dependent variable * are linear, follow some specific non-linear link function, or that they * are even monotonic in nature. Thus, tree methods are particularly well * suited for data mining tasks, where there is often little a priori * knowledge nor any coherent set of theories or predictions regarding which * variables are related and how. In those types of data analytics, tree * methods can often reveal simple relationships between just a few variables * that could have easily gone unnoticed using other analytic techniques.
*
* One major problem with classification and regression trees is their high * variance. Often a small change in the data can result in a very different * series of splits, making interpretation somewhat precarious. Besides, * decision-tree learners can create over-complex trees that cause over-fitting. * Mechanisms such as pruning are necessary to avoid this problem. * Another limitation of trees is the lack of smoothness of the prediction * surface. *

* Some techniques such as bagging, boosting, and random forest use more than * one decision tree for their analysis. * * @author Haifeng Li * @see GradientTreeBoost * @see RandomForest */ public class RegressionTree implements Regression { private static final long serialVersionUID = 1L; /** * The attributes of independent variable. */ private Attribute[] attributes; /** * Variable importance. Every time a split of a node is made on variable * the impurity criterion for the two descendent nodes is less than the * parent node. Adding up the decreases for each individual variable * over the tree gives a simple measure of variable importance. */ private double[] importance; /** * Values between [-1, 1] that represents monotonic regression coefficient for each attribute. * * It can be used to enforce model to keep monotonic relationship between target and the attribute. * Positive value enforce target to be positively correlated with this feature. * Positive value enforce target to be negatively correlated with this feature. * Zero value turns off monotonic regression. */ private double[] monotonicRegression; /** * The root of the regression tree */ private Node root; /** * The number of instances in a node below which the tree will * not split, setting nodeSize = 5 generally gives good results. */ private int nodeSize = 5; /** * The maximum number of leaf nodes in the tree. */ private int maxNodes = 6; /** * The number of input variables to be used to determine the decision * at a node of the tree. */ private int mtry; /** * The number of binary features. */ private int numFeatures; /** * The index of training values in ascending order. Note that only numeric * attributes will be sorted. */ private transient int[][] order; /** * Trainer for regression tree. */ public static class Trainer extends RegressionTrainer { /** * The minimum size of leaf nodes. */ private int nodeSize = 1; /** * The maximum number of leaf nodes in the tree. */ private int maxNodes = 100; /** * The number of sparse binary features. */ private int numFeatures = -1; /** * Constructor. * * @param maxNodes the maximum number of leaf nodes in the tree. */ public Trainer(int maxNodes) { if (maxNodes < 2) { throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes); } this.maxNodes = maxNodes; } /** * Constructor. * * @param attributes the attributes of independent variable. * @param maxNodes the maximum number of leaf nodes in the tree. */ public Trainer(Attribute[] attributes, int maxNodes) { super(attributes); if (maxNodes < 2) { throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes); } this.maxNodes = maxNodes; } /** * Constructor. * * @param numFeatures the number of features. * @param maxNodes the maximum number of leaf nodes in the tree. */ public Trainer(int numFeatures, int maxNodes) { if (numFeatures <= 0) { throw new IllegalArgumentException("Invalid number of sparse binary features: " + numFeatures); } if (maxNodes < 2) { throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes); } this.numFeatures = numFeatures; this.maxNodes = maxNodes; } /** * Sets the maximum number of leaf nodes in the tree. * * @param maxNodes the maximum number of leaf nodes in the tree. */ public Trainer setMaxNodes(int maxNodes) { if (maxNodes < 2) { throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes); } this.maxNodes = maxNodes; return this; } /** * Sets the minimum size of leaf nodes. * * @param nodeSize the minimum size of leaf nodes.. */ public Trainer setNodeSize(int nodeSize) { if (nodeSize < 2) { throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize); } this.nodeSize = nodeSize; return this; } @Override public RegressionTree train(double[][] x, double[] y) { return new RegressionTree(attributes, x, y, maxNodes, nodeSize); } public RegressionTree train(int[][] x, double[] y) { if (numFeatures <= 0) { return new RegressionTree(Math.max(x) + 1, x, y, maxNodes, nodeSize); } else { return new RegressionTree(numFeatures, x, y, maxNodes, nodeSize); } } } /** * An interface to calculate node output. Note that samples[i] is the * number of sampling of dataset[i]. 0 means that the datum is not * included and values of greater than 1 are possible because of * sampling with replacement. */ public interface NodeOutput { /** * Calculate the node output. * * @param samples the samples in the node. * @return the node output */ public double calculate(int[] samples); } /** * Regression tree node. */ class Node implements Serializable { /** * Predicted real value for this node. */ double output = 0.0; /** * The split feature for this node. */ int splitFeature = -1; /** * The split value. */ double splitValue = Double.NaN; double gain = 0.0; /** * Reduction in squared error compared to parent. */ double splitScore = 0.0; /** * Children node. */ Node trueChild; /** * Children node. */ Node falseChild; /** * Predicted output for children node. */ double trueChildOutput = 0.0; /** * Predicted output for children node. */ double falseChildOutput = 0.0; /** * Constructor. */ public Node(double output) { this.output = output; } /** * Evaluate the regression tree over an instance. */ public double predict(double[] x) { if (trueChild == null && falseChild == null) { return output; } else { if (attributes[splitFeature].getType() == Attribute.Type.NOMINAL) { if (Math.equals(x[splitFeature], splitValue)) { return trueChild.predict(x); } else { return falseChild.predict(x); } } else if (attributes[splitFeature].getType() == Attribute.Type.NUMERIC) { if (x[splitFeature] <= splitValue) { return trueChild.predict(x); } else { return falseChild.predict(x); } } else { throw new IllegalStateException("Unsupported attribute type: " + attributes[splitFeature].getType()); } } } /** * Evaluate the regression tree over an instance. */ public double predict(int[] x) { if (trueChild == null && falseChild == null) { return output; } else if (x[splitFeature] == (int) splitValue) { return trueChild.predict(x); } else { return falseChild.predict(x); } } } /** * Regression tree node for training purpose. */ class TrainNode implements Comparable { /** * The associated regression tree node. */ Node node; /** * Child node that passes the test. */ TrainNode trueChild; /** * Child node that fails the test. */ TrainNode falseChild; /** * Training dataset. */ double[][] x; /** * Training data response value. */ double[] y; /** * The samples for training this node. Note that samples[i] is the * number of sampling of dataset[i]. 0 means that the datum is not * included and values of greater than 1 are possible because of * sampling with replacement. */ int[] samples; /** * Constructor. */ public TrainNode(Node node, double[][] x, double[] y, int[] samples) { this.node = node; this.x = x; this.y = y; this.samples = samples; } @Override public int compareTo(TrainNode a) { return (int) Math.signum(a.node.splitScore - node.splitScore); } /** * Calculate the node output for leaves. * * @param output the output calculate functor. */ public void calculateOutput(NodeOutput output) { if (node.trueChild == null && node.falseChild == null) { node.output = output.calculate(samples); } else { if (trueChild != null) { trueChild.calculateOutput(output); } if (falseChild != null) { falseChild.calculateOutput(output); } } } /** * Finds the best attribute to split on at the current node. Returns * true if a split exists to reduce squared error, false otherwise. */ public boolean findBestSplit() { int n = 0; for (int s : samples) { n += s; } if (n <= nodeSize) { return false; } double sum = node.output * n; int p = attributes.length; int[] variables = IntStream.range(0, attributes.length).toArray(); // Loop through features and compute the reduction of squared error, // which is trueCount * trueMean^2 + falseCount * falseMean^2 - count * parentMean^2 if (mtry < p) { Math.permutate(variables); // Random forest already runs on parallel. for (int j = 0; j < mtry; j++) { Node split = findBestSplit(n, sum, variables[j]); if (split.splitScore > node.splitScore) { node.splitFeature = split.splitFeature; node.splitValue = split.splitValue; node.splitScore = split.splitScore; node.gain = split.gain; node.trueChildOutput = split.trueChildOutput; node.falseChildOutput = split.falseChildOutput; } } } else { List tasks = new ArrayList<>(mtry); for (int j = 0; j < mtry; j++) { tasks.add(new SplitTask(n, sum, variables[j])); } try { for (Node split : MulticoreExecutor.run(tasks)) { if (split.splitScore > node.splitScore) { node.splitFeature = split.splitFeature; node.splitValue = split.splitValue; node.splitScore = split.splitScore; node.gain = split.gain; node.trueChildOutput = split.trueChildOutput; node.falseChildOutput = split.falseChildOutput; } } } catch (Exception ex) { for (int j = 0; j < mtry; j++) { Node split = findBestSplit(n, sum, variables[j]); if (split.splitScore > node.splitScore) { node.splitFeature = split.splitFeature; node.splitValue = split.splitValue; node.splitScore = split.splitScore; node.gain = split.gain; node.trueChildOutput = split.trueChildOutput; node.falseChildOutput = split.falseChildOutput; } } } } return (node.splitFeature != -1); } /** * Task to find the best split cutoff for attribute j at the current node. */ class SplitTask implements Callable { /** * The number instances in this node. */ int n; /** * The sum of responses of this node. */ double sum; /** * The index of variables for this task. */ int j; SplitTask(int n, double sum, int j) { this.n = n; this.sum = sum; this.j = j; } @Override public Node call() { return findBestSplit(n, sum, j); } } /** * Finds the best split cutoff for attribute j at the current node. * * @param n the number instances in this node. * @param j the attribute to split on. */ public Node findBestSplit(int n, double sum, int j) { Node split = new Node(0.0); if (attributes[j].getType() == Attribute.Type.NOMINAL) { int m = ((NominalAttribute) attributes[j]).size(); double[] trueSum = new double[m]; int[] trueCount = new int[m]; for (int i = 0; i < x.length; i++) { if (samples[i] > 0) { double target = samples[i] * y[i]; // For each true feature of this datum increment the // sufficient statistics for the "true" branch to evaluate // splitting on this feature. int index = (int) x[i][j]; trueSum[index] += target; trueCount[index] += samples[i]; } } for (int k = 0; k < m; k++) { double tc = (double) trueCount[k]; double fc = n - tc; // If either side is empty, skip this feature. if (tc < nodeSize || fc < nodeSize) { continue; } // compute penalized means double trueMean = trueSum[k] / tc; double falseMean = (sum - trueSum[k]) / fc; double gain = (tc * trueMean * trueMean + fc * falseMean * falseMean) - n * split.output * split.output; if (gain > split.splitScore) { // new best split split.splitFeature = j; split.splitValue = k; split.splitScore = gain; split.trueChildOutput = trueMean; split.falseChildOutput = falseMean; } } } else if (attributes[j].getType() == Attribute.Type.NUMERIC) { double trueSum = 0.0; int trueCount = 0; double prevx = Double.NaN; for (int i : order[j]) { if (samples[i] > 0) { if (Double.isNaN(prevx) || x[i][j] == prevx) { prevx = x[i][j]; trueSum += samples[i] * y[i]; trueCount += samples[i]; continue; } double falseCount = n - trueCount; // If either side is empty, skip this feature. if (trueCount < nodeSize || falseCount < nodeSize) { prevx = x[i][j]; trueSum += samples[i] * y[i]; trueCount += samples[i]; continue; } // compute penalized means double trueMean = trueSum / trueCount; double falseMean = (sum - trueSum) / falseCount; // The gain is actually -(reduction in squared error) for // sorting in priority queue, which treats smaller number with // higher priority. double gain = (trueCount * trueMean * trueMean + falseCount * falseMean * falseMean) - n * split.output * split.output; double score = gain; double monoRegForFeature = monotonicRegression[j]; // False child - larger values of feature if (monoRegForFeature > 0) { boolean isTargetDecreasing = trueMean > falseMean; if (isTargetDecreasing) { score *= 1 - Math.abs(monoRegForFeature); } } else if (monoRegForFeature < 0) { boolean isTargetDecreasing = trueMean < falseMean; if (isTargetDecreasing) { score *= 1 - Math.abs(monoRegForFeature); } } // monoRegForFeature == 0 - no monotonic regression if (score > split.splitScore) { // new best split split.gain = gain; split.splitFeature = j; split.splitValue = (x[i][j] + prevx) / 2; split.splitScore = score; split.trueChildOutput = trueMean; split.falseChildOutput = falseMean; } prevx = x[i][j]; trueSum += samples[i] * y[i]; trueCount += samples[i]; } } } else { throw new IllegalStateException("Unsupported attribute type: " + attributes[j].getType()); } return split; } /** * Split the node into two children nodes. Returns true if split success. */ public void split(PriorityQueue nextSplits) { if(nextSplits == null) { throw new IllegalArgumentException("nextSplits cannot be null"); } if (node.splitFeature < 0) { throw new IllegalStateException("Split a node with invalid feature."); } int n = x.length; int tc = 0; int fc = 0; int[] trueSamples = new int[n]; int[] falseSamples = new int[n]; if (attributes[node.splitFeature].getType() == Attribute.Type.NOMINAL) { for (int i = 0; i < n; i++) { if (samples[i] > 0) { if (Math.equals(x[i][node.splitFeature], node.splitValue)) { trueSamples[i] = samples[i]; tc += trueSamples[i]; samples[i] = 0; } else { falseSamples[i] = samples[i]; fc += samples[i]; } } } } else if (attributes[node.splitFeature].getType() == Attribute.Type.NUMERIC) { for (int i = 0; i < n; i++) { if (samples[i] > 0) { if (x[i][node.splitFeature] <= node.splitValue) { trueSamples[i] = samples[i]; tc += trueSamples[i]; samples[i] = 0; } else { falseSamples[i] = samples[i]; fc += samples[i]; } } } } else { throw new IllegalStateException("Unsupported attribute type: " + attributes[node.splitFeature].getType()); } if (tc < nodeSize || fc < nodeSize) { node.splitFeature = -1; node.splitValue = Double.NaN; node.splitScore = 0.0; node.gain = 0.0; return; } node.trueChild = new Node(node.trueChildOutput); node.falseChild = new Node(node.falseChildOutput); trueChild = new TrainNode(node.trueChild, x, y, trueSamples); if (tc > nodeSize && trueChild.findBestSplit()) { nextSplits.add(trueChild); } falseChild = new TrainNode(node.falseChild, x, y, falseSamples); if (fc > nodeSize && falseChild.findBestSplit()) { nextSplits.add(falseChild); } importance[node.splitFeature] += node.gain; } } /** * Regression tree training node for sparse binary features. */ class SparseBinaryTrainNode implements Comparable { /** * The associated regression tree node. */ Node node; /** * Child node that passes the test. */ SparseBinaryTrainNode trueChild; /** * Child node that fails the test. */ SparseBinaryTrainNode falseChild; /** * Training dataset. */ int[][] x; /** * Training data response value. */ double[] y; /** * The samples for training this node. Note that samples[i] is the * number of sampling of dataset[i]. 0 means that the datum is not * included and values of greater than 1 are possible because of * sampling with replacement. */ int[] samples; /** * Constructor. */ public SparseBinaryTrainNode(Node node, int[][] x, double[] y, int[] samples) { this.node = node; this.x = x; this.y = y; this.samples = samples; } @Override public int compareTo(SparseBinaryTrainNode a) { return (int) Math.signum(a.node.splitScore - node.splitScore); } /** * Finds the best attribute to split on at the current node. Returns * true if a split exists to reduce squared error, false otherwise. */ public boolean findBestSplit() { if (node.trueChild != null || node.falseChild != null) { throw new IllegalStateException("Split non-leaf node."); } int p = numFeatures; double[] trueSum = new double[p]; int[] trueCount = new int[p]; int[] featureIndex = new int[p]; int n = Math.sum(samples); double sumX = 0.0; for (int i = 0; i < x.length; i++) { if (samples[i] == 0) { continue; } double target = samples[i] * y[i]; sumX += y[i]; // For each true feature of this datum increment the // sufficient statistics for the "true" branch to evaluate // splitting on this feature. for (int j = 0; j < x[i].length; ++j) { int index = x[i][j]; trueSum[index] += target; trueCount[index] += samples[i]; featureIndex[index] = j; } } // Loop through features and compute the reduction // of squared error, which is trueCount * trueMean^2 + falseCount * falseMean^2 - count * parentMean^2 // Initialize the information in the leaf node.splitScore = 0.0; node.splitFeature = -1; node.splitValue = -1; for (int i = 0; i < p; ++i) { double tc = (double) trueCount[i]; double fc = n - tc; // If either side would have fewer than minimum data points, skip this feature. if (tc < nodeSize || fc < nodeSize) { continue; } // compute penalized means double trueMean = trueSum[i] / tc; double falseMean = (sumX - trueSum[i]) / fc; double gain = (tc * trueMean * trueMean + fc * falseMean * falseMean) - n * node.output * node.output; if (gain > node.splitScore) { // new best split node.splitFeature = featureIndex[i]; node.splitValue = i; node.splitScore = gain; node.trueChildOutput = trueMean; node.falseChildOutput = falseMean; } } return (node.splitFeature != -1); } /** * Split the node into two children nodes. */ public void split(PriorityQueue nextSplits) { if (node.splitFeature < 0) { throw new IllegalStateException("Split a node with invalid feature."); } if (node.trueChild != null || node.falseChild != null) { throw new IllegalStateException("Split non-leaf node."); } int n = x.length; int tc = 0; int fc = 0; int[] trueSamples = new int[n]; int[] falseSamples = new int[n]; for (int i = 0; i < n; i++) { if (samples[i] > 0) { if (x[i][node.splitFeature] == (int) node.splitValue) { trueSamples[i] = samples[i]; tc += trueSamples[i]; samples[i] = 0; } else { falseSamples[i] = samples[i]; fc += samples[i]; } } } node.trueChild = new Node(node.trueChildOutput); node.falseChild = new Node(node.falseChildOutput); trueChild = new SparseBinaryTrainNode(node.trueChild, x, y, trueSamples); if (tc > nodeSize && trueChild.findBestSplit()) { if (nextSplits != null) { nextSplits.add(trueChild); } else { trueChild.split(null); } } falseChild = new SparseBinaryTrainNode(node.falseChild, x, y, falseSamples); if (fc > nodeSize && falseChild.findBestSplit()) { if (nextSplits != null) { nextSplits.add(falseChild); } else { falseChild.split(null); } } importance[node.splitFeature] += node.gain; } /** * Calculate the node output for leaves. * * @param output the output calculate functor. */ public void calculateOutput(NodeOutput output) { if (node.trueChild == null && node.falseChild == null) { node.output = output.calculate(samples); } else { if (trueChild != null) { trueChild.calculateOutput(output); } if (falseChild != null) { falseChild.calculateOutput(output); } } } } /** * Constructor. Learns a regression tree with (most) given number of leaves. * All attributes are assumed to be numeric. * * @param x the training instances. * @param y the response variable. * @param maxNodes the maximum number of leaf nodes in the tree. */ public RegressionTree(double[][] x, double[] y, int maxNodes) { this(null, x, y, maxNodes, 5); } /** * Constructor. Learns a regression tree with (most) given number of leaves. * All attributes are assumed to be numeric. * * @param x the training instances. * @param y the response variable. * @param maxNodes the maximum number of leaf nodes in the tree. */ public RegressionTree(double[][] x, double[] y, int maxNodes, int nodeSize) { this(null, x, y, maxNodes, nodeSize); } /** * Constructor. Learns a regression tree with (most) given number of leaves. * * @param attributes the attribute properties. * @param x the training instances. * @param y the response variable. * @param maxNodes the maximum number of leaf nodes in the tree. */ public RegressionTree(Attribute[] attributes, double[][] x, double[] y, int maxNodes) { this(attributes, x, y, maxNodes, 5); } /** * Constructor. Learns a regression tree for random forest and gradient tree boosting. * * @param data the dataset. * @param maxNodes the maximum number of leaf nodes in the tree. * samples[i] should be 0 or 1 to indicate if the instance is used for training. */ public RegressionTree(AttributeDataset data, int maxNodes) { this(data.attributes(), data.x(), data.y(), maxNodes); } /** * Constructor. Learns a regression tree with (most) given number of leaves. * * @param attributes the attribute properties. * @param x the training instances. * @param y the response variable. * @param maxNodes the maximum number of leaf nodes in the tree. */ public RegressionTree(Attribute[] attributes, double[][] x, double[] y, int maxNodes, int nodeSize) { this(attributes, x, y, maxNodes, nodeSize, x[0].length, null, null, null); } /** * Constructor. Learns a regression tree for random forest and gradient tree boosting. * * @param data the dataset. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the number of instances in a node below which the tree will * not split, setting nodeSize = 5 generally gives good results. */ public RegressionTree(AttributeDataset data, int maxNodes, int nodeSize) { this(data.attributes(), data.x(), data.y(), maxNodes, nodeSize); } /** * Constructor. Learns a regression tree for random forest and gradient tree boosting. * * @param attributes the attribute properties. * @param x the training instances. * @param y the response variable. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the number of instances in a node below which the tree will * not split, setting nodeSize = 5 generally gives good results. * @param mtry the number of input variables to pick to split on at each * node. It seems that p/3 give generally good performance, where p * is the number of variables. * @param order the index of training values in ascending order. Note * that only numeric attributes need be sorted. * @param samples the sample set of instances for stochastic learning. * samples[i] should be 0 or 1 to indicate if the instance is used for training. */ public RegressionTree(Attribute[] attributes, double[][] x, double[] y, int maxNodes, int nodeSize, int mtry, int[][] order, int[] samples, NodeOutput output) { this(attributes, x, y, maxNodes, nodeSize, mtry, order, samples, output, null); } /** * Constructor. Learns a regression tree for random forest and gradient tree boosting. * * @param data the dataset. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the number of instances in a node below which the tree will * not split, setting nodeSize = 5 generally gives good results. * @param mtry the number of input variables to pick to split on at each * node. It seems that p/3 give generally good performance, where p * is the number of variables. * @param order the index of training values in ascending order. Note * that only numeric attributes need be sorted. * @param samples the sample set of instances for stochastic learning. * samples[i] should be 0 or 1 to indicate if the instance is used for training. */ public RegressionTree(AttributeDataset data, int maxNodes, int nodeSize, int mtry, int[][] order, int[] samples, NodeOutput output) { this(data.attributes(), data.x(), data.y(), maxNodes, nodeSize, mtry, order, samples, output); } public RegressionTree(AttributeDataset data, int maxNodes, int nodeSize, int mtry, int[][] order, int[] samples, NodeOutput output, double[] monotonicRegression) { this(data.attributes(), data.x(), data.y(), maxNodes, nodeSize, mtry, order, samples, output, monotonicRegression); } public RegressionTree(Attribute[] attributes, double[][] x, double[] y, int maxNodes, int nodeSize, int mtry, int[][] order, int[] samples, NodeOutput output, double[] monotonicRegression) { if (x.length != y.length) { throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length)); } if (mtry < 1 || mtry > x[0].length) { throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + mtry); } if (maxNodes < 2) { throw new IllegalArgumentException("Invalid maximum leaves: " + maxNodes); } if (nodeSize < 2) { throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize); } if (attributes == null) { int p = x[0].length; attributes = new Attribute[p]; for (int i = 0; i < p; i++) { attributes[i] = new NumericAttribute("V" + (i + 1)); } } if (monotonicRegression == null) { // initialized with zeros which is neutral monoreg value monotonicRegression = new double[attributes.length]; } this.attributes = attributes; this.monotonicRegression = monotonicRegression; this.maxNodes = maxNodes; this.nodeSize = nodeSize; this.mtry = mtry; importance = new double[attributes.length]; if (order != null) { this.order = order; } else { int n = x.length; int p = x[0].length; double[] a = new double[n]; this.order = new int[p][]; for (int j = 0; j < p; j++) { if (attributes[j] instanceof NumericAttribute) { for (int i = 0; i < n; i++) { a[i] = x[i][j]; } this.order[j] = QuickSort.sort(a); } } } int n = 0; double sum = 0.0; if (samples == null) { n = y.length; samples = new int[n]; for (int i = 0; i < n; i++) { samples[i] = 1; sum += y[i]; } } else { for (int i = 0; i < y.length; i++) { n += samples[i]; sum += samples[i] * y[i]; } } root = new Node(sum / n); TrainNode trainRoot = new TrainNode(root, x, y, samples); // Now add splits to the tree until max tree size is reached if (trainRoot.findBestSplit()) { // Priority queue for best-first tree growing. PriorityQueue nextSplits = new PriorityQueue<>(); nextSplits.add(trainRoot); // Pop best leaf from priority queue, split it, and push // children nodes into the queue if possible. for (int leaves = 1; leaves < this.maxNodes; leaves++) { // parent is the leaf to split TrainNode node = nextSplits.poll(); if (node == null) { break; } node.split(nextSplits); // Split the parent node into two children nodes } } if (output != null) { trainRoot.calculateOutput(output); } } /** * Constructor. Learns a regression tree on sparse binary samples. * * @param numFeatures the number of sparse binary features. * @param x the training instances of sparse binary features. * @param y the response variable. * @param maxNodes the maximum number of leaf nodes in the tree. */ public RegressionTree(int numFeatures, int[][] x, double[] y, int maxNodes) { this(numFeatures, x, y, maxNodes, 5); } /** * Constructor. Learns a regression tree on sparse binary samples. * * @param numFeatures the number of sparse binary features. * @param x the training instances of sparse binary features. * @param y the response variable. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the number of instances in a node below which the tree will * not split, setting nodeSize = 5 generally gives good results. */ public RegressionTree(int numFeatures, int[][] x, double[] y, int maxNodes, int nodeSize) { this(numFeatures, x, y, maxNodes, nodeSize, null, null); } /** * Constructor. Learns a regression tree on sparse binary samples. * * @param numFeatures the number of sparse binary features. * @param x the training instances. * @param y the response variable. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the number of instances in a node below which the tree will * not split, setting nodeSize = 5 generally gives good results. * @param samples the sample set of instances for stochastic learning. * samples[i] should be 0 or 1 to indicate if the instance is used for training. */ public RegressionTree(int numFeatures, int[][] x, double[] y, int maxNodes, int nodeSize, int[] samples, NodeOutput output) { if (x.length != y.length) { throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length)); } if (maxNodes < 2) { throw new IllegalArgumentException("Invalid maximum number of leaves: " + maxNodes); } if (nodeSize < 2) { throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize); } this.maxNodes = maxNodes; this.nodeSize = nodeSize; this.numFeatures = numFeatures; this.mtry = numFeatures; importance = new double[numFeatures]; // Priority queue for best-first tree growing. PriorityQueue nextSplits = new PriorityQueue<>(); int n = 0; double sum = 0.0; if (samples == null) { n = y.length; samples = new int[n]; for (int i = 0; i < n; i++) { samples[i] = 1; sum += y[i]; } } else { for (int i = 0; i < y.length; i++) { n += samples[i]; sum += samples[i] * y[i]; } } root = new Node(sum / n); SparseBinaryTrainNode trainRoot = new SparseBinaryTrainNode(root, x, y, samples); // Now add splits to the tree until max tree size is reached if (trainRoot.findBestSplit()) { nextSplits.add(trainRoot); } // Pop best leaf from priority queue, split it, and push // children nodes into the queue if possible. for (int leaves = 1; leaves < this.maxNodes; leaves++) { // parent is the leaf to split SparseBinaryTrainNode node = nextSplits.poll(); if (node == null) { break; } node.split(nextSplits); // Split the parent node into two children nodes } if (output != null) { trainRoot.calculateOutput(output); } } /** * Returns the variable importance. Every time a split of a node is made * on variable the impurity criterion for the two descendent nodes is less * than the parent node. Adding up the decreases for each individual * variable over the tree gives a simple measure of variable importance. * * @return the variable importance */ public double[] importance() { return importance; } @Override public double predict(double[] x) { return root.predict(x); } /** * Predicts the dependent variable of an instance with sparse binary features. * * @param x the instance. * @return the predicted value of dependent variable. */ public double predict(int[] x) { return root.predict(x); } /** * Returns the maximum depth" of the tree -- the number of * nodes along the longest path from the root node * down to the farthest leaf node. */ public int maxDepth() { return maxDepth(root); } private int maxDepth(Node node) { if (node == null) return 0; // compute the depth of each subtree int lDepth = maxDepth(node.trueChild); int rDepth = maxDepth(node.falseChild); // use the larger one if (lDepth > rDepth) return (lDepth + 1); else return (rDepth + 1); } // For dot() tree traversal. private class DotNode { int parent; int id; Node node; DotNode(int parent, int id, Node node) { this.parent = parent; this.id = id; this.node = node; } } /** * Returns the graphic representation in Graphviz dot format. * Try http://viz-js.com/ to visualize the returned string. */ public String dot() { StringBuilder builder = new StringBuilder(); builder.append("digraph RegressionTree {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n"); int n = 0; // number of nodes processed Queue queue = new LinkedList<>(); queue.add(new DotNode(-1, 0, root)); while (!queue.isEmpty()) { // Dequeue a vertex from queue and print it DotNode dnode = queue.poll(); int id = dnode.id; int parent = dnode.parent; Node node = dnode.node; // leaf node if (node.trueChild == null && node.falseChild == null) { builder.append(String.format(" %d [label=<%.4f>, fillcolor=\"#00000000\", shape=ellipse];\n", id, node.output)); } else { Attribute attr = attributes[node.splitFeature]; if (attr.getType() == Attribute.Type.NOMINAL) { builder.append(String.format(" %d [label=<%s = %s
nscore = %.4f>, fillcolor=\"#00000000\"];\n", id, attr.getName(), attr.toString(node.splitValue), node.splitScore)); } else if (attr.getType() == Attribute.Type.NUMERIC) { builder.append(String.format(" %d [label=<%s ≤ %.4f
score = %.4f>, fillcolor=\"#00000000\"];\n", id, attr.getName(), node.splitValue, node.splitScore)); } else { throw new IllegalStateException("Unsupported attribute type: " + attr.getType()); } } // add edge if (parent >= 0) { builder.append(' ').append(parent).append(" -> ").append(id); // only draw edge label at top if (parent == 0) { if (id == 1) { builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]"); } else { builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]"); } } builder.append(";\n"); } if (node.trueChild != null) { queue.add(new DotNode(id, ++n, node.trueChild)); } if (node.falseChild != null) { queue.add(new DotNode(id, ++n, node.falseChild)); } } builder.append("}"); return builder.toString(); } /** * Returs the root node. * @return root node. */ public Node getRoot() { return root; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy