org.apache.spark.examples.ml.GBTExample.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// scalastyle:off println
package org.apache.spark.examples.ml
import java.util.Locale
import scala.collection.mutable
import scopt.OptionParser
import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
import org.apache.spark.sql.{DataFrame, SparkSession}
/**
* An example runner for decision trees. Run with
* {{{
* ./bin/run-example ml.GBTExample [options]
* }}}
* Decision Trees and ensembles can take a large amount of memory. If the run-example command
* above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
* For local mode, run
* {{{
* ./bin/spark-submit --class org.apache.spark.examples.ml.GBTExample --driver-memory 1g
* [examples JAR path] [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object GBTExample {
case class Params(
input: String = null,
testInput: String = "",
dataFormat: String = "libsvm",
algo: String = "classification",
maxDepth: Int = 5,
maxBins: Int = 32,
minInstancesPerNode: Int = 1,
minInfoGain: Double = 0.0,
maxIter: Int = 10,
fracTest: Double = 0.2,
cacheNodeIds: Boolean = false,
checkpointDir: Option[String] = None,
checkpointInterval: Int = 10) extends AbstractParams[Params]
def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("GBTExample") {
head("GBTExample: an example Gradient-Boosted Trees app.")
opt[String]("algo")
.text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
.action((x, c) => c.copy(algo = x))
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
opt[Int]("minInstancesPerNode")
.text(s"min number of instances required at child nodes to create the parent split," +
s" default: ${defaultParams.minInstancesPerNode}")
.action((x, c) => c.copy(minInstancesPerNode = x))
opt[Double]("minInfoGain")
.text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
.action((x, c) => c.copy(minInfoGain = x))
opt[Int]("maxIter")
.text(s"number of trees in ensemble, default: ${defaultParams.maxIter}")
.action((x, c) => c.copy(maxIter = x))
opt[Double]("fracTest")
.text(s"fraction of data to hold out for testing. If given option testInput, " +
s"this option is ignored. default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
opt[Boolean]("cacheNodeIds")
.text(s"whether to use node Id cache during training, " +
s"default: ${defaultParams.cacheNodeIds}")
.action((x, c) => c.copy(cacheNodeIds = x))
opt[String]("checkpointDir")
.text(s"checkpoint directory where intermediate node Id caches will be stored, " +
s"default: ${
defaultParams.checkpointDir match {
case Some(strVal) => strVal
case None => "None"
}
}")
.action((x, c) => c.copy(checkpointDir = Some(x)))
opt[Int]("checkpointInterval")
.text(s"how often to checkpoint the node Id cache, " +
s"default: ${defaultParams.checkpointInterval}")
.action((x, c) => c.copy(checkpointInterval = x))
opt[String]("testInput")
.text(s"input path to test dataset. If given, option fracTest is ignored." +
s" default: ${defaultParams.testInput}")
.action((x, c) => c.copy(testInput = x))
opt[String]("dataFormat")
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
.action((x, c) => c.copy(dataFormat = x))
arg[String]("")
.text("input path to labeled examples")
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
if (params.fracTest < 0 || params.fracTest >= 1) {
failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
} else {
success
}
}
}
parser.parse(args, defaultParams) match {
case Some(params) => run(params)
case _ => sys.exit(1)
}
}
def run(params: Params): Unit = {
val spark = SparkSession
.builder
.appName(s"GBTExample with $params")
.getOrCreate()
params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir)
val algo = params.algo.toLowerCase(Locale.ROOT)
println(s"GBTExample with parameters:\n$params")
// Load training and test data and cache it.
val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(params.input,
params.dataFormat, params.testInput, algo, params.fracTest)
// Set up Pipeline
val stages = new mutable.ArrayBuffer[PipelineStage]()
// (1) For classification, re-index classes.
val labelColName = if (algo == "classification") "indexedLabel" else "label"
if (algo == "classification") {
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol(labelColName)
stages += labelIndexer
}
// (2) Identify categorical features using VectorIndexer.
// Features with more than maxCategories values will be treated as continuous.
val featuresIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(10)
stages += featuresIndexer
// (3) Learn GBT.
val dt = algo match {
case "classification" =>
new GBTClassifier()
.setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
.setMinInstancesPerNode(params.minInstancesPerNode)
.setMinInfoGain(params.minInfoGain)
.setCacheNodeIds(params.cacheNodeIds)
.setCheckpointInterval(params.checkpointInterval)
.setMaxIter(params.maxIter)
case "regression" =>
new GBTRegressor()
.setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
.setMinInstancesPerNode(params.minInstancesPerNode)
.setMinInfoGain(params.minInfoGain)
.setCacheNodeIds(params.cacheNodeIds)
.setCheckpointInterval(params.checkpointInterval)
.setMaxIter(params.maxIter)
case _ => throw new IllegalArgumentException(s"Algo ${params.algo} not supported.")
}
stages += dt
val pipeline = new Pipeline().setStages(stages.toArray)
// Fit the Pipeline.
val startTime = System.nanoTime()
val pipelineModel = pipeline.fit(training)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
// Get the trained GBT from the fitted PipelineModel.
algo match {
case "classification" =>
val rfModel = pipelineModel.stages.last.asInstanceOf[GBTClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
val rfModel = pipelineModel.stages.last.asInstanceOf[GBTRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case _ => throw new IllegalArgumentException(s"Algo ${params.algo} not supported.")
}
// Evaluate model on training, test data.
algo match {
case "classification" =>
println("Training data results:")
DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
println("Test data results:")
DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
case "regression" =>
println("Training data results:")
DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
println("Test data results:")
DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
case _ =>
throw new IllegalArgumentException(s"Algo ${params.algo} not supported.")
}
spark.stop()
}
}
// scalastyle:on println
© 2015 - 2025 Weber Informatics LLC | Privacy Policy