All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.h2o.models.H2ODeepLearning.scala Maven / Gradle / Ivy

There is a newer version: 1.6.8
Show newest version
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.h2o.models

import hex.deeplearning.{DeepLearning, DeepLearningModel}
import hex.deeplearning.DeepLearningModel.DeepLearningParameters
import hex.schemas.DeepLearningV3.DeepLearningParametersV3
import org.apache.spark.annotation.Since
import org.apache.spark.h2o.H2OContext
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.SQLContext


/**
 * Deep learning ML component.
 */
class H2ODeepLearningModel(model: DeepLearningModel,
                           override val uid: String)(h2oContext: H2OContext, sqlContext: SQLContext)
  extends H2OModel[H2ODeepLearningModel, DeepLearningModel](model, h2oContext, sqlContext) with MLWritable {

  def this(model: DeepLearningModel)
          (implicit h2oContext: H2OContext, sqlContext: SQLContext) = this(model, Identifiable.randomUID("dlModel"))(h2oContext, sqlContext)

  override def defaultFileName: String = H2ODeepLearningModel.defaultFileName
}

object H2ODeepLearningModel extends MLReadable[H2ODeepLearningModel] {

  val defaultFileName = "dl_model"

  @Since("1.6.0")
  override def read: MLReader[H2ODeepLearningModel] = new H2OModelReader[H2ODeepLearningModel, DeepLearningModel](defaultFileName) {
    override protected def make(model: DeepLearningModel, uid: String)
                               (implicit h2oContext: H2OContext,sqLContext: SQLContext): H2ODeepLearningModel =
      new H2ODeepLearningModel(model, uid)(h2oContext, sqlContext)
  }

  @Since("1.6.0")
  override def load(path: String): H2ODeepLearningModel = super.load(path)
}


/**
  *  Creates H2ODeepLearning model
  *  If the key specified the training set is specified using setTrainKey, then frame with this key is used as the
  *  training frame, otherwise it uses the frame from the previous stage as the training frame
  */
class H2ODeepLearning(parameters: Option[DeepLearningParameters], override val uid: String)
                     (implicit h2oContext: H2OContext, sqlContext: SQLContext)
  extends H2OAlgorithm[DeepLearningParameters, H2ODeepLearningModel](parameters)
  with H2ODeepLearningParams {

  type SELF = H2ODeepLearning

  def this()(implicit h2oContext: H2OContext, sqlContext: SQLContext) = this(None, Identifiable.randomUID("dl"))
  def this(parameters: DeepLearningParameters)(implicit h2oContext: H2OContext, sqlContext: SQLContext) = this(Option(parameters),Identifiable.randomUID("dl"))
  def this(parameters: DeepLearningParameters, uid: String)(implicit h2oContext: H2OContext, sqlContext: SQLContext) = this(Option(parameters),uid)

  override def defaultFileName: String = H2ODeepLearning.defaultFileName

  override def trainModel(params: DeepLearningParameters): H2ODeepLearningModel = {
    val model = new DeepLearning(params).trainModel().get()
    new H2ODeepLearningModel(model)
  }

  /** @group setParam */
  def setEpochs(value: Double) = set(epochs, value){getParams._epochs = value}

  /** @group setParam */
  def setL1(value: Double) = set(l1, value){getParams._l1 = value}

  /** @group setParam */
  def setL2(value: Double) = set(l2, value){getParams._l2 = value}

  /** @group setParam */
  def setHidden(value: Array[Int]) = set(hidden, value){getParams._hidden = value}

  /** @group setParam */
  def setResponseColumn(value: String) = set(responseColumn,value){getParams._response_column = value}
}

object H2ODeepLearning extends MLReadable[H2ODeepLearning] {

  private final val defaultFileName = "dl_params"

  @Since("1.6.0")
  override def read: MLReader[H2ODeepLearning] = new H2OAlgorithmReader[H2ODeepLearning, DeepLearningParameters](defaultFileName)

  @Since("1.6.0")
  override def load(path: String): H2ODeepLearning = super.load(path)
}
/**
  * Parameters here can be set as normal and are duplicated to DeepLearningParameters H2O object
  */
trait H2ODeepLearningParams extends H2OParams[DeepLearningParameters] {

  type H2O_SCHEMA = DeepLearningParametersV3

  protected def paramTag = reflect.classTag[DeepLearningParameters]
  protected def schemaTag = reflect.classTag[H2O_SCHEMA]

  /**
    * All parameters should be set here along with their documentation and explained default values
    */
  final val epochs = doubleParam("epochs")
  final val l1 = doubleParam("l1")
  final val l2 = doubleParam("l2")
  final val hidden = new IntArrayParam(this, "hidden", doc("hidden"))
  final val responseColumn = param[String]("responseColumn")

  setDefault(
    epochs -> parameters._epochs,
    l1 -> parameters._l1,
    l2 -> parameters._l2,
    hidden -> parameters._hidden,
    responseColumn -> parameters._response_column)

  /** @group getParam */
  def getEpochs: Double = $(epochs)
  /** @group getParam */
  def getL1: Double = $(l1)
  /** @group getParam */
  def getL2: Double = $(l2)
  /** @group getParam */
  def getHidden: Array[Int] = $(hidden)
  /** @group getParam */
  def getResponseColumn: String = $(responseColumn)

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy