hivemall.smile.classification.DecisionTree Maven / Gradle / Ivy
/*
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file includes a modified version of Smile:
// https://github.com/haifengl/smile/blob/master/core/src/main/java/smile/classification/DecisionTree.java
package hivemall.smile.classification;
import static hivemall.smile.classification.PredictionHandler.Operator.EQ;
import static hivemall.smile.classification.PredictionHandler.Operator.GT;
import static hivemall.smile.classification.PredictionHandler.Operator.LE;
import static hivemall.smile.classification.PredictionHandler.Operator.NE;
import static hivemall.smile.utils.SmileExtUtils.NOMINAL;
import static hivemall.smile.utils.SmileExtUtils.NUMERIC;
import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName;
import static hivemall.smile.utils.SmileExtUtils.resolveName;
import hivemall.annotations.VisibleForTesting;
import matrix4j.matrix.Matrix;
import matrix4j.vector.DenseVector;
import matrix4j.vector.SparseVector;
import matrix4j.vector.Vector;
import matrix4j.vector.VectorProcedure;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.VariableOrder;
import hivemall.utils.collections.arrays.SparseIntArray;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.function.Consumer;
import hivemall.utils.function.IntPredicate;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.lang.ObjectUtils;
import hivemall.utils.lang.StringUtils;
import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.random.PRNG;
import hivemall.utils.random.RandomNumberGeneratorFactory;
import hivemall.utils.sampling.IntReservoirSampler;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import smile.classification.Classifier;
import smile.math.Math;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.roaringbitmap.IntConsumer;
import org.roaringbitmap.RoaringBitmap;
/**
* Decision tree for classification. A decision tree can be learned by splitting the training set
* into subsets based on an attribute value test. This process is repeated on each derived subset in
* a recursive manner called recursive partitioning. The recursion is completed when the subset at a
* node all has the same value of the target variable, or when splitting no longer adds value to the
* predictions.
*
* The algorithms that are used for constructing decision trees usually work top-down by choosing a
* variable at each step that is the next best variable to use in splitting the set of items. "Best"
* is defined by how well the variable splits the set into homogeneous subsets that have the same
* value of the target variable. Different algorithms use different formulae for measuring "best".
* Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen element
* from the set would be incorrectly labeled if it were randomly labeled according to the
* distribution of labels in the subset. Gini impurity can be computed by summing the probability of
* each item being chosen times the probability of a mistake in categorizing that item. It reaches
* its minimum (zero) when all cases in the node fall into a single target category. Information
* gain is another popular measure, used by the ID3, C4.5 and C5.0 algorithms. Information gain is
* based on the concept of entropy used in information theory. For categorical variables with
* different number of levels, however, information gain are biased in favor of those attributes
* with more levels. Instead, one may employ the information gain ratio, which solves the drawback
* of information gain.
*
* Classification and Regression Tree techniques have a number of advantages over many of those
* alternative techniques.
*
* - Simple to understand and interpret.
* - In most cases, the interpretation of results summarized in a tree is very simple. This
* simplicity is useful not only for purposes of rapid classification of new observations, but can
* also often yield a much simpler "model" for explaining why observations are classified or
* predicted in a particular manner.
* - Able to handle both numerical and categorical data.
* - Other techniques are usually specialized in analyzing datasets that have only one type of
* variable.
* - Tree methods are nonparametric and nonlinear.
* - The final results of using tree methods for classification or regression can be summarized in
* a series of (usually few) logical if-then conditions (tree nodes). Therefore, there is no
* implicit assumption that the underlying relationships between the predictor variables and the
* dependent variable are linear, follow some specific non-linear link function, or that they are
* even monotonic in nature. Thus, tree methods are particularly well suited for data mining tasks,
* where there is often little a priori knowledge nor any coherent set of theories or predictions
* regarding which variables are related and how. In those types of data analytics, tree methods can
* often reveal simple relationships between just a few variables that could have easily gone
* unnoticed using other analytic techniques.
*
* One major problem with classification and regression trees is their high variance. Often a small
* change in the data can result in a very different series of splits, making interpretation
* somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause
* over-fitting. Mechanisms such as pruning are necessary to avoid this problem. Another limitation
* of trees is the lack of smoothness of the prediction surface.
*
* Some techniques such as bagging, boosting, and random forest use more than one decision tree for
* their analysis.
*/
public class DecisionTree implements Classifier {
private static final Log logger = LogFactory.getLog(DecisionTree.class);
/**
* Training dataset.
*/
@Nonnull
private final Matrix _X;
/**
* class labels.
*/
@Nonnull
private final int[] _y;
/**
* The samples for training this node. Note that samples[i] is the number of sampling of
* dataset[i]. 0 means that the datum is not included and values of greater than 1 are possible
* because of sampling with replacement.
*/
@Nonnull
private final int[] _samples;
/**
* An index of training values. Initially, order[j] is a set of indices that iterate through the
* training values for attribute j in ascending order. During training, the array is rearranged
* so that all values for each leaf node occupy a contiguous range, but within that range they
* maintain the original ordering. Note that only numeric attributes will be sorted; non-numeric
* attributes will have a null in the corresponding place in the array.
*/
@Nonnull
private final VariableOrder _order;
/**
* An index that maps their current position in the {@link #_order} to their original locations
* in {@link #_samples}.
*/
@Nonnull
private final int[] _sampleIndex;
/**
* The attributes of independent variable.
*/
@Nonnull
private final RoaringBitmap _nominalAttrs;
/**
* Variable importance. Every time a split of a node is made on variable the (GINI, information
* gain, etc.) impurity criterion for the two descendant nodes is less than the parent node.
* Adding up the decreases for each individual variable over the tree gives a simple measure of
* variable importance.
*/
@Nonnull
private final Vector _importance;
/**
* The root of the regression tree
*/
@Nonnull
private final Node _root;
/**
* The maximum number of the tree depth
*/
private final int _maxDepth;
/**
* The splitting rule.
*/
@Nonnull
private final SplitRule _rule;
/**
* The number of classes.
*/
private final int _k;
/**
* The number of input variables to be used to determine the decision at a node of the tree.
*/
private final int _numVars;
/**
* The number of instances in a node below which the tree will not split.
*/
private final int _minSamplesSplit;
/**
* The minimum number of samples in a leaf node.
*/
private final int _minSamplesLeaf;
/**
* The random number generator.
*/
@Nonnull
private final PRNG _rnd;
/**
* The criterion to choose variable to split instances.
*/
public static enum SplitRule {
/**
* Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen
* element from the set would be incorrectly labeled if it were randomly labeled according
* to the distribution of labels in the subset. Gini impurity can be computed by summing the
* probability of each item being chosen times the probability of a mistake in categorizing
* that item. It reaches its minimum (zero) when all cases in the node fall into a single
* target category.
*/
GINI,
/**
* Used by the ID3, C4.5 and C5.0 tree generation algorithms.
*/
ENTROPY,
/**
* Classification error.
*/
CLASSIFICATION_ERROR
}
/**
* Classification tree node.
*/
public static final class Node implements Externalizable {
/**
* Predicted class label for this node.
*/
int output = -1;
/**
* A posteriori probability based on sample ratios in this node.
*/
@Nullable
double[] posteriori = null;
/**
* The split feature for this node.
*/
int splitFeature = -1;
/**
* The type of split feature
*/
boolean quantitativeFeature = true;
/**
* The split value.
*/
double splitValue = Double.NaN;
/**
* Reduction in splitting criterion.
*/
double splitScore = 0.0;
/**
* Children node.
*/
Node trueChild = null;
/**
* Children node.
*/
Node falseChild = null;
public Node() {}// for Externalizable
public Node(@Nonnull double[] posteriori) {
this(Math.whichMax(posteriori), posteriori);
}
public Node(int output, @Nonnull double[] posteriori) {
this.output = output;
this.posteriori = posteriori;
}
private boolean isLeaf() {
return trueChild == null && falseChild == null;
}
private void markAsLeaf() {
this.splitFeature = -1;
this.splitValue = Double.NaN;
this.splitScore = 0.0;
this.trueChild = null;
this.falseChild = null;
}
@VisibleForTesting
public int predict(@Nonnull final double[] x) {
return predict(new DenseVector(x));
}
/**
* Evaluate the regression tree over an instance.
*/
public int predict(@Nonnull final Vector x) {
if (isLeaf()) {
return output;
} else {
if (quantitativeFeature) {
if (x.get(splitFeature, Double.NaN) <= splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
} else {
if (x.get(splitFeature, Double.NaN) == splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
}
}
}
/**
* Evaluate the regression tree over an instance.
*/
public void predict(@Nonnull final Vector x, @Nonnull final PredictionHandler handler) {
if (isLeaf()) {
handler.visitLeaf(output, posteriori);
} else {
final double feature = x.get(splitFeature, Double.NaN);
if (quantitativeFeature) {
if (feature <= splitValue) {
handler.visitBranch(LE, splitFeature, feature, splitValue);
trueChild.predict(x, handler);
} else {
handler.visitBranch(GT, splitFeature, feature, splitValue);
falseChild.predict(x, handler);
}
} else {
if (feature == splitValue) {
handler.visitBranch(EQ, splitFeature, feature, splitValue);
trueChild.predict(x, handler);
} else {
handler.visitBranch(NE, splitFeature, feature, splitValue);
falseChild.predict(x, handler);
}
}
}
}
public void exportJavascript(@Nonnull final StringBuilder builder,
@Nullable final String[] featureNames, @Nullable final String[] classNames,
final int depth) {
if (isLeaf()) {
indent(builder, depth);
builder.append("").append(resolveName(output, classNames)).append(";\n");
} else {
indent(builder, depth);
if (quantitativeFeature) {
if (featureNames == null) {
builder.append("if( x[")
.append(splitFeature)
.append("] <= ")
.append(splitValue)
.append(" ) {\n");
} else {
builder.append("if( ")
.append(resolveFeatureName(splitFeature, featureNames))
.append(" <= ")
.append(splitValue)
.append(" ) {\n");
}
} else {
if (featureNames == null) {
builder.append("if( x[")
.append(splitFeature)
.append("] == ")
.append(splitValue)
.append(" ) {\n");
} else {
builder.append("if( ")
.append(resolveFeatureName(splitFeature, featureNames))
.append(" == ")
.append(splitValue)
.append(" ) {\n");
}
}
trueChild.exportJavascript(builder, featureNames, classNames, depth + 1);
indent(builder, depth);
builder.append("} else {\n");
falseChild.exportJavascript(builder, featureNames, classNames, depth + 1);
indent(builder, depth);
builder.append("}\n");
}
}
public void exportGraphviz(@Nonnull final StringBuilder builder,
@Nullable final String[] featureNames, @Nullable final String[] classNames,
@Nonnull final String outputName, @Nullable final double[] colorBrew,
@Nonnull final MutableInt nodeIdGenerator, final int parentNodeId) {
final int myNodeId = nodeIdGenerator.getValue();
if (isLeaf()) {
// fillcolor=h,s,v
// https://en.wikipedia.org/wiki/HSL_and_HSV
// http://www.graphviz.org/doc/info/attrs.html#k:colorList
String hsvColor = (colorBrew == null || output >= colorBrew.length) ? "#00000000"
: String.format("%.4f,1.000,1.000", colorBrew[output]);
builder.append(
String.format(" %d [label=<%s = %s>, fillcolor=\"%s\", shape=ellipse];\n",
myNodeId, outputName, resolveName(output, classNames), hsvColor));
if (myNodeId != parentNodeId) {
builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId);
if (parentNodeId == 0) {
if (myNodeId == 1) {
builder.append(
" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
} else {
builder.append(
" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
}
}
builder.append(";\n");
}
} else {
if (quantitativeFeature) {
builder.append(
String.format(" %d [label=<%s ≤ %s>, fillcolor=\"#00000000\"];\n",
myNodeId, resolveFeatureName(splitFeature, featureNames),
Double.toString(splitValue)));
} else {
builder.append(
String.format(" %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", myNodeId,
resolveFeatureName(splitFeature, featureNames),
Double.toString(splitValue)));
}
if (myNodeId != parentNodeId) {
builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId);
if (parentNodeId == 0) {//only draw edge label on top
if (myNodeId == 1) {
builder.append(
" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
} else {
builder.append(
" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
}
}
builder.append(";\n");
}
nodeIdGenerator.addValue(1);
trueChild.exportGraphviz(builder, featureNames, classNames, outputName, colorBrew,
nodeIdGenerator, myNodeId);
nodeIdGenerator.addValue(1);
falseChild.exportGraphviz(builder, featureNames, classNames, outputName, colorBrew,
nodeIdGenerator, myNodeId);
}
}
@Deprecated
public int opCodegen(@Nonnull final List scripts, int depth) {
int selfDepth = 0;
final StringBuilder buf = new StringBuilder();
if (isLeaf()) {
buf.append("push ").append(output);
scripts.add(buf.toString());
buf.setLength(0);
buf.append("goto last");
scripts.add(buf.toString());
selfDepth += 2;
} else {
if (quantitativeFeature) {
buf.append("push ").append("x[").append(splitFeature).append("]");
scripts.add(buf.toString());
buf.setLength(0);
buf.append("push ").append(splitValue);
scripts.add(buf.toString());
buf.setLength(0);
buf.append("ifle ");
scripts.add(buf.toString());
depth += 3;
selfDepth += 3;
int trueDepth = trueChild.opCodegen(scripts, depth);
selfDepth += trueDepth;
scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
selfDepth += falseDepth;
} else {
buf.append("push ").append("x[").append(splitFeature).append("]");
scripts.add(buf.toString());
buf.setLength(0);
buf.append("push ").append(splitValue);
scripts.add(buf.toString());
buf.setLength(0);
buf.append("ifeq ");
scripts.add(buf.toString());
depth += 3;
selfDepth += 3;
int trueDepth = trueChild.opCodegen(scripts, depth);
selfDepth += trueDepth;
scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
selfDepth += falseDepth;
}
}
return selfDepth;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(splitFeature);
out.writeByte(quantitativeFeature ? NUMERIC : NOMINAL);
out.writeDouble(splitValue);
if (isLeaf()) {
out.writeBoolean(true);
out.writeInt(output);
out.writeInt(posteriori.length);
for (int i = 0; i < posteriori.length; i++) {
out.writeDouble(posteriori[i]);
}
} else {
out.writeBoolean(false);
if (trueChild == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
trueChild.writeExternal(out);
}
if (falseChild == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
falseChild.writeExternal(out);
}
}
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.splitFeature = in.readInt();
final byte typeId = in.readByte();
this.quantitativeFeature = (typeId == NUMERIC);
this.splitValue = in.readDouble();
if (in.readBoolean()) {//isLeaf
this.output = in.readInt();
final int size = in.readInt();
final double[] posteriori = new double[size];
for (int i = 0; i < size; i++) {
posteriori[i] = in.readDouble();
}
this.posteriori = posteriori;
} else {
if (in.readBoolean()) {
this.trueChild = new Node();
trueChild.readExternal(in);
}
if (in.readBoolean()) {
this.falseChild = new Node();
falseChild.readExternal(in);
}
}
}
}
private static void indent(final StringBuilder builder, final int depth) {
for (int i = 0; i < depth; i++) {
builder.append(" ");
}
}
/**
* Classification tree node for training purpose.
*/
private final class TrainNode implements Comparable {
/**
* The associated regression tree node.
*/
@Nonnull
final Node node;
/**
* Depth of the node in the tree
*/
final int depth;
/**
* The lower bound (inclusive) in the order array of the samples belonging to this node.
*/
final int low;
/**
* The upper bound (exclusive) in the order array of the samples belonging to this node.
*/
final int high;
/**
* The number of samples
*/
final int samples;
@Nullable
int[] constFeatures;
public TrainNode(@Nonnull Node node, int depth, int low, int high, int samples) {
this(node, depth, low, high, samples, new int[0]);
}
public TrainNode(@Nonnull Node node, int depth, int low, int high, int samples,
@Nonnull int[] constFeatures) {
if (low >= high) {
throw new IllegalArgumentException(
"Unexpected condition was met. low=" + low + ", high=" + high);
}
this.node = node;
this.depth = depth;
this.low = low;
this.high = high;
this.samples = samples;
this.constFeatures = constFeatures;
}
@Override
public int compareTo(TrainNode a) {
return (int) Math.signum(a.node.splitScore - node.splitScore);
}
/**
* Finds the best attribute to split on at the current node.
*
* @return true if a split exists to reduce squared error, false otherwise.
*/
public boolean findBestSplit() {
// avoid split if tree depth is larger than threshold
if (depth >= _maxDepth) {
return false;
}
// avoid split if the number of samples is less than threshold
if (samples <= _minSamplesSplit) {
return false;
}
// Sample count in each class.
final int[] count = new int[_k];
final boolean pure = countSamples(count);
if (pure) {// if all instances have same label, stop splitting.
return false;
}
final int[] constFeatures_ = this.constFeatures; // this.constFeatures may be replace in findBestSplit but it's accepted
final double impurity = impurity(count, samples, _rule);
final int[] falseCount = new int[_k];
for (int varJ : variableIndex()) {
if (ArrayUtils.contains(constFeatures_, varJ)) {
continue; // skip constant features
}
final Node split = findBestSplit(samples, count, falseCount, impurity, varJ);
if (split.splitScore > node.splitScore) {
node.splitFeature = split.splitFeature;
node.quantitativeFeature = split.quantitativeFeature;
node.splitValue = split.splitValue;
node.splitScore = split.splitScore;
}
}
return node.splitFeature != -1;
}
@Nonnull
private int[] variableIndex() {
final Matrix X = _X;
final IntReservoirSampler sampler = new IntReservoirSampler(_numVars, _rnd.nextLong());
if (X.isSparse()) {
// sample columns from sampled examples
final RoaringBitmap cols = new RoaringBitmap();
final VectorProcedure proc = new VectorProcedure() {
public void apply(final int col) {
cols.add(col);
}
};
final int[] sampleIndex = _sampleIndex;
for (int i = low, end = high; i < end; i++) {
int row = sampleIndex[i];
assert (_samples[row] != 0) : row;
X.eachColumnIndexInRow(row, proc);
}
cols.forEach(new IntConsumer() {
public void accept(final int k) {
sampler.add(k);
}
});
} else {
final int ncols = X.numColumns();
for (int i = 0; i < ncols; i++) {
sampler.add(i);
}
}
return sampler.getSample();
}
private boolean countSamples(@Nonnull final int[] count) {
final int[] sampleIndex = _sampleIndex;
final int[] samples = _samples;
final int[] y = _y;
boolean pure = true;
for (int i = low, end = high, label = -1; i < end; i++) {
int index = sampleIndex[i];
int y_i = y[index];
count[y_i] += samples[index];
if (label == -1) {
label = y_i;
} else if (y_i != label) {
pure = false;
}
}
return pure;
}
/**
* Finds the best split cutoff for attribute j at the current node.
*
* @param n the number instances in this node.
* @param count the sample count in each class.
* @param falseCount an array to store sample count in each class for false child node.
* @param impurity the impurity of this node.
* @param j the attribute index to split on.
*/
private Node findBestSplit(final int n, final int[] count, final int[] falseCount,
final double impurity, final int j) {
final int[] samples = _samples;
final int[] sampleIndex = _sampleIndex;
final Matrix X = _X;
final int[] y = _y;
final int classes = _k;
final Node splitNode = new Node();
if (_nominalAttrs.contains(j)) {
final Int2ObjectMap trueCount = new Int2ObjectOpenHashMap();
int countNaN = 0;
for (int i = low, end = high; i < end; i++) {
final int index = sampleIndex[i];
final int numSamples = samples[index];
if (numSamples == 0) {
continue;
}
final double v = X.get(index, j, Double.NaN);
if (Double.isNaN(v)) {
countNaN++;
continue;
}
int x_ij = (int) v;
int[] tc_x = trueCount.get(x_ij);
if (tc_x == null) {
tc_x = new int[classes];
trueCount.put(x_ij, tc_x);
}
int y_i = y[index];
tc_x[y_i] += numSamples;
}
final int countDistinctX = trueCount.size() + (countNaN == 0 ? 0 : 1);
if (countDistinctX <= 1) { // mark as a constant feature
this.constFeatures = ArrayUtils.sortedArraySet(constFeatures, j);
}
for (Int2ObjectMap.Entry e : trueCount.int2ObjectEntrySet()) {
final int l = e.getIntKey();
final int[] trueCount_l = e.getValue();
final int tc = Math.sum(trueCount_l);
final int fc = n - tc;
// skip splitting this feature.
if (tc < _minSamplesSplit || fc < _minSamplesSplit) {
continue;
}
for (int k = 0; k < classes; k++) {
falseCount[k] = count[k] - trueCount_l[k];
}
final double gain =
impurity - (double) tc / n * impurity(trueCount_l, tc, _rule)
- (double) fc / n * impurity(falseCount, fc, _rule);
if (gain > splitNode.splitScore) {
// new best split
splitNode.splitFeature = j;
splitNode.quantitativeFeature = false;
splitNode.splitValue = l;
splitNode.splitScore = gain;
}
}
} else {
final int[] trueCount = new int[classes];
final MutableInt countNaN = new MutableInt(0);
final MutableInt replaceCount = new MutableInt(0);
_order.eachNonNullInColumn(j, low, high, new Consumer() {
double prevx = Double.NaN, lastx = Double.NaN;
int prevy = -1;
@Override
public void accept(int pos, final int i) {
final int numSamples = samples[i];
if (numSamples == 0) {
return;
}
final double x_ij = X.get(i, j, Double.NaN);
if (Double.isNaN(x_ij)) {
countNaN.incr();
return;
}
if (lastx != x_ij) {
lastx = x_ij;
replaceCount.incr();
}
final int y_i = y[i];
if (Double.isNaN(prevx) || x_ij == prevx || y_i == prevy) {
prevx = x_ij;
prevy = y_i;
trueCount[y_i] += numSamples;
return;
}
final int tc = Math.sum(trueCount);
final int fc = n - tc;
// skip splitting this feature.
if (tc < _minSamplesSplit || fc < _minSamplesSplit) {
prevx = x_ij;
prevy = y_i;
trueCount[y_i] += numSamples;
return;
}
for (int l = 0; l < classes; l++) {
falseCount[l] = count[l] - trueCount[l];
}
final double gain =
impurity - (double) tc / n * impurity(trueCount, tc, _rule)
- (double) fc / n * impurity(falseCount, fc, _rule);
if (gain > splitNode.splitScore) {
// new best split
splitNode.splitFeature = j;
splitNode.quantitativeFeature = true;
splitNode.splitValue = (x_ij + prevx) / 2.d;
splitNode.splitScore = gain;
}
prevx = x_ij;
prevy = y_i;
trueCount[y_i] += numSamples;
}//apply()
});
final int countDistinctX = replaceCount.get() + (countNaN.get() == 0 ? 0 : 1);
if (countDistinctX <= 1) { // mark as a constant feature
this.constFeatures = ArrayUtils.sortedArraySet(constFeatures, j);
}
}
return splitNode;
}
/**
* Split the node into two children nodes. Returns true if split success.
*
* @return true if split occurred. false if the node is set to leaf.
*/
public boolean split(@Nullable final PriorityQueue nextSplits) {
if (node.splitFeature < 0) {
throw new IllegalStateException("Split a node with invalid feature.");
}
final IntPredicate goesLeft = getPredicate();
// split samples
final int tc, fc, pivot;
final double[] trueChildPosteriori = new double[_k],
falseChildPosteriori = new double[_k];
{
MutableInt tc_ = new MutableInt(0);
MutableInt fc_ = new MutableInt(0);
pivot = splitSamples(tc_, fc_, trueChildPosteriori, falseChildPosteriori, goesLeft);
tc = tc_.get();
fc = fc_.get();
}
if (tc < _minSamplesLeaf || fc < _minSamplesLeaf) {
node.markAsLeaf();
return false;
}
for (int i = 0; i < _k; i++) {
trueChildPosteriori[i] /= tc; // divide by zero never happens
falseChildPosteriori[i] /= fc;
}
partitionOrder(low, pivot, high, goesLeft);
int leaves = 0;
node.trueChild = new Node(trueChildPosteriori);
TrainNode trueChild =
new TrainNode(node.trueChild, depth + 1, low, pivot, tc, constFeatures.clone());
node.falseChild = new Node(falseChildPosteriori);
TrainNode falseChild =
new TrainNode(node.falseChild, depth + 1, pivot, high, fc, constFeatures);
this.constFeatures = null;
if (tc >= _minSamplesSplit && trueChild.findBestSplit()) {
if (nextSplits != null) {
nextSplits.add(trueChild);
} else {
if (trueChild.split(null) == false) {
leaves++;
}
}
} else {
leaves++;
}
if (fc >= _minSamplesSplit && falseChild.findBestSplit()) {
if (nextSplits != null) {
nextSplits.add(falseChild);
} else {
if (falseChild.split(null) == false) {
leaves++;
}
}
} else {
leaves++;
}
// Prune meaningless branches
if (leaves == 2) {// both left and right child are leaf node
if (node.trueChild.output == node.falseChild.output) {// found a meaningless branch
node.markAsLeaf();
return false;
}
}
_importance.incr(node.splitFeature, node.splitScore);
if (nextSplits == null) {
// For depth-first splitting, a posteriori is not needed for non-leaf nodes
node.posteriori = null;
}
return true;
}
/**
* @return Pivot to split samples
*/
private int splitSamples(@Nonnull final MutableInt tc, @Nonnull final MutableInt fc,
@Nonnull final double[] trueChildPosteriori,
@Nonnull final double[] falseChildPosteriori,
@Nonnull final IntPredicate goesLeft) {
final int[] sampleIndex = _sampleIndex;
final int[] samples = _samples;
final int[] y = _y;
int pivot = low;
for (int k = low, end = high; k < end; k++) {
final int i = sampleIndex[k];
final int numSamples = samples[i];
final int yi = y[i];
if (goesLeft.test(i)) {
tc.addValue(numSamples);
trueChildPosteriori[yi] += numSamples;
pivot++;
} else {
fc.addValue(numSamples);
falseChildPosteriori[yi] += numSamples;
}
}
return pivot;
}
/**
* Modifies {@link #_order} and {@link #_sampleIndex} by partitioning the range from low
* (inclusive) to high (exclusive) so that all elements i for which goesLeft(i) is true come
* before all elements for which it is false, but element ordering is otherwise preserved.
* The number of true values returned by goesLeft must equal split-low.
*
* @param low the low bound of the segment of the order arrays which will be partitioned.
* @param split where the partition's split point will end up.
* @param high the high bound of the segment of the order arrays which will be partitioned.
* @param goesLeft whether an element goes to the left side or the right side of the
* partition.
* @param buffer scratch space large enough to hold all elements for which goesLeft is
* false.
*/
private void partitionOrder(final int low, final int pivot, final int high,
@Nonnull final IntPredicate goesLeft) {
final int[] buf = new int[high - pivot];
_order.eachRow(new Consumer() {
@Override
public void accept(int col, @Nonnull final SparseIntArray row) {
partitionArray(row, low, pivot, high, goesLeft, buf);
}
});
partitionArray(_sampleIndex, low, pivot, high, goesLeft, buf);
}
@Nonnull
private IntPredicate getPredicate() {
if (node.quantitativeFeature) {
return new IntPredicate() {
@Override
public boolean test(int i) {
return _X.get(i, node.splitFeature, Double.NaN) <= node.splitValue;
}
};
} else {
return new IntPredicate() {
@Override
public boolean test(int i) {
return _X.get(i, node.splitFeature, Double.NaN) == node.splitValue;
}
};
}
}
}
private static void partitionArray(@Nonnull final SparseIntArray a, final int low,
final int pivot, final int high, @Nonnull final IntPredicate goesLeft,
@Nonnull final int[] buf) {
final int[] rowIndexes = a.keys();
final int[] rowPtrs = a.values();
final int size = a.size();
final int startPos = ArrayUtils.insertionPoint(rowIndexes, size, low);
final int endPos = ArrayUtils.insertionPoint(rowIndexes, size, high);
int pos = startPos, k = 0, j = low;
for (int i = startPos; i < endPos; i++) {
final int rowPtr = rowPtrs[i];
if (goesLeft.test(rowPtr)) {
rowIndexes[pos] = j;
rowPtrs[pos] = rowPtr;
pos++;
j++;
} else {
if (k >= buf.length) {
throw new IndexOutOfBoundsException(String.format(
"low=%d, pivot=%d, high=%d, a.size()=%d, buf.length=%d, i=%d, j=%d, k=%d, startPos=%d, endPos=%d\na=%s\nbuf=%s",
low, pivot, high, a.size(), buf.length, i, j, k, startPos, endPos,
a.toString(), Arrays.toString(buf)));
}
buf[k++] = rowPtr;
}
}
for (int i = 0; i < k; i++) {
rowIndexes[pos] = pivot + i;
rowPtrs[pos] = buf[i];
pos++;
}
if (pos != endPos) {
throw new IllegalStateException(
String.format("pos=%d, startPos=%d, endPos=%d, k=%d\na=%s", pos, startPos, endPos,
k, a.toString()));
}
}
/**
* Modifies an array in-place by partitioning the range from low (inclusive) to high (exclusive)
* so that all elements i for which goesLeft(i) is true come before all elements for which it is
* false, but element ordering is otherwise preserved. The number of true values returned by
* goesLeft must equal split-low. buf is scratch space large enough (i.e., at least high-split
* long) to hold all elements for which goesLeft is false.
*/
private static void partitionArray(@Nonnull final int[] a, final int low, final int pivot,
final int high, @Nonnull final IntPredicate goesLeft, @Nonnull final int[] buf) {
int j = low;
int k = 0;
for (int i = low; i < high; i++) {
if (i >= a.length) {
throw new IndexOutOfBoundsException(String.format(
"low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d", low,
pivot, high, a.length, buf.length, i, j, k));
}
final int rowPtr = a[i];
if (goesLeft.test(rowPtr)) {
a[j++] = rowPtr;
} else {
if (k >= buf.length) {
throw new IndexOutOfBoundsException(String.format(
"low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d",
low, pivot, high, a.length, buf.length, i, j, k));
}
buf[k++] = rowPtr;
}
}
if (k != high - pivot || j != pivot) {
throw new IndexOutOfBoundsException(
String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, j=%d, k=%d",
low, pivot, high, a.length, buf.length, j, k));
}
System.arraycopy(buf, 0, a, pivot, k);
}
/**
* Returns the impurity of a node.
*
* @param count the sample count in each class.
* @param n the number of samples in the node.
* @param rule the rule for splitting a node.
* @return the impurity of a node
*/
private static double impurity(@Nonnull final int[] count, final int n,
@Nonnull final SplitRule rule) {
double impurity = 0.0;
switch (rule) {
case GINI: {
impurity = 1.0;
for (int count_i : count) {
if (count_i > 0) {
double p = (double) count_i / n;
impurity -= p * p;
}
}
break;
}
case ENTROPY: {
for (int count_i : count) {
if (count_i > 0) {
double p = (double) count_i / n;
impurity -= p * Math.log2(p);
}
}
break;
}
case CLASSIFICATION_ERROR: {
impurity = 0.d;
for (int count_i : count) {
if (count_i > 0) {
impurity = Math.max(impurity, (double) count_i / n);
}
}
impurity = Math.abs(1.d - impurity);
break;
}
}
return impurity;
}
/**
* Prunes redundant leaves from the tree. In some cases, a node is split into two leaves that
* get assigned the same label, so this recursively combines leaves when it notices this
* situation.
*/
private static void pruneRedundantLeaves(@Nonnull final Node node, @Nonnull Vector importance) {
if (node.isLeaf()) {
return;
}
// The children might not be leaves now, but might collapse into leaves given the chance.
pruneRedundantLeaves(node.trueChild, importance);
pruneRedundantLeaves(node.falseChild, importance);
if (node.trueChild.isLeaf() && node.falseChild.isLeaf()
&& node.trueChild.output == node.falseChild.output) {
node.trueChild = null;
node.falseChild = null;
importance.decr(node.splitFeature, node.splitScore);
} else {
// a posteriori is not needed for non-leaf nodes
node.posteriori = null;
}
}
public DecisionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x, @Nonnull int[] y,
int numSamplesLeaf) {
this(nominalAttrs, x, y, x.numColumns(), Integer.MAX_VALUE, numSamplesLeaf, 2, 1, null, SplitRule.GINI, null);
}
public DecisionTree(@Nullable RoaringBitmap nominalAttrs, @Nullable Matrix x, @Nullable int[] y,
int numSamplesLeaf, @Nullable PRNG rand) {
this(nominalAttrs, x, y, x.numColumns(), Integer.MAX_VALUE, numSamplesLeaf, 2, 1, null, SplitRule.GINI, rand);
}
/**
* Constructor. Learns a classification tree for random forest.
*
* @param nominalAttrs the attribute properties.
* @param x the training instances.
* @param y the response variable.
* @param numVars the number of input variables to pick to split on at each node. It seems that
* dim/3 give generally good performance, where dim is the number of variables.
* @param maxLeafNodes the maximum number of leaf nodes in the tree.
* @param minSamplesSplit the number of minimum elements in a node to split
* @param minSamplesLeaf The minimum number of samples in a leaf node
* @param samples the sample set of instances for stochastic learning. samples[i] is the number
* of sampling for instance i.
* @param rule the splitting rule.
* @param rand random number generator
*/
public DecisionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x, @Nonnull int[] y,
int numVars, int maxDepth, int maxLeafNodes, int minSamplesSplit, int minSamplesLeaf,
@Nullable int[] samples, @Nonnull SplitRule rule, @Nullable PRNG rand) {
checkArgument(x, y, numVars, maxDepth, maxLeafNodes, minSamplesSplit, minSamplesLeaf);
this._X = x;
this._y = y;
this._k = Math.max(y) + 1;
if (_k < 2) {
throw new IllegalArgumentException("Only one class or negative class labels.");
}
if (nominalAttrs == null) {
nominalAttrs = new RoaringBitmap();
}
this._nominalAttrs = nominalAttrs;
this._numVars = numVars;
this._maxDepth = maxDepth;
// min_sample_leaf >= 2 is satisfied iff min_sample_split >= 4
// So, split only happens when samples in intermediate nodes has >= 2 * min_sample_leaf nodes.
if (minSamplesSplit < minSamplesLeaf * 2) {
if (logger.isInfoEnabled()) {
logger.info(String.format(
"min_sample_leaf = %d replaces min_sample_split = %d with min_sample_split = %d",
minSamplesLeaf, minSamplesSplit, minSamplesLeaf * 2));
}
minSamplesSplit = minSamplesLeaf * 2;
}
this._minSamplesSplit = minSamplesSplit;
this._minSamplesLeaf = minSamplesLeaf;
this._rule = rule;
this._importance = x.isSparse() ? new SparseVector() : new DenseVector(x.numColumns());
this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand;
final int n = y.length;
final int[] count = new int[_k];
final int[] sampleIndex;
int totalNumSamples = 0;
if (samples == null) {
samples = new int[n];
sampleIndex = new int[n];
for (int i = 0; i < n; i++) {
samples[i] = 1;
count[y[i]]++;
sampleIndex[i] = i;
}
totalNumSamples = n;
} else {
final IntArrayList positions = new IntArrayList(n);
for (int i = 0; i < n; i++) {
final int sample = samples[i];
if (sample != 0) {
count[y[i]] += sample;
positions.add(i);
totalNumSamples += sample;
}
}
sampleIndex = positions.toArray(true);
}
this._samples = samples;
this._order = SmileExtUtils.sort(nominalAttrs, x, samples);
this._sampleIndex = sampleIndex;
final double[] posteriori = new double[_k];
for (int i = 0; i < _k; i++) {
posteriori[i] = (double) count[i] / n;
}
this._root = new Node(Math.whichMax(count), posteriori);
final TrainNode trainRoot =
new TrainNode(_root, 1, 0, _sampleIndex.length, totalNumSamples);
if (maxLeafNodes == Integer.MAX_VALUE) { // depth-first split
if (trainRoot.findBestSplit()) {
trainRoot.split(null);
}
} else { // best-first split
// Priority queue for best-first tree growing.
final PriorityQueue nextSplits = new PriorityQueue();
// Now add splits to the tree until max tree size is reached
if (trainRoot.findBestSplit()) {
nextSplits.add(trainRoot);
}
// Pop best leaf from priority queue, split it, and push
// children nodes into the queue if possible.
for (int leaves = 1; leaves < maxLeafNodes; leaves++) {
// parent is the leaf to split
TrainNode node = nextSplits.poll();
if (node == null) {
break;
}
if (!node.split(nextSplits)) { // Split the parent node into two children nodes
leaves--;
}
}
pruneRedundantLeaves(_root, _importance);
}
}
@VisibleForTesting
Node getRootNode() {
return _root;
}
private static void checkArgument(@Nonnull Matrix x, @Nonnull int[] y, int numVars,
int maxDepth, int maxLeafNodes, int minSamplesSplit, int minSamplesLeaf) {
if (x.numRows() != y.length) {
throw new IllegalArgumentException(
String.format("The sizes of X and Y don't match: %d != %d", x.numRows(), y.length));
}
if (y.length == 0) {
throw new IllegalArgumentException("No training example given");
}
if (numVars <= 0 || numVars > x.numColumns()) {
throw new IllegalArgumentException(
"Invalid number of variables to split on at a node of the tree: " + numVars);
}
if (maxDepth < 2) {
throw new IllegalArgumentException("maxDepth should be greater than 1: " + maxDepth);
}
if (maxLeafNodes < 2) {
throw new IllegalArgumentException("Invalid maximum leaves: " + maxLeafNodes);
}
if (minSamplesSplit < 2) {
throw new IllegalArgumentException(
"Invalid minimum number of samples required to split an internal node: "
+ minSamplesSplit);
}
if (minSamplesLeaf < 1) {
throw new IllegalArgumentException(
"Invalid minimum size of leaf nodes: " + minSamplesLeaf);
}
}
/**
* Returns the variable importance. Every time a split of a node is made on variable the (GINI,
* information gain, etc.) impurity criterion for the two descendent nodes is less than the
* parent node. Adding up the decreases for each individual variable over the tree gives a
* simple measure of variable importance.
*
* @return the variable importance
*/
@Nonnull
public Vector importance() {
return _importance;
}
@VisibleForTesting
public int predict(@Nonnull final double[] x) {
return predict(new DenseVector(x));
}
@Override
public int predict(@Nonnull final Vector x) {
return _root.predict(x);
}
public void predict(@Nonnull final Vector x, @Nonnull final PredictionHandler handler) {
_root.predict(x, handler);
}
/**
* Predicts the class label of an instance and also calculate a posteriori probabilities. Not
* supported.
*/
public int predict(Vector x, double[] posteriori) {
throw new UnsupportedOperationException("Not supported.");
}
@Nonnull
public String predictJsCodegen(@Nonnull final String[] featureNames,
@Nonnull final String[] classNames) {
StringBuilder buf = new StringBuilder(1024);
_root.exportJavascript(buf, featureNames, classNames, 0);
return buf.toString();
}
@Deprecated
@Nonnull
public String predictOpCodegen(@Nonnull final String sep) {
List opslist = new ArrayList();
_root.opCodegen(opslist, 0);
opslist.add("call end");
String scripts = StringUtils.concat(opslist, sep);
return scripts;
}
@Nonnull
public byte[] serialize(boolean compress) throws HiveException {
try {
if (compress) {
return ObjectUtils.toCompressedBytes(_root);
} else {
return ObjectUtils.toBytes(_root);
}
} catch (IOException ioe) {
throw new HiveException("IOException cause while serializing DecisionTree object", ioe);
} catch (Exception e) {
throw new HiveException("Exception cause while serializing DecisionTree object", e);
}
}
@Nonnull
public static Node deserialize(@Nonnull final byte[] serializedObj, final int length,
final boolean compressed) throws HiveException {
final Node root = new Node();
try {
if (compressed) {
ObjectUtils.readCompressedObject(serializedObj, 0, length, root);
} else {
ObjectUtils.readObject(serializedObj, length, root);
}
} catch (IOException ioe) {
throw new HiveException("IOException cause while deserializing DecisionTree object",
ioe);
} catch (Exception e) {
throw new HiveException("Exception cause while deserializing DecisionTree object", e);
}
return root;
}
@Override
public String toString() {
return _root == null ? "" : predictJsCodegen(null, null);
}
}