All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.odkl.JacksonParam.scala Maven / Gradle / Ivy

package org.apache.spark.ml.odkl

import com.fasterxml.jackson.core.{JsonGenerator, JsonParser}
import com.fasterxml.jackson.databind.deser.std.PrimitiveArrayDeserializers
import com.fasterxml.jackson.databind.module.SimpleModule
import com.fasterxml.jackson.databind._
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import odkl.analysis.spark.util.Logging
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.DenseVector

import scala.reflect.ClassTag

/**
  * ml.odkl is an extension to Spark ML package with intention to
  * 1. Provide a modular structure with shared and tested common code
  * 2. Add ability to create train-only transformation (for better prediction performance)
  * 3. Unify extra information generation by the model fitters
  * 4. Support combined models with option for parallel training.
  *
  * This particular file contains utility for serializing complex parameters using jackson (handles few
  * types automatically which can not be handled by json4s)
  */


class JacksonParam[T] (
                        parent: String,
                        name: String,
                        doc: String,
                        isValid: T => Boolean,
                        default: Option[T]
                      )(implicit ct: ClassTag[T])
  extends Param[T](parent, name, doc, isValid) with Logging {

  def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean)(implicit ct: ClassTag[T]) = {
    this(parent.uid, name, doc, isValid, None)
  }

  def this(parent: String, name: String, doc: String)(implicit ct: ClassTag[T]) = {
    this(parent, name, doc, (x: T) => true, None)
  }

  def this(parent: Identifiable, name: String, doc: String)(implicit ct: ClassTag[T]) = {
    this(parent.uid, name, doc)
  }

  override def jsonEncode(value: T): String = {
    JacksonParam.objectMapper.writeValueAsString(value)
  }

  override def jsonDecode(json: String): T = {
    try {
      JacksonParam.objectMapper.readValue[T](json, ct.runtimeClass.asInstanceOf[Class[T]])
    } catch {
      case e: Throwable =>
        logError(s"Failed to read param $name from data $json due error", e)
        default.get
    }
  }

}

object JacksonParam extends Serializable {

  def apply[T](parent: Identifiable, name: String, doc: String,
               isValid: (T) => Boolean = (x: T) => true,
               default: Option[T] = None)(implicit ct: ClassTag[T]) = {
    new JacksonParam[T](parent.uid, name, doc, isValid, default)
  }

  def mapParam[V](parent: Identifiable, name: String, doc: String,
                  isValid: (Map[String, V]) => Boolean = (x: Map[String, V]) => true,
                  default: Option[Map[String, V]] = Some(Map[String,V]()))(implicit ct: ClassTag[Map[String, V]]) = {
    new JacksonParam[Map[String, V]](parent.uid, name, doc, isValid, default)
  }

  def arrayParam[V](parent: Identifiable, name: String, doc: String)(implicit ct: ClassTag[Array[V]], cv: ClassTag[V]) = {
    new JacksonParam[Array[V]](parent.uid, name, doc, (x : Array[V]) => true, Some[Array[V]](Array[V]()))
  }

  lazy val objectMapper: ObjectMapper = createDefaultMapper

  private def createDefaultMapper: ObjectMapper = {
    val m = new ObjectMapper()
    m.registerModule(DefaultScalaModule)
    m.registerModule(sparkModule)
    m
  }

  def sparkModule: Module = {
    val module = new SimpleModule("SparkJson")
    module.addSerializer(classOf[DenseVector], DenseVectorSerializer)
    module.addDeserializer(classOf[DenseVector], DenseVectorDeserializer)


    module
  }

  private object DenseVectorSerializer extends JsonSerializer[DenseVector] {

    override def serialize(value: DenseVector, gen: JsonGenerator, serializers: SerializerProvider): Unit = {
      gen.writeStartArray(value.size)
      for (i <- 0 until value.size) gen.writeNumber(value(i))
      gen.writeEndArray()
    }
  }

  private object DenseVectorDeserializer extends JsonDeserializer[DenseVector] {


    private val arrayDeserializer = PrimitiveArrayDeserializers.forType(java.lang.Double.TYPE)

    override def deserialize(p: JsonParser, ctxt: DeserializationContext): DenseVector = {
      val array = arrayDeserializer.deserialize(p, ctxt).asInstanceOf[Array[Double]]
      new DenseVector(array)
    }
  }

}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy