All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.sparkling.ml.params.AlgoParam.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ai.h2o.sparkling.ml.params

import ai.h2o.sparkling.ml.algos._
import hex.Model
import org.apache.spark.ml.param.{Param, ParamPair, Params}
import org.json4s.JsonAST.JNull
import org.json4s.JsonDSL._
import org.json4s._
import org.json4s.jackson.JsonMethods.{compact, parse, render}

class AlgoParam(parent: Params, name: String, doc: String, isValid: H2OAlgorithm[_ <: Model.Parameters] => Boolean)
  extends Param[H2OAlgorithm[_ <: Model.Parameters]](parent, name, doc, isValid) {

  def this(parent: Params, name: String, doc: String) =
    this(parent, name, doc, _ => true)

  override def jsonEncode(value: H2OAlgorithm[_ <: Model.Parameters]): String = {
    val encoded = AlgoParam.jsonEncode(value)
    compact(render(encoded))
  }

  override def jsonDecode(json: String): H2OAlgorithm[_ <: Model.Parameters] = {
    val parsed = parse(json)
    AlgoParam.jsonDecode(parsed)
  }
}

object AlgoParam {

  def jsonEncode(value: H2OAlgorithm[_ <: Model.Parameters]): JValue = {
    if (value == null) {
      JNull
    } else {
      val algoClassName = value.getClass.getName
      val uid = value.uid
      val params = value.extractParamMap().toSeq
      val jsonParams = render(params.map {
        case ParamPair(p, v) =>
          p.name -> parse(p.jsonEncode(v))
      }.toList)

      ("class" -> algoClassName) ~ ("uid" -> uid) ~ ("paramMap" -> jsonParams)
    }
  }

  def jsonDecode(parsed: JValue): H2OAlgorithm[_ <: Model.Parameters] = {
    if (parsed == JNull) {
      null
    } else {
      implicit val format = DefaultFormats
      val className = (parsed \ "class").extract[String]
      val uid = (parsed \ "uid").extract[String]
      val jsonParams = parsed \ "paramMap"

      val algo = createH2OAlgoInstance(className, uid)
      jsonParams match {
        case JObject(pairs) =>
          pairs.foreach {
            case (paramName, jsonValue) =>
              val param = algo.getParam(paramName)
              val value = param.jsonDecode(compact(render(jsonValue)))
              algo.set(param, value)
          }
        case _ => throw new RuntimeException("Invalid JSON parameters")
      }
      algo
    }
  }

  private def createH2OAlgoInstance(algoName: String, uid: String): H2OAlgorithm[_ <: Model.Parameters] = {
    val cls = Class.forName(algoName)
    cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[H2OAlgorithm[_ <: Model.Parameters]]
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy