All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.h2o.models.H2OGBM.scala Maven / Gradle / Ivy

There is a newer version: 1.6.8
Show newest version
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.h2o.models

import hex.schemas.GBMV3.GBMParametersV3
import hex.tree.SharedTreeModel.SharedTreeParameters
import hex.tree.gbm.GBMModel.GBMParameters
import hex.tree.gbm.{GBM, GBMModel}
import org.apache.spark.annotation.Since
import org.apache.spark.h2o.H2OContext
import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader, MLWritable}
import org.apache.spark.sql.SQLContext

/**
  * H2O GBM Algo exposed via Spark ML pipelines.
  */
class H2OGBMModel(model: GBMModel, override val uid: String)(h2oContext: H2OContext, sqlContext: SQLContext)
  extends H2OModel[H2OGBMModel,GBMModel](model, h2oContext, sqlContext) with MLWritable {

  def this(model: GBMModel)(implicit h2oContext: H2OContext, sqlContext: SQLContext) = this(model, Identifiable.randomUID("gbmModel"))(h2oContext, sqlContext)

  override def defaultFileName: String = H2OGBMModel.defaultFileName
}

object H2OGBMModel extends MLReadable[H2OGBMModel] {
  val defaultFileName = "gbm_model"

  @Since("1.6.0")
  override def read: MLReader[H2OGBMModel] = new H2OModelReader[H2OGBMModel, GBMModel](defaultFileName) {
    override protected def make(model: GBMModel, uid: String)
                               (implicit h2oContext: H2OContext,sqLContext: SQLContext): H2OGBMModel = new H2OGBMModel(model, uid)(h2oContext, sqlContext)
  }

  @Since("1.6.0")
  override def load(path: String): H2OGBMModel = super.load(path)
}

class H2OGBM(parameters: Option[GBMParameters], override val uid: String)
                     (implicit h2oContext: H2OContext, sqlContext: SQLContext)
  extends H2OAlgorithm[GBMParameters, H2OGBMModel](parameters)
          with H2OGBMParams {

  type SELF = H2OGBM

  override def defaultFileName: String = H2OGBM.defaultFileName

  override def trainModel(params: GBMParameters): H2OGBMModel = {
    val model = new GBM(params).trainModel().get()
    new H2OGBMModel(model)
  }

  def this()(implicit h2oContext: H2OContext, sqlContext: SQLContext) = this(None, Identifiable.randomUID("dl"))
  def this(parameters: GBMParameters)(implicit h2oContext: H2OContext, sqlContext: SQLContext) = this(Option(parameters),Identifiable.randomUID("dl"))
  def this(parameters: GBMParameters, uid: String)(implicit h2oContext: H2OContext, sqlContext: SQLContext) = this(Option(parameters),uid)
}

object H2OGBM extends MLReadable[H2OGBM] {

  private final val defaultFileName = "gbm_params"

  @Since("1.6.0")
  override def read: MLReader[H2OGBM] = new H2OAlgorithmReader[H2OGBM, GBMParameters](defaultFileName)

  @Since("1.6.0")
  override def load(path: String): H2OGBM = super.load(path)
}


/**
  * Parameters for Spark's API exposing underlying H2O model.
  */
trait H2OGBMParams extends H2OTreeParams[GBMParameters] {

  type H2O_SCHEMA = GBMParametersV3

  protected def paramTag = reflect.classTag[GBMParameters]
  protected def schemaTag = reflect.classTag[H2O_SCHEMA]

  /**
    * All parameters should be set here along with their documentation and explained default values
    */
  final val learnRate = doubleParam("learnRate")
  final val learnRateAnnealing = doubleParam("learnRateAnnealing")
  final val colSampleRate = doubleParam("colSampleRate")
  final val maxAbsLeafnodePred = doubleParam("maxAbsLeafnodePred")
  final val responseColumn = param[String]("responseColumn")

  setDefault(learnRate -> parameters._learn_rate)
  setDefault(learnRateAnnealing -> parameters._learn_rate_annealing)
  setDefault(colSampleRate -> parameters._col_sample_rate)
  setDefault(maxAbsLeafnodePred -> parameters._max_abs_leafnode_pred)
  setDefault(responseColumn -> parameters._response_column)
}

trait H2OTreeParams[P <: SharedTreeParameters] extends H2OParams[P] {
  final val ntrees = intParam("ntrees")
  final val maxDepth = intParam("maxDepth")
  final val minRows = doubleParam("maxRows")
  final val nbins = intParam("nbins")
  final val nbinsCat = intParam("nbinsCat")
  final val minSplitImprovement = doubleParam("minSplitImprovement")
  final val r2Stopping = doubleParam("r2Stopping")
  final val seed = longParam("seed")
  final val nbinsTopLevel = intParam("nbinsTopLevel")
  final val buildTreeOneNode = booleanParam("buildTreeOneNode")
  final val scoreTreeInterval = intParam("scoreTreeInterval")
  final val initialScoreInterval = intParam("initialScoreInterval")
  final val sampleRate = doubleParam("sampleRate")

  setDefault(
    ntrees -> parameters._ntrees,
    maxDepth -> parameters._max_depth,
    minRows -> parameters._min_rows,
    nbins -> parameters._nbins,
    nbinsCat -> parameters._nbins_cats,
    minSplitImprovement -> parameters._min_split_improvement,
    r2Stopping -> parameters._r2_stopping,
    seed -> parameters._seed,
    nbinsTopLevel -> parameters._nbins_top_level,
    buildTreeOneNode -> parameters._build_tree_one_node,
    scoreTreeInterval -> parameters._score_tree_interval,
    initialScoreInterval -> parameters._initial_score_interval,
    sampleRate -> parameters._sample_rate)

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy