All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.hydrosphere.spark_ml_serving.preprocessors.LocalVectorIndexerModel.scala Maven / Gradle / Ivy

package io.hydrosphere.spark_ml_serving.preprocessors

import java.lang.reflect.InvocationTargetException

import io.hydrosphere.spark_ml_serving.TypedTransformerConverter
import io.hydrosphere.spark_ml_serving.common.utils.DataUtils._
import io.hydrosphere.spark_ml_serving.common._
import io.hydrosphere.spark_ml_serving.common.utils.DataUtils
import org.apache.spark.ml.feature.VectorIndexerModel
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}

class LocalVectorIndexerModel(override val sparkTransformer: VectorIndexerModel)
  extends LocalTransformer[VectorIndexerModel] {
  override def transform(localData: LocalData): LocalData = {
    val transformFunc: Vector => Vector = {
      val sortedCatFeatureIndices = sparkTransformer.categoryMaps.keys.toArray.sorted
      val localVectorMap          = sparkTransformer.categoryMaps
      val localNumFeatures        = sparkTransformer.numFeatures
      val f: Vector => Vector = { (v: Vector) =>
        assert(
          v.size == localNumFeatures,
          "VectorIndexerModel expected vector of length" +
            s" $sparkTransformer.numFeatures but found length ${v.size}"
        )
        v match {
          case dv: DenseVector =>
            val tmpv = dv.copy
            localVectorMap.foreach {
              case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
                tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
            }
            tmpv
          case sv: SparseVector =>
            // We use the fact that categorical value 0 is always mapped to index 0.
            val tmpv          = sv.copy
            var catFeatureIdx = 0
            // index into sortedCatFeatureIndices
            var k = 0 // index into non-zero elements of sparse vector
            while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) {
              val featureIndex = sortedCatFeatureIndices(catFeatureIdx)
              if (featureIndex < tmpv.indices(k)) {
                catFeatureIdx += 1
              } else if (featureIndex > tmpv.indices(k)) {
                k += 1
              } else {
                tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
                catFeatureIdx += 1
                k += 1
              }
            }
            tmpv
        }
      }
      f
    }
    localData.column(sparkTransformer.getInputCol) match {
      case Some(column) =>
        val newColumn = LocalDataColumn(
          sparkTransformer.getOutputCol,
          column.data.mapToMlVectors.map(transformFunc(_).toList)
        )
        localData.withColumn(newColumn)

      case None => localData
    }
  }
}

object LocalVectorIndexerModel
  extends SimpleModelLoader[VectorIndexerModel]
  with TypedTransformerConverter[VectorIndexerModel] {

  override def build(metadata: Metadata, data: LocalData): VectorIndexerModel = {

    val ctor = classOf[VectorIndexerModel].getDeclaredConstructor(
      classOf[String],
      classOf[Int],
      classOf[Map[Int, Map[Double, Int]]]
    )
    ctor.setAccessible(true)
    val categoryMaps = DataUtils.kludgeForVectorIndexer(
      data.column("categoryMaps").get.data.head.asInstanceOf[Map[String, Any]]
    )
    try {
      ctor
        .newInstance(
          metadata.uid,
          data.column("numFeatures").get.data.head.asInstanceOf[java.lang.Integer],
          categoryMaps
        )
        .setInputCol(metadata.paramMap("inputCol").asInstanceOf[String])
        .setOutputCol(metadata.paramMap("outputCol").asInstanceOf[String])
    } catch {
      case e: InvocationTargetException => throw e.getTargetException
      case e: Throwable                 => throw e
    }
  }

  override implicit def toLocal(
    transformer: VectorIndexerModel
  ) = new LocalVectorIndexerModel(transformer)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy