All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.h2o.models.H2OModel.scala Maven / Gradle / Ivy

There is a newer version: 1.6.8
Show newest version
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.h2o.models

import java.io.File

import hex.Model
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.h2o._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util._
import org.apache.spark.ml.{Model => SparkModel}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, SQLContext}
import water.app.ModelSerializationSupport

import scala.reflect.ClassTag

/**
  * Shared implementation for H2O model pipelines
  */
abstract class H2OModel[S <: H2OModel[S, M],
                        M <: Model[_, _, _ <: Model.Output]]
                        (val model: M, h2oContext: H2OContext, sqlContext: SQLContext)
  extends SparkModel[S] with MLWritable {

  override def copy(extra: ParamMap): S = defaultCopy(extra)

  override def transform(dataset: DataFrame): DataFrame = {
    val frame: H2OFrame = h2oContext.asH2OFrame(dataset)
    val prediction = model.score(frame)
    h2oContext.asDataFrame(prediction)(sqlContext)
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    val ncols: Int = if (model._output.nclasses() == 1) 1 else model._output.nclasses() + 1
    StructType(model._output._names.map {
      name => StructField(name, DoubleType, nullable = true, metadata = null)
    })
  }

  @Since("1.6.0")
  override def write: MLWriter =  new H2OModelWriter[S](this.asInstanceOf[S])

  def defaultFileName: String
}

private[models] class H2OModelWriter[T <: H2OModel[T, _ <: Model[_, _, _ <: Model.Output]]](instance: T) extends MLWriter {

  @org.apache.spark.annotation.Since("1.6.0")
  override protected def saveImpl(path: String): Unit = {
    DefaultParamsWriter.saveMetadata(instance, path, sc)
    val file = new java.io.File(path, instance.defaultFileName)
    ModelSerializationSupport.exportH2OModel(instance.model, file.toURI)
  }
}

private[models] abstract class H2OModelReader[T <: H2OModel[T, M] : ClassTag, M <: Model[_, _, _ <: Model.Output]]
                                              (val defaultFileName: String) extends MLReader[T] {

  private val className = implicitly[ClassTag[T]].runtimeClass.getName

  @org.apache.spark.annotation.Since("1.6.0")
  override def load(path: String): T = {
    val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
    val file = new File(path, defaultFileName)
    val model = ModelSerializationSupport.loadH2OModel[M](file.toURI)
    implicit val h2oContext = H2OContext.get().getOrElse(throw new RuntimeException("H2OContext has to be started in order to use H2O pipelines elements"))
    val sqlContext = SQLContext.getOrCreate(sc)
    val h2oModel = make(model, metadata.uid)(h2oContext, sqlContext)
    DefaultParamsReader.getAndSetParams(h2oModel, metadata)
    h2oModel
  }

  protected def make(model: M, uid: String)(implicit h2oContext: H2OContext, sqLContext: SQLContext): T
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy