![JAR search and dependency download from the Maven repository](/logo.png)
com.tencent.angel.sona.ml.tuning.CrossValidator.scala Maven / Gradle / Ivy
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
package com.tencent.angel.sona.ml.tuning
import java.util.{Locale, List => JList}
import scala.collection.JavaConverters._
import scala.concurrent.Future
import scala.concurrent.duration.Duration
import org.apache.hadoop.fs.Path
import com.tencent.angel.sona.ml.util._
import org.json4s.DefaultFormats
import org.apache.spark.internal.Logging
import com.tencent.angel.sona.ml.{Estimator, Model}
import com.tencent.angel.sona.ml.evaluation.Evaluator
import com.tencent.angel.sona.ml.param.{IntParam, ParamMap, ParamValidators}
import com.tencent.angel.sona.ml.param.shared.{HasCollectSubModels, HasParallelism}
import com.tencent.angel.sona.ml.util.Instrumentation.instrumented
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util._
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
private[sona] trait CrossValidatorParams extends ValidatorParams {
* Param for number of folds for cross validation. Must be >= 2.
* Default: 3
* @group param
val numFolds: IntParam = new IntParam(this, "numFolds",
"number of folds for cross validation (>= 2)", ParamValidators.gtEq(2))
/** @group getParam */
def getNumFolds: Int = $(numFolds)
setDefault(numFolds -> 3)
* K-fold cross validation performs model selection by splitting the dataset into a set of
* non-overlapping randomly partitioned folds which are used as separate training and test datasets
* e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,
* each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the
* test set exactly once.
class CrossValidator(override val uid: String)
extends Estimator[CrossValidatorModel]
with CrossValidatorParams with HasParallelism with HasCollectSubModels
with MLWritable with Logging {
def this() = this(Identifiable.randomUID("cv"))
/** @group setParam */
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
/** @group setParam */
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
/** @group setParam */
def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
/** @group setParam */
def setNumFolds(value: Int): this.type = set(numFolds, value)
/** @group setParam */
def setSeed(value: Long): this.type = set(seed, value)
* Set the maximum level of parallelism to evaluate models in parallel.
* Default is 1 for serial evaluation
* @group expertSetParam
def setParallelism(value: Int): this.type = set(parallelism, value)
* Whether to collect submodels when fitting. If set, we can get submodels from
* the returned model.
* Note: If set this param, when you save the returned model, you can set an option
* "persistSubModels" to be "true" before saving, in order to save these submodels.
* You can check documents of
* {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter}
* for more information.
* @group expertSetParam
def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value)
override def fit(dataset: Dataset[_]): CrossValidatorModel = instrumented { instr =>
val schema = dataset.schema
transformSchema(schema, logging = true)
val sparkSession = dataset.sparkSession
val est = $(estimator)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
// Create execution context based on $(parallelism)
val executionContext = getExecutionContext
instr.logParams(this, numFolds, seed, parallelism)
val collectSubModelsParam = $(collectSubModels)
var subModels: Option[Array[Array[Model[_]]]] = if (collectSubModelsParam) {
} else None
// Compute metrics for each model over each split
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.")
// Fit models in a Future for training in parallel
val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
Future[Double] {
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
if (collectSubModelsParam) {
subModels.get(splitIndex)(paramIndex) = model
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
instr.logDebug(s"Got metric $metric for model trained with $paramMap.")
// Wait for metrics to be calculated
val foldMetrics = foldMetricFutures.map(ThreadUtil.awaitResult(_, Duration.Inf))
// Unpersist training & validation set once all metrics have been produced
}.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits
instr.logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
val (bestMetric, bestIndex) =
if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
else metrics.zipWithIndex.minBy(_._1)
instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
instr.logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
copyValues(new CrossValidatorModel(uid, bestModel, metrics)
override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
override def copy(extra: ParamMap): CrossValidator = {
val copied = defaultCopy(extra).asInstanceOf[CrossValidator]
if (copied.isDefined(estimator)) {
if (copied.isDefined(evaluator)) {
// Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types.
// E.g., this may fail if a [[Param]] is an instance of an [[Estimator]].
// However, this case should be unusual.
override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this)
object CrossValidator extends MLReadable[CrossValidator] {
override def read: MLReader[CrossValidator] = new CrossValidatorReader
override def load(path: String): CrossValidator = super.load(path)
private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter {
override protected def saveImpl(path: String): Unit =
ValidatorParams.saveImpl(path, instance, sc)
private class CrossValidatorReader extends MLReader[CrossValidator] {
/** Checked against metadata when loading model */
private val className = classOf[CrossValidator].getName
override def load(path: String): CrossValidator = {
implicit val format: DefaultFormats = DefaultFormats
val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val cv = new CrossValidator(metadata.uid)
metadata.getAndSetParams(cv, skipParams = Option(List("estimatorParamMaps")))
* CrossValidatorModel contains the model with the highest average cross-validation
* metric across folds and uses this model to transform input data. CrossValidatorModel
* also tracks the metrics for each param map evaluated.
* @param bestModel The best model selected from k-fold cross validation.
* @param avgMetrics Average cross-validation metrics for each paramMap in
* `CrossValidator.estimatorParamMaps`, in the corresponding order.
class CrossValidatorModel private[sona](
override val uid: String,
val bestModel: Model[_],
val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
/** A Python-friendly auxiliary constructor. */
private[sona] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = {
this(uid, bestModel, avgMetrics.asScala.toArray)
private var _subModels: Option[Array[Array[Model[_]]]] = None
private[tuning] def setSubModels(subModels: Option[Array[Array[Model[_]]]])
: CrossValidatorModel = {
_subModels = subModels
// A Python-friendly auxiliary method
private[tuning] def setSubModels(subModels: JList[JList[Model[_]]])
: CrossValidatorModel = {
_subModels = if (subModels != null) {
} else {
* @return submodels represented in two dimension array. The index of outer array is the
* fold index, and the index of inner array corresponds to the ordering of
* estimatorParamMaps
* @throws IllegalArgumentException if subModels are not available. To retrieve subModels,
* make sure to set collectSubModels to true before fitting.
def subModels: Array[Array[Model[_]]] = {
require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " +
"to set collectSubModels to true before fitting.")
def hasSubModels: Boolean = _subModels.isDefined
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
override def transformSchema(schema: StructType): StructType = {
override def copy(extra: ParamMap): CrossValidatorModel = {
val copied = new CrossValidatorModel(
copyValues(copied, extra).setParent(parent)
override def write: CrossValidatorModel.CrossValidatorModelWriter = {
new CrossValidatorModel.CrossValidatorModelWriter(this)
object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]])
: Option[Array[Array[Model[_]]]] = {
override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
override def load(path: String): CrossValidatorModel = super.load(path)
* Writer for CrossValidatorModel.
* @param instance CrossValidatorModel instance used to construct the writer
* CrossValidatorModelWriter supports an option "persistSubModels", with possible values
* "true" or "false". If you set the collectSubModels Param before fitting, then you can
* set "persistSubModels" to "true" in order to persist the subModels. By default,
* "persistSubModels" will be "true" when subModels are available and "false" otherwise.
* If subModels are not available, then setting "persistSubModels" to "true" will cause
* an exception.
final class CrossValidatorModelWriter private[tuning](
instance: CrossValidatorModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
val persistSubModelsParam = optionMap.getOrElse("persistsubmodels",
if (instance.hasSubModels) "true" else "false")
require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)),
s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " +
"values are \"true\" or \"false\"")
val persistSubModels = persistSubModelsParam.toBoolean
import org.json4s.JsonDSL._
val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~
("persistSubModels" -> persistSubModels)
ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
if (persistSubModels) {
require(instance.hasSubModels, "When persisting tuning models, you can only set " +
"persistSubModels to true if the tuning was done with collectSubModels set to true. " +
"To save the sub-models, try rerunning fitting with collectSubModels set to true.")
val subModelsPath = new Path(path, "subModels")
for (splitIndex <- 0 until instance.getNumFolds) {
val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}")
for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
val modelPath = new Path(splitPath, paramIndex.toString).toString
private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] {
/** Checked against metadata when loading model */
private val className = classOf[CrossValidatorModel].getName
override def load(path: String): CrossValidatorModel = {
implicit val format = DefaultFormats
val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
val persistSubModels = (metadata.metadata \ "persistSubModels")
val subModels: Option[Array[Array[Model[_]]]] = if (persistSubModels) {
val subModelsPath = new Path(path, "subModels")
val _subModels = Array.fill(numFolds)(Array.fill[Model[_]](
for (splitIndex <- 0 until numFolds) {
val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}")
for (paramIndex <- 0 until estimatorParamMaps.length) {
val modelPath = new Path(splitPath, paramIndex.toString).toString
_subModels(splitIndex)(paramIndex) =
DefaultParamsReader.loadParamsInstance(modelPath, sc)
} else None
val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
model.set(model.estimator, estimator)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps")))
© 2015 - 2025 Weber Informatics LLC | Privacy Policy