
org.apache.spark.mllib.tree.DecisionTree.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import scala.collection.mutable
import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.impl._
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
/**
* A class which implements a decision tree learning algorithm for classification and regression.
* It supports both continuous and categorical features.
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of algorithm (classification, regression, etc.), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
*/
@Since("1.0.0")
class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
extends Serializable with Logging {
strategy.assertValid()
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return DecisionTreeModel that can be used for prediction
*/
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rfModel = rf.run(input)
rfModel.trees(0)
}
}
@Since("1.0.0")
object DecisionTree extends Serializable with Logging {
/**
* Method to train a decision tree model.
* The method supports binary and multiclass classification and regression.
*
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
* is recommended to clearly separate classification and regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of algorithm (classification, regression, etc.), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
* @return DecisionTreeModel that can be used for prediction
*/
@Since("1.0.0")
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
new DecisionTree(strategy).run(input)
}
/**
* Method to train a decision tree model.
* The method supports binary and multiclass classification and regression.
*
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
* is recommended to clearly separate classification and regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param algo algorithm, classification or regression
* @param impurity impurity criterion used for information gain calculation
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* @return DecisionTreeModel that can be used for prediction
*/
@Since("1.0.0")
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth)
new DecisionTree(strategy).run(input)
}
/**
* Method to train a decision tree model.
* The method supports binary and multiclass classification and regression.
*
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
* is recommended to clearly separate classification and regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param algo algorithm, classification or regression
* @param impurity impurity criterion used for information gain calculation
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* @param numClasses number of classes for classification. Default value of 2.
* @return DecisionTreeModel that can be used for prediction
*/
@Since("1.2.0")
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClasses)
new DecisionTree(strategy).run(input)
}
/**
* Method to train a decision tree model.
* The method supports binary and multiclass classification and regression.
*
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
* is recommended to clearly separate classification and regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* @param numClasses number of classes for classification. Default value of 2.
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo Map storing arity of categorical features.
* E.g., an entry (n -> k) indicates that feature n is categorical
* with k categories indexed from 0: {0, 1, ..., k-1}.
* @return DecisionTreeModel that can be used for prediction
*/
@Since("1.0.0")
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
new DecisionTree(strategy).run(input)
}
/**
* Method to train a decision tree model for binary or multiclass classification.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels should take values {0, 1, ..., numClasses-1}.
* @param numClasses number of classes for classification.
* @param categoricalFeaturesInfo Map storing arity of categorical features.
* E.g., an entry (n -> k) indicates that feature n is categorical
* with k categories indexed from 0: {0, 1, ..., k-1}.
* @param impurity Criterion used for information gain calculation.
* Supported values: "gini" (recommended) or "entropy".
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* (suggested value: 5)
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 32)
* @return DecisionTreeModel that can be used for prediction
*/
@Since("1.1.0")
def trainClassifier(
input: RDD[LabeledPoint],
numClasses: Int,
categoricalFeaturesInfo: Map[Int, Int],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
val impurityType = Impurities.fromString(impurity)
train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,
categoricalFeaturesInfo)
}
/**
* Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
*/
@Since("1.1.0")
def trainClassifier(
input: JavaRDD[LabeledPoint],
numClasses: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
trainClassifier(input.rdd, numClasses,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
impurity, maxDepth, maxBins)
}
/**
* Method to train a decision tree model for regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels are real numbers.
* @param categoricalFeaturesInfo Map storing arity of categorical features.
* E.g., an entry (n -> k) indicates that feature n is categorical
* with k categories indexed from 0: {0, 1, ..., k-1}.
* @param impurity Criterion used for information gain calculation.
* Supported values: "variance".
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* (suggested value: 5)
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 32)
* @return DecisionTreeModel that can be used for prediction
*/
@Since("1.1.0")
def trainRegressor(
input: RDD[LabeledPoint],
categoricalFeaturesInfo: Map[Int, Int],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
val impurityType = Impurities.fromString(impurity)
train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo)
}
/**
* Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
*/
@Since("1.1.0")
def trainRegressor(
input: JavaRDD[LabeledPoint],
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
trainRegressor(input.rdd,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
impurity, maxDepth, maxBins)
}
/**
* Get the node index corresponding to this data point.
* This function mimics prediction, passing an example from the root node down to a leaf
* or unsplit node; that node's index is returned.
*
* @param node Node in tree from which to classify the given data point.
* @param binnedFeatures Binned feature vector for data point.
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param unorderedFeatures Set of indices of unordered features.
* @return Leaf index if the data point reaches a leaf.
* Otherwise, last node reachable in tree matching this example.
* Note: This is the global node index, i.e., the index used in the tree.
* This index is different from the index used during training a particular
* group of nodes on one call to [[findBestSplits()]].
*/
private def predictNodeIndex(
node: Node,
binnedFeatures: Array[Int],
bins: Array[Array[Bin]],
unorderedFeatures: Set[Int]): Int = {
if (node.isLeaf || node.split.isEmpty) {
// Node is either leaf, or has not yet been split.
node.id
} else {
val featureIndex = node.split.get.feature
val splitLeft = node.split.get.featureType match {
case Continuous => {
val binIndex = binnedFeatures(featureIndex)
val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
// bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
// We do not need to check lowSplit since bins are separated by splits.
featureValueUpperBound <= node.split.get.threshold
}
case Categorical => {
val featureValue = binnedFeatures(featureIndex)
node.split.get.categories.contains(featureValue)
}
case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
}
if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
// Return index from next layer of nodes to train
if (splitLeft) {
Node.leftChildIndex(node.id)
} else {
Node.rightChildIndex(node.id)
}
} else {
if (splitLeft) {
predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures)
} else {
predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures)
}
}
}
}
/**
* Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
*
* For ordered features, a single bin is updated.
* For unordered features, bins correspond to subsets of categories; either the left or right bin
* for each subset is updated.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (feature, bin).
* @param treePoint Data point being aggregated.
* @param splits possible splits indexed (numFeatures)(numSplits)
* @param unorderedFeatures Set of indices of unordered features.
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
splits: Array[Array[Split]],
unorderedFeatures: Set[Int],
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
// Use subsampled features
featuresForNode.get.size
} else {
// Use all features
agg.metadata.numFeatures
}
// Iterate over features.
var featureIndexIdx = 0
while (featureIndexIdx < numFeaturesPerNode) {
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
if (unorderedFeatures.contains(featureIndex)) {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
agg.getLeftRightFeatureOffsets(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
} else {
agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
}
splitIndex += 1
}
} else {
// Ordered feature
val binIndex = treePoint.binnedFeatures(featureIndex)
agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
}
featureIndexIdx += 1
}
}
/**
* Helper for binSeqOp, for regression and for classification with only ordered features.
*
* For each feature, the sufficient statistics of one bin are updated.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (feature, bin).
* @param treePoint Data point being aggregated.
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def orderedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val label = treePoint.label
// Iterate over features.
if (featuresForNode.nonEmpty) {
// Use subsampled features
var featureIndexIdx = 0
while (featureIndexIdx < featuresForNode.get.size) {
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
agg.update(featureIndexIdx, binIndex, label, instanceWeight)
featureIndexIdx += 1
}
} else {
// Use all features
val numFeatures = agg.metadata.numFeatures
var featureIndex = 0
while (featureIndex < numFeatures) {
val binIndex = treePoint.binnedFeatures(featureIndex)
agg.update(featureIndex, binIndex, label, instanceWeight)
featureIndex += 1
}
}
}
/**
* Given a group of nodes, this finds the best split for each node.
*
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
* @param metadata Learning and dataset metadata
* @param topNodes Root node for each tree. Used for matching instances with nodes.
* @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
* @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
* where nodeIndexInfo stores the index in the group and the
* feature subsets (if using feature subsets).
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
* Updated with new non-leaf nodes which are created.
* @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
* each value in the array is the data point's node Id
* for a corresponding tree. This is used to prevent the need
* to pass the entire tree to the executors during
* the node stat aggregation phase.
*/
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
metadata: DecisionTreeMetadata,
topNodes: Array[Node],
nodesForGroup: Map[Int, Array[Node]],
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
nodeQueue: mutable.Queue[(Int, Node)],
timer: TimeTracker = new TimeTracker,
nodeIdCache: Option[NodeIdCache] = None): Unit = {
/*
* The high-level descriptions of the best split optimizations are noted here.
*
* *Group-wise training*
* We perform bin calculations for groups of nodes to reduce the number of
* passes over the data. Each iteration requires more computation and storage,
* but saves several iterations over the data.
*
* *Bin-wise computation*
* We use a bin-wise best split computation strategy instead of a straightforward best split
* computation strategy. Instead of analyzing each sample for contribution to the left/right
* child node impurity of every split, we first categorize each feature of a sample into a
* bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
* to calculate information gain for each split.
*
* *Aggregation over partitions*
* Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
* the number of splits in advance. Thus, we store the aggregates (at the appropriate
* indices) in a single array for all bins and rely upon the RDD aggregate method to
* drastically reduce the communication overhead.
*/
// numNodes: Number of nodes in this group
val numNodes = nodesForGroup.values.map(_.size).sum
logDebug("numNodes = " + numNodes)
logDebug("numFeatures = " + metadata.numFeatures)
logDebug("numClasses = " + metadata.numClasses)
logDebug("isMulticlass = " + metadata.isMulticlass)
logDebug("isMulticlassWithCategoricalFeatures = " +
metadata.isMulticlassWithCategoricalFeatures)
logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
/**
* Performs a sequential aggregation over a partition for a particular tree and node.
*
* For each feature, the aggregate sufficient statistics are updated for the relevant
* bins.
*
* @param treeIndex Index of the tree that we want to perform aggregation for.
* @param nodeInfo The node info for the tree node.
* @param agg Array storing aggregate calculation, with a set of sufficient statistics
* for each (node, feature, bin).
* @param baggedPoint Data point being aggregated.
*/
def nodeBinSeqOp(
treeIndex: Int,
nodeInfo: RandomForest.NodeIndexInfo,
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint]): Unit = {
if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
}
}
/**
* Performs a sequential aggregation over a partition.
*
* Each data point contributes to one node. For each feature,
* the aggregate sufficient statistics are updated for the relevant bins.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (node, feature, bin).
* @param baggedPoint Data point being aggregated.
* @return agg
*/
def binSeqOp(
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
agg
}
/**
* Do the same thing as binSeqOp, but with nodeIdCache.
*/
def binSeqOpWithNodeIdCache(
agg: Array[DTStatsAggregator],
dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val baggedPoint = dataPoint._1
val nodeIdCache = dataPoint._2
val nodeIndex = nodeIdCache(treeIndex)
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
agg
}
/**
* Get node index in group --> features indices map,
* which is a short cut to find feature indices for a node given node index in group
* @param treeToNodeToIndexInfo
* @return
*/
def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]])
: Option[Map[Int, Array[Int]]] = if (!metadata.subsamplingFeatures) {
None
} else {
val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
assert(nodeIndexInfo.featureSubset.isDefined)
mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
}
}
Some(mutableNodeToFeatures.toMap)
}
// array of nodes to train indexed by node index in group
val nodes = new Array[Node](numNodes)
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
}
}
// Calculate best splits for all nodes in the group
timer.start("chooseSplits")
// In each partition, iterate all instances and compute aggregate stats for each node,
// yield an (nodeIndex, nodeAggregateStats) pair for each node.
// After a `reduceByKey` operation,
// stats of a node will be shuffled to a particular partition and be combined together,
// then best splits for nodes are found there.
// Finally, only best Splits for nodes are collected to driver to construct decision tree.
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
}
new DTStatsAggregator(metadata, featuresForNode)
}
// iterator all instances in current partition and update aggregate stats
points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
}
} else {
input.mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
}
new DTStatsAggregator(metadata, featuresForNode)
}
// iterator all instances in current partition and update aggregate stats
points.foreach(binSeqOp(nodeStatsAggregators, _))
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
}
}
val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
.map { case (nodeIndex, aggStats) =>
val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
nodeToFeatures(nodeIndex)
}
// find best split for each node
val (split: Split, stats: InformationGainStats, predict: Predict) =
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
(nodeIndex, (split, stats, predict))
}.collectAsMap()
timer.stop("chooseSplits")
val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
Array.fill[mutable.Map[Int, NodeIndexUpdater]](
metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
} else {
null
}
// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
val nodeIndex = node.id
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val (split: Split, stats: InformationGainStats, predict: Predict) =
nodeToBestSplits(aggNodeIndex)
logDebug("best split = " + split)
// Extract info for this node. Create children if not leaf.
val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
assert(node.id == nodeIndex)
node.predict = predict
node.isLeaf = isLeaf
node.stats = Some(stats)
node.impurity = stats.impurity
logDebug("Node = " + node)
if (!isLeaf) {
node.split = Some(split)
val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),
stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
if (nodeIdCache.nonEmpty) {
val nodeIndexUpdater = NodeIndexUpdater(
split = split,
nodeIndex = nodeIndex)
nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
}
// enqueue left child and right child if they are not leaves
if (!leftChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.leftNode.get))
}
if (!rightChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.rightNode.get))
}
logDebug("leftChildIndex = " + node.leftNode.get.id +
", impurity = " + stats.leftImpurity)
logDebug("rightChildIndex = " + node.rightNode.get.id +
", impurity = " + stats.rightImpurity)
}
}
}
if (nodeIdCache.nonEmpty) {
// Update the cache if needed.
nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins)
}
}
/**
* Calculate the information gain for a given (feature, split) based upon left/right aggregates.
* @param leftImpurityCalculator left node aggregates for this (feature, split)
* @param rightImpurityCalculator right node aggregate for this (feature, split)
* @return information gain and statistics for split
*/
private def calculateGainForSplit(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
metadata: DecisionTreeMetadata,
impurity: Double): InformationGainStats = {
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
// If left child or right child doesn't satisfy minimum instances per node,
// then this split is invalid, return invalid information gain stats.
if ((leftCount < metadata.minInstancesPerNode) ||
(rightCount < metadata.minInstancesPerNode)) {
return InformationGainStats.invalidInformationGainStats
}
val totalCount = leftCount + rightCount
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
val leftWeight = leftCount / totalCount.toDouble
val rightWeight = rightCount / totalCount.toDouble
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
// if information gain doesn't satisfy minimum information gain,
// then this split is invalid, return invalid information gain stats.
if (gain < metadata.minInfoGain) {
return InformationGainStats.invalidInformationGainStats
}
// calculate left and right predict
val leftPredict = calculatePredict(leftImpurityCalculator)
val rightPredict = calculatePredict(rightImpurityCalculator)
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
leftPredict, rightPredict)
}
private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
val predict = impurityCalculator.predict
val prob = impurityCalculator.prob(predict)
new Predict(predict, prob)
}
/**
* Calculate predict value for current node, given stats of any split.
* Note that this function is called only once for each node.
* @param leftImpurityCalculator left node aggregates for a split
* @param rightImpurityCalculator right node aggregates for a split
* @return predict value and impurity for current node
*/
private def calculatePredictImpurity(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
val predict = calculatePredict(parentNodeAgg)
val impurity = parentNodeAgg.calculate()
(predict, impurity)
}
/**
* Find the best split for a node.
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
node: Node): (Split, InformationGainStats, Predict) = {
// calculate predict and impurity if current node is top node
val level = Node.indexToLevel(node.id)
var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
None
} else {
Some((node.predict, node.impurity))
}
// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
}
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)
/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
* splits are considered. (With K categories, we consider K - 1 possible splits.)
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = Range(0, numBins).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
categoryStats.calculate()
} else if (binAggregates.metadata.isClassification) {
// For categorical variables in binary classification,
// the bins are ordered by the count of class 1.
categoryStats.stats(1)
} else {
// For categorical variables in regression,
// the bins are ordered by the prediction.
categoryStats.predict
}
} else {
Double.MaxValue
}
(featureValue, centroid)
}
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
// bins sorted by centroids
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
logDebug("Sorted centroids for categorical variable = " +
categoriesSortedByCentroid.mkString(","))
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
var splitIndex = 0
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
}
// lastCategory = index of bin with total aggregates for this (node, feature)
val lastCategory = categoriesSortedByCentroid.last._1
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val featureValue = categoriesSortedByCentroid(splitIndex)._1
val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)
(bestSplit, bestSplitStats, predictWithImpurity.get._1)
}
/**
* Returns splits and bins for decision tree calculation.
* Continuous and categorical features are handled differently.
*
* Continuous features:
* For each feature, there are numBins - 1 possible splits representing the possible binary
* decisions at each node in the tree.
* This finds locations (feature values) for splits using a subsample of the data.
*
* Categorical features:
* For each feature, there is 1 bin per split.
* Splits and bins are handled in 2 ways:
* (a) "unordered features"
* For multiclass classification with a low-arity feature
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
* the feature is split based on subsets of categories.
* (b) "ordered features"
* For regression and binary classification,
* and for multiclass classification with a high-arity feature,
* there is one bin per category.
*
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param metadata Learning and dataset metadata
* @return A tuple of (splits, bins).
* Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
* of size (numFeatures, numSplits).
* Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
* of size (numFeatures, numBins).
*/
protected[tree] def findSplitsBins(
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
logDebug("isMulticlass = " + metadata.isMulticlass)
val numFeatures = metadata.numFeatures
// Sample the input only if there are continuous features.
val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
val sampledInput = if (continuousFeatures.nonEmpty) {
// Calculate the number of samples for approximate quantile calculation.
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
val fraction = if (requiredSamples < metadata.numExamples) {
requiredSamples.toDouble / metadata.numExamples
} else {
1.0
}
logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt())
} else {
input.sparkContext.emptyRDD[LabeledPoint]
}
metadata.quantileStrategy match {
case Sort =>
findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
case MinMax =>
throw new UnsupportedOperationException("minmax not supported yet.")
case ApproxHist =>
throw new UnsupportedOperationException("approximate histogram not supported yet.")
}
}
private def findSplitsBinsBySorting(
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata,
continuousFeatures: IndexedSeq[Int]): (Array[Array[Split]], Array[Array[Bin]]) = {
def findSplits(
featureIndex: Int,
featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = {
val splits = {
val featureSplits = findSplitsForContinuousFeature(
featureSamples.toArray,
metadata,
featureIndex)
logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}")
featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil))
}
val bins = {
val lowSplit = new DummyLowSplit(featureIndex, Continuous)
val highSplit = new DummyHighSplit(featureIndex, Continuous)
// tack the dummy splits on either side of the computed splits
val allSplits = lowSplit +: splits.toSeq :+ highSplit
// slide across the split points pairwise to allocate the bins
allSplits.sliding(2).map {
case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue)
}.toArray
}
(featureIndex, (splits, bins))
}
val continuousSplits = {
// reduce the parallelism for split computations when there are less
// continuous features than input partitions. this prevents tasks from
// being spun up that will definitely do no work.
val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
input
.flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
.groupByKey(numPartitions)
.map { case (k, v) => findSplits(k, v) }
.collectAsMap()
}
val numFeatures = metadata.numFeatures
val (splits, bins) = Range(0, numFeatures).unzip {
case i if metadata.isContinuous(i) =>
val (split, bin) = continuousSplits(i)
metadata.setNumSplits(i, split.length)
(split, bin)
case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
val featureArity = metadata.featureArity(i)
val split = Range(0, metadata.numSplits(i)).map { splitIndex =>
val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
new Split(i, Double.MinValue, Categorical, categories)
}
// For unordered categorical features, there is no need to construct the bins.
// since there is a one-to-one correspondence between the splits and the bins.
(split.toArray, Array.empty[Bin])
case i if metadata.isCategorical(i) =>
// Ordered features
// Bins correspond to feature values, so we do not need to compute splits or bins
// beforehand. Splits are constructed as needed during training.
(Array.empty[Split], Array.empty[Bin])
}
(splits.toArray, bins.toArray)
}
/**
* Nested method to extract list of eligible categories given an index. It extracts the
* position of ones in a binary representation of the input. If binary
* representation of an number is 01101 (13), the output list should (3.0, 2.0,
* 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
*/
private[tree] def extractMultiClassCategories(
input: Int,
maxFeatureValue: Int): List[Double] = {
var categories = List[Double]()
var j = 0
var bitShiftedInput = input
while (j < maxFeatureValue) {
if (bitShiftedInput % 2 != 0) {
// updating the list of categories.
categories = j.toDouble :: categories
}
// Right shift by one
bitShiftedInput = bitShiftedInput >> 1
j += 1
}
categories
}
/**
* Find splits for a continuous feature
* NOTE: Returned number of splits is set based on `featureSamples` and
* could be different from the specified `numSplits`.
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
* @param featureSamples feature values of each sample
* @param metadata decision tree metadata
* NOTE: `metadata.numbins` will be changed accordingly
* if there are not enough splits to be found
* @param featureIndex feature index to find splits
* @return array of splits
*/
private[tree] def findSplitsForContinuousFeature(
featureSamples: Array[Double],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
val splits = {
val numSplits = metadata.numSplits(featureIndex)
// get count for each distinct value
val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
m + ((x, m.getOrElse(x, 0) + 1))
}
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
// if possible splits is not enough or just enough, just return all possible splits
val possibleSplits = valueCounts.length
if (possibleSplits <= numSplits) {
valueCounts.map(_._1)
} else {
// stride between splits
val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
logDebug("stride = " + stride)
// iterate `valueCount` to find splits
val splitsBuilder = Array.newBuilder[Double]
var index = 1
// currentCount: sum of counts of values that have been visited
var currentCount = valueCounts(0)._2
// targetCount: target value for `currentCount`.
// If `currentCount` is closest value to `targetCount`,
// then current value is a split threshold.
// After finding a split threshold, `targetCount` is added by stride.
var targetCount = stride
while (index < valueCounts.length) {
val previousCount = currentCount
currentCount += valueCounts(index)._2
val previousGap = math.abs(previousCount - targetCount)
val currentGap = math.abs(currentCount - targetCount)
// If adding count of current value to currentCount
// makes the gap between currentCount and targetCount smaller,
// previous value is a split threshold.
if (previousGap < currentGap) {
splitsBuilder += valueCounts(index - 1)._1
targetCount += stride
}
index += 1
}
splitsBuilder.result()
}
}
// TODO: Do not fail; just ignore the useless feature.
assert(splits.length > 0,
s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
" Please remove this feature and then try again.")
// the split metadata must be updated on the driver
splits
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy