org.apache.spark.ml.tree.treeParams.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.tree
import java.util.Locale
import scala.util.Try
import org.apache.spark.annotation.Since
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.classification.ProbabilisticClassifierParams
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
* Parameters for Decision Tree-based algorithms.
*
* Note: Marked as private since this may be made public in the future.
*/
private[ml] trait DecisionTreeParams extends PredictorParams
with HasCheckpointInterval with HasSeed with HasWeightCol {
/**
* Leaf indices column name.
* Predicted leaf index of each instance in each tree by preorder.
* (default = "")
* @group param
*/
@Since("3.0.0")
final val leafCol: Param[String] =
new Param[String](this, "leafCol", "Leaf indices column name. " +
"Predicted leaf index of each instance in each tree by preorder")
/**
* Maximum depth of the tree (nonnegative).
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* (default = 5)
* @group param
*/
final val maxDepth: IntParam =
new IntParam(this, "maxDepth", "Maximum depth of the tree. (Nonnegative)" +
" E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes." +
" Must be in range [0, 30].",
ParamValidators.inRange(0, 30))
/**
* Maximum number of bins used for discretizing continuous features and for choosing how to split
* on features at each node. More bins give higher granularity.
* Must be at least 2 and at least number of categories in any categorical feature.
* (default = 32)
* @group param
*/
final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
" discretizing continuous features. Must be at least 2 and at least number of categories" +
" for any categorical feature.", ParamValidators.gtEq(2))
/**
* Minimum number of instances each child must have after split.
* If a split causes the left or right child to have fewer than minInstancesPerNode,
* the split will be discarded as invalid.
* Must be at least 1.
* (default = 1)
* @group param
*/
final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
" number of instances each child must have after split. If a split causes the left or right" +
" child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
" Must be at least 1.", ParamValidators.gtEq(1))
/**
* Minimum fraction of the weighted sample count that each child must have after split.
* If a split causes the fraction of the total weight in the left or right child to be less than
* minWeightFractionPerNode, the split will be discarded as invalid.
* Should be in the interval [0.0, 0.5).
* (default = 0.0)
* @group param
*/
final val minWeightFractionPerNode: DoubleParam = new DoubleParam(this,
"minWeightFractionPerNode", "Minimum fraction of the weighted sample count that each child " +
"must have after split. If a split causes the fraction of the total weight in the left or " +
"right child to be less than minWeightFractionPerNode, the split will be discarded as " +
"invalid. Should be in interval [0.0, 0.5)",
ParamValidators.inRange(0.0, 0.5, lowerInclusive = true, upperInclusive = false))
/**
* Minimum information gain for a split to be considered at a tree node.
* Should be at least 0.0.
* (default = 0.0)
* @group param
*/
final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
"Minimum information gain for a split to be considered at a tree node.",
ParamValidators.gtEq(0.0))
/**
* Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be
* split per iteration, and its aggregates may exceed this size.
* (default = 256 MB)
* @group expertParam
*/
final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
"Maximum memory in MB allocated to histogram aggregation.",
ParamValidators.gtEq(0))
/**
* If false, the algorithm will pass trees to executors to match instances with nodes.
* If true, the algorithm will cache node IDs for each instance.
* Caching can speed up training of deeper trees. Users can set how often should the
* cache be checkpointed or disable it by setting checkpointInterval.
* (default = false)
* @group expertParam
*/
final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
" algorithm will pass trees to executors to match instances with nodes. If true, the" +
" algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
" trees.")
setDefault(leafCol -> "", maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256,
cacheNodeIds -> false, checkpointInterval -> 10)
/** @group setParam */
@Since("3.0.0")
final def setLeafCol(value: String): this.type = set(leafCol, value)
/** @group getParam */
@Since("3.0.0")
final def getLeafCol: String = $(leafCol)
/** @group getParam */
final def getMaxDepth: Int = $(maxDepth)
/** @group getParam */
final def getMaxBins: Int = $(maxBins)
/** @group getParam */
final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
/** @group getParam */
final def getMinWeightFractionPerNode: Double = $(minWeightFractionPerNode)
/** @group getParam */
final def getMinInfoGain: Double = $(minInfoGain)
/** @group expertGetParam */
final def getMaxMemoryInMB: Int = $(maxMemoryInMB)
/** @group expertGetParam */
final def getCacheNodeIds: Boolean = $(cacheNodeIds)
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity,
subsamplingRate: Double): OldStrategy = {
val strategy = OldStrategy.defaultStrategy(oldAlgo)
strategy.impurity = oldImpurity
strategy.checkpointInterval = getCheckpointInterval
strategy.maxBins = getMaxBins
strategy.maxDepth = getMaxDepth
strategy.maxMemoryInMB = getMaxMemoryInMB
strategy.minInfoGain = getMinInfoGain
strategy.minInstancesPerNode = getMinInstancesPerNode
strategy.minWeightFractionPerNode = getMinWeightFractionPerNode
strategy.useNodeIdCache = getCacheNodeIds
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
strategy.subsamplingRate = subsamplingRate
strategy
}
}
/**
* Parameters for Decision Tree-based classification algorithms.
*/
private[ml] trait TreeClassifierParams extends Params {
/**
* Criterion used for information gain calculation (case-insensitive).
* This impurity type is used in DecisionTreeClassifier and RandomForestClassifier,
* Supported: "entropy" and "gini".
* (default = gini)
* @group param
*/
final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
" information gain calculation (case-insensitive). Supported options:" +
s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
(value: String) =>
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
setDefault(impurity -> "gini")
/** @group getParam */
final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT)
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
case "entropy" => OldEntropy
case "gini" => OldGini
case _ =>
// Should never happen because of check in setter method.
throw new RuntimeException(
s"TreeClassifierParams was given unrecognized impurity: $impurity.")
}
}
}
private[ml] object TreeClassifierParams {
// These options should be lowercase.
final val supportedImpurities: Array[String] =
Array("entropy", "gini").map(_.toLowerCase(Locale.ROOT))
}
private[ml] trait DecisionTreeClassifierParams
extends DecisionTreeParams with TreeClassifierParams with ProbabilisticClassifierParams {
override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
var outputSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
if ($(leafCol).nonEmpty) {
outputSchema = SchemaUtils.appendColumn(outputSchema, $(leafCol), DoubleType)
}
outputSchema
}
}
private[ml] trait HasVarianceImpurity extends Params {
/**
* Criterion used for information gain calculation (case-insensitive).
* This impurity type is used in DecisionTreeRegressor, RandomForestRegressor, GBTRegressor
* and GBTClassifier (since GBTClassificationModel is internally composed of
* DecisionTreeRegressionModels).
* Supported: "variance".
* (default = variance)
* @group param
*/
final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
" information gain calculation (case-insensitive). Supported options:" +
s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}",
(value: String) =>
HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
setDefault(impurity -> "variance")
/** @group getParam */
final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT)
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
case "variance" => OldVariance
case _ =>
// Should never happen because of check in setter method.
throw new RuntimeException(
s"TreeRegressorParams was given unrecognized impurity: $impurity")
}
}
}
private[ml] object HasVarianceImpurity {
// These options should be lowercase.
final val supportedImpurities: Array[String] =
Array("variance").map(_.toLowerCase(Locale.ROOT))
}
/**
* Parameters for Decision Tree-based regression algorithms.
*/
private[ml] trait TreeRegressorParams extends HasVarianceImpurity
private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
with TreeRegressorParams with HasVarianceCol {
override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
var outputSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
outputSchema = SchemaUtils.appendColumn(outputSchema, $(varianceCol), DoubleType)
}
if ($(leafCol).nonEmpty) {
outputSchema = SchemaUtils.appendColumn(outputSchema, $(leafCol), DoubleType)
}
outputSchema
}
}
private[spark] object TreeEnsembleParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
}
/**
* Parameters for Decision Tree-based ensemble algorithms.
*
* Note: Marked as private since this may be made public in the future.
*/
private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
/**
* Fraction of the training data used for learning each decision tree, in range (0, 1].
* (default = 1.0)
* @group param
*/
final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
"Fraction of the training data used for learning each decision tree, in range (0, 1].",
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
/** @group getParam */
final def getSubsamplingRate: Double = $(subsamplingRate)
/**
* Create a Strategy instance to use with the old API.
* NOTE: The caller should set impurity and seed.
*/
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
}
/**
* The number of features to consider for splits at each tree node.
* Supported options:
* - "auto": Choose automatically for task:
* If numTrees == 1, set to "all."
* If numTrees greater than 1 (forest), set to "sqrt" for classification and
* to "onethird" for regression.
* - "all": use all features
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features)
* - "n": when n is in the range (0, 1.0], use n * number of features. When n
* is in the range (1, number of features), use n features.
* (default = "auto")
*
* These various settings are based on the following references:
* - log2: tested in Breiman (2001)
* - sqrt: recommended by Breiman manual for random forests
* - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
* package.
* @see Breiman (2001)
* @see
* Breiman manual for random forests
*
* @group param
*/
final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
"The number of features to consider for splits at each tree node." +
s" Supported options: ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
s", (0.0-1.0], [1-n].",
(value: String) =>
TreeEnsembleParams.supportedFeatureSubsetStrategies.contains(
value.toLowerCase(Locale.ROOT))
|| Try(value.toInt).filter(_ > 0).isSuccess
|| Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
/** @group getParam */
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT)
setDefault(subsamplingRate -> 1.0, featureSubsetStrategy -> "auto")
}
/**
* Parameters for Decision Tree-based ensemble classification algorithms.
*/
private[ml] trait TreeEnsembleClassifierParams
extends TreeEnsembleParams with ProbabilisticClassifierParams {
override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
var outputSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
if ($(leafCol).nonEmpty) {
outputSchema = SchemaUtils.appendColumn(outputSchema, $(leafCol), new VectorUDT)
}
outputSchema
}
}
/**
* Parameters for Decision Tree-based ensemble regression algorithms.
*/
private[ml] trait TreeEnsembleRegressorParams
extends TreeEnsembleParams {
override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
var outputSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
if ($(leafCol).nonEmpty) {
outputSchema = SchemaUtils.appendColumn(outputSchema, $(leafCol), new VectorUDT)
}
outputSchema
}
}
/**
* Parameters for Random Forest algorithms.
*/
private[ml] trait RandomForestParams extends TreeEnsembleParams {
/**
* Number of trees to train (at least 1).
* If 1, then no bootstrapping is used. If greater than 1, then bootstrapping is done.
* TODO: Change to always do bootstrapping (simpler). SPARK-7130
* (default = 20)
*
* Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
* is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms
* are a bit different.
* @group param
*/
final val numTrees: IntParam =
new IntParam(this, "numTrees", "Number of trees to train (at least 1)",
ParamValidators.gtEq(1))
/** @group getParam */
final def getNumTrees: Int = $(numTrees)
/**
* Whether bootstrap samples are used when building trees.
* @group expertParam
*/
@Since("3.0.0")
final val bootstrap: BooleanParam = new BooleanParam(this, "bootstrap",
"Whether bootstrap samples are used when building trees.")
/** @group getParam */
@Since("3.0.0")
final def getBootstrap: Boolean = $(bootstrap)
setDefault(numTrees -> 20, bootstrap -> true)
}
private[ml] trait RandomForestClassifierParams
extends RandomForestParams with TreeEnsembleClassifierParams with TreeClassifierParams
private[ml] trait RandomForestRegressorParams
extends RandomForestParams with TreeEnsembleRegressorParams with TreeRegressorParams
/**
* Parameters for Gradient-Boosted Tree algorithms.
*
* Note: Marked as private since this may be made public in the future.
*/
private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize
with HasValidationIndicatorCol {
/**
* Threshold for stopping early when fit with validation is used.
* (This parameter is ignored when fit without validation is used.)
* The decision to stop early is decided based on this logic:
* If the current loss on the validation set is greater than 0.01, the diff
* of validation error is compared to relative tolerance which is
* validationTol * (current loss on the validation set).
* If the current loss on the validation set is less than or equal to 0.01,
* the diff of validation error is compared to absolute tolerance which is
* validationTol * 0.01.
* @group param
* @see validationIndicatorCol
*/
@Since("2.4.0")
final val validationTol: DoubleParam = new DoubleParam(this, "validationTol",
"Threshold for stopping early when fit with validation is used." +
"If the error rate on the validation input changes by less than the validationTol," +
"then learning will stop early (before `maxIter`)." +
"This parameter is ignored when fit without validation is used.",
ParamValidators.gtEq(0.0)
)
/** @group getParam */
@Since("2.4.0")
final def getValidationTol: Double = $(validationTol)
/**
* Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking
* the contribution of each estimator.
* (default = 0.1)
* @group param
*/
final override val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size " +
"(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.",
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01, featureSubsetStrategy -> "all")
/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
private[ml] def getOldBoostingStrategy(
categoricalFeatures: Map[Int, Int],
oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
// NOTE: The old API does not support "seed" so we ignore it.
new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize, getValidationTol)
}
/** Get old Gradient Boosting Loss type */
private[ml] def getOldLossType: OldLoss
}
private[ml] object GBTClassifierParams {
// The losses below should be lowercase.
/** Accessor for supported loss settings: logistic */
final val supportedLossTypes: Array[String] =
Array("logistic").map(_.toLowerCase(Locale.ROOT))
}
private[ml] trait GBTClassifierParams
extends GBTParams with TreeEnsembleClassifierParams with HasVarianceImpurity {
/**
* Loss function which GBT tries to minimize. (case-insensitive)
* Supported: "logistic"
* (default = logistic)
* @group param
*/
val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
" tries to minimize (case-insensitive). Supported options:" +
s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}",
(value: String) =>
GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT)))
setDefault(lossType -> "logistic")
/** @group getParam */
def getLossType: String = $(lossType).toLowerCase(Locale.ROOT)
/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldClassificationLoss = {
getLossType match {
case "logistic" => OldLogLoss
case _ =>
// Should never happen because of check in setter method.
throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
}
}
}
private[ml] object GBTRegressorParams {
// The losses below should be lowercase.
/** Accessor for supported loss settings: squared (L2), absolute (L1) */
final val supportedLossTypes: Array[String] =
Array("squared", "absolute").map(_.toLowerCase(Locale.ROOT))
}
private[ml] trait GBTRegressorParams
extends GBTParams with TreeEnsembleRegressorParams with TreeRegressorParams {
/**
* Loss function which GBT tries to minimize. (case-insensitive)
* Supported: "squared" (L2) and "absolute" (L1)
* (default = squared)
* @group param
*/
val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
" tries to minimize (case-insensitive). Supported options:" +
s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}",
(value: String) =>
GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT)))
setDefault(lossType -> "squared")
/** @group getParam */
def getLossType: String = $(lossType).toLowerCase(Locale.ROOT)
/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
convertToOldLossType(getLossType)
}
private[ml] def convertToOldLossType(loss: String): OldLoss = {
loss match {
case "squared" => OldSquaredError
case "absolute" => OldAbsoluteError
case _ =>
// Should never happen because of check in setter method.
throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy