All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.core.serialize.params.ArrayMapParam.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package org.apache.spark.ml.param

import org.apache.spark.ml.util.Identifiable
import spray.json.{DefaultJsonProtocol, _}

import scala.collection.immutable.Map

object ArrayMapJsonProtocol extends DefaultJsonProtocol {

  implicit object MapJsonFormat extends JsonFormat[Map[String, Any]] {
    def write(m: Map[String, Any]): JsValue = {
      JsObject(m.mapValues {
        case v: Int => JsNumber(v)
        case v: Double => JsNumber(v)
        case v: String => JsString(v)
        case true => JsTrue
        case false => JsFalse
        case v: Map[_, _] => write(v.asInstanceOf[Map[String, Any]])
        case default => serializationError(s"Unable to serialize $default")
      })
    }

    def read(value: JsValue): Map[String, Any] = value.asInstanceOf[JsObject].fields.map(kvp => {
      val convValue = kvp._2 match {
        case JsNumber(n) => if (n.isValidInt) n.intValue().asInstanceOf[Any] else n.toDouble.asInstanceOf[Any]
        case JsString(s) => s
        case JsTrue => true
        case JsFalse => false
        case v: JsValue => read(v)
        case default => deserializationError(s"Unable to deserialize $default")
      }
      (kvp._1, convValue)
    })
  }

}

/** Param for Array of stage parameter maps. */
class ArrayMapParam(parent: String, name: String, doc: String, isValid: Array[Map[String, Any]] => Boolean)
  extends Param[Array[Map[String, Any]]](parent, name, doc, isValid) {

  import ArrayMapJsonProtocol._

  def this(parent: Params, name: String, doc: String, isValid: Array[Map[String, Any]] => Boolean) =
      this(parent.uid, name, doc, isValid)

    def this(parent: Params, name: String, doc: String) =
      this(parent, name, doc, ParamValidators.alwaysTrue)

    def this(parent: String, name: String, doc: String) =
      this(parent, name, doc, ParamValidators.alwaysTrue)

    def this(parent: Identifiable, name: String, doc: String, isValid: Array[Map[String, Any]] => Boolean) =
      this(parent.uid, name, doc, isValid)

    def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

    /** Creates a param pair with the given value (for Java). */
    override def w(value: Array[Map[String, Any]]): ParamPair[Array[Map[String, Any]]] = super.w(value)

    override def jsonEncode(value: Array[Map[String, Any]]): String = {
      val json = value.toSeq.toJson
      json.prettyPrint
    }

    override def jsonDecode(json: String): Array[Map[String, Any]] = {
      val jsonValue = json.parseJson
      jsonValue.convertTo[Seq[Map[String, Any]]].toArray
    }

  }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy