org.apache.spark.ml.tree.Node.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.tree
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict}
/**
* Decision tree node interface.
*/
sealed abstract class Node extends Serializable {
// TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
// code into the new API and deprecate the old API. SPARK-3727
/** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */
def prediction: Double
/** Impurity measure at this node (for training data) */
def impurity: Double
/**
* Statistics aggregated from training data at this node, used to compute prediction, impurity,
* and probabilities.
* For classification, the array of class counts must be normalized to a probability distribution.
*/
private[ml] def impurityStats: ImpurityCalculator
/** Recursive prediction helper method */
private[ml] def predictImpl(features: Vector): LeafNode
/** Recursive prediction helper method */
private[ml] def predictBinned(binned: Array[Int], splits: Array[Array[Split]]): LeafNode
/**
* Get the number of nodes in tree below this node, including leaf nodes.
* E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
*/
private[tree] def numDescendants: Int
/**
* Recursive print function.
* @param indentFactor The number of spaces to add to each level of indentation.
*/
private[tree] def subtreeToString(indentFactor: Int = 0): String
/**
* Get depth of tree from this node.
* E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes.
*/
private[tree] def subtreeDepth: Int
/**
* Create a copy of this node in the old Node format, recursively creating child nodes as needed.
* @param id Node ID using old format IDs
*/
private[ml] def toOld(id: Int): OldNode
/**
* Trace down the tree, and return the largest feature index used in any split.
* @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
*/
private[ml] def maxSplitFeatureIndex(): Int
/** Returns a deep copy of the subtree rooted at this node. */
private[tree] def deepCopy(): Node
}
private[ml] object Node {
/**
* Create a new Node from the old Node format, recursively creating child nodes as needed.
*/
def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
if (oldNode.isLeaf) {
// TODO: Once the implementation has been moved to this API, then include sufficient
// statistics here.
new LeafNode(prediction = oldNode.predict.predict,
impurity = oldNode.impurity, impurityStats = null)
} else {
val gain = if (oldNode.stats.nonEmpty) {
oldNode.stats.get.gain
} else {
0.0
}
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
}
}
}
/**
* Decision tree leaf node.
* @param prediction Prediction this node makes
* @param impurity Impurity measure at this node (for training data)
*/
class LeafNode private[ml] (
override val prediction: Double,
override val impurity: Double,
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
override def toString: String =
s"LeafNode(prediction = $prediction, impurity = $impurity)"
override private[ml] def predictImpl(features: Vector): LeafNode = this
override private[ml] def predictBinned(
binned: Array[Int],
splits: Array[Array[Split]]): LeafNode = this
override private[tree] def numDescendants: Int = 0
override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
val prefix: String = " " * indentFactor
prefix + s"Predict: $prediction\n"
}
override private[tree] def subtreeDepth: Int = 0
override private[ml] def toOld(id: Int): OldNode = {
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
impurity, isLeaf = true, None, None, None, None)
}
override private[ml] def maxSplitFeatureIndex(): Int = -1
override private[tree] def deepCopy(): Node = {
new LeafNode(prediction, impurity, impurityStats)
}
}
/**
* Internal Decision Tree node.
* @param prediction Prediction this node would make if it were a leaf node
* @param impurity Impurity measure at this node (for training data)
* @param gain Information gain value. Values less than 0 indicate missing values;
* this quirk will be removed with future updates.
* @param leftChild Left-hand child node
* @param rightChild Right-hand child node
* @param split Information about the test used to split to the left or right child.
*/
class InternalNode private[ml] (
override val prediction: Double,
override val impurity: Double,
val gain: Double,
val leftChild: Node,
val rightChild: Node,
val split: Split,
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
// Note to developers: The constructor argument impurityStats should be reconsidered before we
// make the constructor public. We may be able to improve the representation.
override def toString: String = {
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
}
override private[ml] def predictImpl(features: Vector): LeafNode = {
var node: Node = this
while (node.isInstanceOf[InternalNode]) {
val n = node.asInstanceOf[InternalNode]
if (n.split.shouldGoLeft(features)) {
node = n.leftChild
} else {
node = n.rightChild
}
}
node.asInstanceOf[LeafNode]
}
override private[ml] def predictBinned(
binned: Array[Int],
splits: Array[Array[Split]]): LeafNode = {
var node: Node = this
while (node.isInstanceOf[InternalNode]) {
val n = node.asInstanceOf[InternalNode]
val i = n.split.featureIndex
if (n.split.shouldGoLeft(binned(i), splits(i))) {
node = n.leftChild
} else {
node = n.rightChild
}
}
node.asInstanceOf[LeafNode]
}
override private[tree] def numDescendants: Int = {
2 + leftChild.numDescendants + rightChild.numDescendants
}
override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
val prefix: String = " " * indentFactor
prefix + s"If (${InternalNode.splitToString(split, left = true)})\n" +
leftChild.subtreeToString(indentFactor + 1) +
prefix + s"Else (${InternalNode.splitToString(split, left = false)})\n" +
rightChild.subtreeToString(indentFactor + 1)
}
override private[tree] def subtreeDepth: Int = {
1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth)
}
override private[ml] def toOld(id: Int): OldNode = {
assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+ " since the old API does not support deep trees.")
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity,
isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
Some(rightChild.toOld(OldNode.rightChildIndex(id))),
Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
new OldPredict(leftChild.prediction, prob = 0.0),
new OldPredict(rightChild.prediction, prob = 0.0))))
}
override private[ml] def maxSplitFeatureIndex(): Int = {
math.max(split.featureIndex,
math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
}
override private[tree] def deepCopy(): Node = {
new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(),
split, impurityStats)
}
}
private object InternalNode {
/**
* Helper method for [[Node.subtreeToString()]].
* @param split Split to print
* @param left Indicates whether this is the part of the split going to the left,
* or that going to the right.
*/
private def splitToString(split: Split, left: Boolean): String = {
val featureStr = s"feature ${split.featureIndex}"
split match {
case contSplit: ContinuousSplit =>
if (left) {
s"$featureStr <= ${contSplit.threshold}"
} else {
s"$featureStr > ${contSplit.threshold}"
}
case catSplit: CategoricalSplit =>
val categoriesStr = catSplit.leftCategories.mkString("{", ",", "}")
if (left) {
s"$featureStr in $categoriesStr"
} else {
s"$featureStr not in $categoriesStr"
}
}
}
}
/**
* Version of a node used in learning. This uses vars so that we can modify nodes as we split the
* tree by adding children, etc.
*
* For now, we use node IDs. These will be kept internal since we hope to remove node IDs
* in the future, or at least change the indexing (so that we can support much deeper trees).
*
* This node can either be:
* - a leaf node, with leftChild, rightChild, split set to null, or
* - an internal node, with all values set
*
* @param id We currently use the same indexing as the old implementation in
* [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
* @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree,
* so that we do not need to consider splitting it further.
* @param stats Impurity statistics for this node.
*/
private[tree] class LearningNode(
var id: Int,
var leftChild: Option[LearningNode],
var rightChild: Option[LearningNode],
var split: Option[Split],
var isLeaf: Boolean,
var stats: ImpurityStats) extends Serializable {
def toNode: Node = toNode(prune = true)
/**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
*/
def toNode(prune: Boolean = true): Node = {
if (!leftChild.isEmpty || !rightChild.isEmpty) {
assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null,
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
(leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction =>
new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
case (l, r) =>
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
l, r, split.get, stats.impurityCalculator)
}
} else {
if (stats.valid) {
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
stats.impurityCalculator)
} else {
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
}
}
}
/**
* Get the node index corresponding to this data point.
* This function mimics prediction, passing an example from the root node down to a leaf
* or unsplit node; that node's index is returned.
*
* @param binnedFeatures Binned feature vector for data point.
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @return Leaf index if the data point reaches a leaf.
* Otherwise, last node reachable in tree matching this example.
* Note: This is the global node index, i.e., the index used in the tree.
* This index is different from the index used during training a particular
* group of nodes on one call to
* [[org.apache.spark.ml.tree.impl.RandomForest.findBestSplits()]].
*/
def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
var node = this
while (!node.isLeaf && node.split.nonEmpty) {
val split = node.split.get
val featureIndex = split.featureIndex
val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
if (node.leftChild.isEmpty) {
// Not yet split. Return next layer of nodes to train
if (splitLeft) {
return LearningNode.leftChildIndex(node.id)
} else {
return LearningNode.rightChildIndex(node.id)
}
} else {
if (splitLeft) {
node = node.leftChild.get
} else {
node = node.rightChild.get
}
}
}
node.id
}
}
private[tree] object LearningNode {
/** Create a node with some of its fields set. */
def apply(
id: Int,
isLeaf: Boolean,
stats: ImpurityStats): LearningNode = {
new LearningNode(id, None, None, None, false, stats)
}
/** Create an empty node with the given node index. Values must be set later on. */
def emptyNode(nodeIndex: Int): LearningNode = {
new LearningNode(nodeIndex, None, None, None, false, null)
}
// The below indexing methods were copied from spark.mllib.tree.model.Node
/**
* Return the index of the left child of this node.
*/
def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
/**
* Return the index of the right child of this node.
*/
def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
/**
* Get the parent index of the given node, or 0 if it is the root.
*/
def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1
/**
* Return the level of a tree which the given node is in.
*/
def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) {
throw new IllegalArgumentException(s"0 is not a valid node index.")
} else {
java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex))
}
/**
* Returns true if this is a left child.
* Note: Returns false for the root.
*/
def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0
/**
* Return the maximum number of nodes which can be in the given level of the tree.
* @param level Level of tree (0 = root).
*/
def maxNodesInLevel(level: Int): Int = 1 << level
/**
* Return the index of the first node in the given level.
* @param level Level of tree (0 = root).
*/
def startIndexInLevel(level: Int): Int = 1 << level
/**
* Traces down from a root node to get the node with the given node index.
* This assumes the node exists.
*/
def getNode(nodeIndex: Int, rootNode: LearningNode): LearningNode = {
var tmpNode: LearningNode = rootNode
var levelsToGo = indexToLevel(nodeIndex)
while (levelsToGo > 0) {
if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
tmpNode = tmpNode.leftChild.get
} else {
tmpNode = tmpNode.rightChild.get
}
levelsToGo -= 1
}
tmpNode
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy