All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.hydrosphere.spark_ml_serving.preprocessors.LocalChiSqSelectorModel.scala Maven / Gradle / Ivy

package io.hydrosphere.spark_ml_serving.preprocessors

import io.hydrosphere.spark_ml_serving.TypedTransformerConverter
import io.hydrosphere.spark_ml_serving.common._
import org.apache.spark.ml.feature.ChiSqSelectorModel
import org.apache.spark.mllib.feature.{ChiSqSelectorModel => OldChiSqSelectorModel}

class LocalChiSqSelectorModel(override val sparkTransformer: ChiSqSelectorModel)
  extends LocalTransformer[ChiSqSelectorModel] {
  lazy val parent: OldChiSqSelectorModel = {
    val field = sparkTransformer.getClass.getDeclaredField(
      "org$apache$spark$ml$feature$ChiSqSelectorModel$$chiSqSelector"
    )
    field.setAccessible(true)
    field.get(sparkTransformer).asInstanceOf[OldChiSqSelectorModel]
  }

  override def transform(localData: LocalData): LocalData = {
    import io.hydrosphere.spark_ml_serving.common.utils.DataUtils._

    localData.column(sparkTransformer.getFeaturesCol) match {
      case Some(column) =>
        val newColumn = LocalDataColumn(
          sparkTransformer.getOutputCol,
          column.data.mapToMlLibVectors.map { x =>
            val res = parent.transform(x).toList
            res
          }
        )
        localData.withColumn(newColumn)
      case None => localData
    }
  }
}

object LocalChiSqSelectorModel
  extends SimpleModelLoader[ChiSqSelectorModel]
  with TypedTransformerConverter[ChiSqSelectorModel] {

  override def build(metadata: Metadata, data: LocalData): ChiSqSelectorModel = {
    val parentConstructor =
      classOf[OldChiSqSelectorModel].getDeclaredConstructor(classOf[Array[Int]])
    parentConstructor.setAccessible(true)
    val selectedFeatures = data.column("selectedFeatures").get.data.head.asInstanceOf[Seq[Int]]
    val mlk              = parentConstructor.newInstance(selectedFeatures.toArray)

    val constructor = classOf[ChiSqSelectorModel]
      .getDeclaredConstructor(classOf[String], classOf[OldChiSqSelectorModel])
    constructor.setAccessible(true)
    var inst = constructor
      .newInstance(metadata.uid, mlk)
      .setFeaturesCol(metadata.paramMap("featuresCol").asInstanceOf[String])
      .setOutputCol(metadata.paramMap("outputCol").asInstanceOf[String])

    inst = inst.set(inst.labelCol, metadata.paramMap("labelCol").asInstanceOf[String])
    inst = inst.set(inst.numTopFeatures, metadata.paramMap("numTopFeatures").toString.toInt)
    inst
  }

  override implicit def toLocal(transformer: ChiSqSelectorModel) =
    new LocalChiSqSelectorModel(transformer)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy