All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.core.serialize.params.MapArrayParam.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package org.apache.spark.ml.param

import spray.json.{DefaultJsonProtocol, _}
import scala.collection.JavaConverters._
import scala.collection.immutable.Map
import scala.collection.mutable

object MapArrayJsonProtocol extends DefaultJsonProtocol {

  implicit object MapJsonFormat extends JsonFormat[Map[String, Seq[String]]] {
    def write(m: Map[String, Seq[String]]): JsValue = {
      JsObject(m.mapValues {
        case v: Seq[String] => seqFormat[String].write(v)
        case default => serializationError(s"Unable to serialize $default")
      })
    }

    def read(value: JsValue): Map[String, Seq[String]] = value.asInstanceOf[JsObject].fields.map(kvp => {
      val convValue = kvp._2 match {
        case v: JsValue => seqFormat[String].read(v)
        case default => deserializationError(s"Unable to deserialize $default")
      }
      (kvp._1, convValue)
    })
  }

}

/** Param for Map of String to Seq of String. */
class MapArrayParam(parent: Params, name: String, doc: String, isValid: Map[String, Seq[String]] => Boolean)
  extends Param[Map[String, Seq[String]]](parent, name, doc, isValid) {
    import MapArrayJsonProtocol._

    def this(parent: Params, name: String, doc: String) =
      this(parent, name, doc, ParamValidators.alwaysTrue)

    /** Creates a param pair with the given value (for Java). */
    def w(value: java.util.HashMap[String, java.util.List[String]]): ParamPair[Map[String, Seq[String]]] = {
      val mutMap = mutable.Map[String, Seq[String]]()
      for (key <- value.keySet().asScala) {
        val list = value.get(key).asScala
        mutMap(key) = list
      }
      w(mutMap.toMap)
    }

    override def jsonEncode(value: Map[String, Seq[String]]): String = {
      val convertedMap = value.map(kvp => (kvp._1, kvp._2.toArray))
      val json = convertedMap.toJson
      json.prettyPrint
    }

    override def jsonDecode(json: String): Map[String, Seq[String]] = {
      val jsonValue = json.parseJson
      jsonValue.convertTo[Map[String, Seq[String]]]
    }

  }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy