All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.hydrosphere.spark_ml_serving.regression.LocalLinearRegressionModel.scala Maven / Gradle / Ivy

package io.hydrosphere.spark_ml_serving.regression

import io.hydrosphere.spark_ml_serving.TypedTransformerConverter
import io.hydrosphere.spark_ml_serving.common._
import io.hydrosphere.spark_ml_serving.common.utils.DataUtils
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.LinearRegressionModel

class LocalLinearRegressionModel(override val sparkTransformer: LinearRegressionModel)
  extends LocalPredictionModel[LinearRegressionModel] {}

object LocalLinearRegressionModel
  extends SimpleModelLoader[LinearRegressionModel]
  with TypedTransformerConverter[LinearRegressionModel] {

  override def build(metadata: Metadata, data: LocalData): LinearRegressionModel = {
    val intercept       = data.column("intercept").get.data.head.asInstanceOf[java.lang.Double]
    val coeffitientsMap = data.column("coefficients").get.data.head.asInstanceOf[Map[String, Any]]
    val coeffitients    = DataUtils.constructVector(coeffitientsMap)

    val ctor = classOf[LinearRegressionModel].getConstructor(
      classOf[String],
      classOf[Vector],
      classOf[Double]
    )
    val inst = ctor.newInstance(metadata.uid, coeffitients, intercept)
    inst
      .set(inst.featuresCol, metadata.paramMap("featuresCol").asInstanceOf[String])
      .set(inst.predictionCol, metadata.paramMap("predictionCol").asInstanceOf[String])
      .set(inst.labelCol, metadata.paramMap("labelCol").asInstanceOf[String])
      .set(inst.elasticNetParam, metadata.paramMap("elasticNetParam").toString.toDouble)
      .set(inst.maxIter, metadata.paramMap("maxIter").asInstanceOf[Number].intValue())
      .set(inst.regParam, metadata.paramMap("regParam").toString.toDouble)
      .set(inst.solver, metadata.paramMap("solver").asInstanceOf[String])
      .set(inst.tol, metadata.paramMap("tol").toString.toDouble)
      .set(inst.standardization, metadata.paramMap("standardization").asInstanceOf[Boolean])
      .set(inst.fitIntercept, metadata.paramMap("fitIntercept").asInstanceOf[Boolean])
  }

  override implicit def toLocal(
    transformer: LinearRegressionModel
  ) = new LocalLinearRegressionModel(transformer)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy