All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator.scala Maven / Gradle / Ivy

There is a newer version: 0.90
Show newest version
/*
 Copyright (c) 2014 by Contributors

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */

package ml.dmlc.xgboost4j.scala.spark

import scala.collection.mutable

import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}

import org.apache.spark.ml.Predictor
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.{Dataset, Row}
import org.json4s.DefaultFormats

/**
 * XGBoost Estimator to produce a XGBoost model
 */
class XGBoostEstimator private[spark](
  override val uid: String, xgboostParams: Map[String, Any])
  extends Predictor[Vector, XGBoostEstimator, XGBoostModel]
  with LearningTaskParams with GeneralParams with BoosterParams with MLWritable {

  def this(xgboostParams: Map[String, Any]) =
    this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any])

  def this(uid: String) = this(uid, Map[String, Any]())

  // called in fromXGBParamMapToParams only when eval_metric is not defined
  private def setupDefaultEvalMetric(): String = {
    val objFunc = xgboostParams.getOrElse("objective", xgboostParams.getOrElse("obj_type", null))
    if (objFunc == null) {
      "rmse"
    } else {
      // compute default metric based on specified objective
      val isClassificationTask = XGBoost.isClassificationTask(xgboostParams)
      if (!isClassificationTask) {
        // default metric for regression or ranking
        if (objFunc.toString.startsWith("rank")) {
          "map"
        } else {
          "rmse"
        }
      } else {
        // default metric for classification
        if (objFunc.toString.startsWith("multi")) {
          // multi
          "merror"
        } else {
          // binary
          "error"
        }
      }
    }
  }

  private def fromXGBParamMapToParams(): Unit = {
    for ((paramName, paramValue) <- xgboostParams) {
      params.find(_.name == paramName) match {
        case None =>
        case Some(_: DoubleParam) =>
          set(paramName, paramValue.toString.toDouble)
        case Some(_: BooleanParam) =>
          set(paramName, paramValue.toString.toBoolean)
        case Some(_: IntParam) =>
          set(paramName, paramValue.toString.toInt)
        case Some(_: FloatParam) =>
          set(paramName, paramValue.toString.toFloat)
        case Some(_: Param[_]) =>
          set(paramName, paramValue)
      }
    }
    if (xgboostParams.get("eval_metric").isEmpty) {
      set("eval_metric", setupDefaultEvalMetric())
    }
  }

  fromXGBParamMapToParams()

  private[spark] def fromParamsToXGBParamMap: Map[String, Any] = {
    val xgbParamMap = new mutable.HashMap[String, Any]()
    for (param <- params) {
      xgbParamMap += param.name -> $(param)
    }
    val r = xgbParamMap.toMap
    if (!XGBoost.isClassificationTask(r) || $(numClasses) == 2) {
      r - "num_class"
    } else {
      r
    }
  }

  private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = {
    var newTrainingSet = trainingSet
    if (!trainingSet.columns.contains($(baseMarginCol))) {
      newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
    }
    if (!trainingSet.columns.contains($(weightCol))) {
      newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0))
    }
    newTrainingSet
  }

  /**
   * produce a XGBoostModel by fitting the given dataset
   */
  override def train(trainingSet: Dataset[_]): XGBoostModel = {
    val instances = ensureColumns(trainingSet).select(
      col($(featuresCol)),
      col($(labelCol)).cast(FloatType),
      col($(baseMarginCol)).cast(FloatType),
      col($(weightCol)).cast(FloatType)
    ).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
      val (indices, values) = features match {
        case v: SparseVector => (v.indices, v.values.map(_.toFloat))
        case v: DenseVector => (null, v.values.map(_.toFloat))
      }
      XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
    }
    transformSchema(trainingSet.schema, logging = true)
    val derivedXGBoosterParamMap = fromParamsToXGBParamMap
    val trainedModel = XGBoost.trainDistributed(instances, derivedXGBoosterParamMap,
      $(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
      $(missing)).setParent(this)
    val returnedModel = copyValues(trainedModel, extractParamMap())
    if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
      returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
    }
    returnedModel
  }

  override def copy(extra: ParamMap): XGBoostEstimator = {
    defaultCopy(extra).asInstanceOf[XGBoostEstimator]
  }

  override def write: MLWriter = new XGBoostEstimator.XGBoostEstimatorWriter(this)
}

object XGBoostEstimator extends MLReadable[XGBoostEstimator] {

  override def read: MLReader[XGBoostEstimator] = new XGBoostEstimatorReader

  override def load(path: String): XGBoostEstimator = super.load(path)

  private[XGBoostEstimator] class XGBoostEstimatorWriter(instance: XGBoostEstimator)
    extends MLWriter {
    override protected def saveImpl(path: String): Unit = {
      require(instance.fromParamsToXGBParamMap("custom_eval") == null &&
        instance.fromParamsToXGBParamMap("custom_obj") == null,
        "we do not support persist XGBoostEstimator with customized evaluator and objective" +
          " function for now")
      implicit val format = DefaultFormats
      implicit val sc = super.sparkSession.sparkContext
      DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
    }
  }

  private class XGBoostEstimatorReader extends MLReader[XGBoostEstimator] {

    override def load(path: String): XGBoostEstimator = {
      val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
      val cls = Utils.classForName(metadata.className)
      val instance =
        cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
      DefaultXGBoostParamsReader.getAndSetParams(instance, metadata)
      instance.asInstanceOf[XGBoostEstimator]
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy