All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.hydrosphere.mist.ml.loaders.LocalLogisticRegressionModel.scala Maven / Gradle / Ivy

package io.hydrosphere.mist.ml.loaders

import java.lang.Boolean

import io.hydrosphere.mist.ml.Metadata
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.classification.LogisticRegressionModel
import org.apache.spark.ml.linalg.{Matrix, SparseMatrix, Vector, Vectors}

object LocalLogisticRegressionModel extends LocalModel {
  override def localLoad(metadata: Metadata, data: Map[String, Any]): Transformer = {
    if (data.contains("coefficients")) {
      // Spark 2.0
      val constructor = classOf[LogisticRegressionModel].getDeclaredConstructor(classOf[String], classOf[Vector], classOf[Double])
      constructor.setAccessible(true)
      val coefficientsParams = data("coefficients").asInstanceOf[Map[String, Any]]
      val coefficients = Vectors.sparse(
        coefficientsParams("size").asInstanceOf[Int],
        coefficientsParams("indices").asInstanceOf[List[Int]].toArray[Int],
        coefficientsParams("values").asInstanceOf[List[Double]].toArray[Double]
      )
      constructor
        .newInstance(metadata.uid, coefficients, data("intercept").asInstanceOf[java.lang.Double])
        .setFeaturesCol(metadata.paramMap("featuresCol").asInstanceOf[String])
        .setPredictionCol(metadata.paramMap("predictionCol").asInstanceOf[String])
        .setProbabilityCol(metadata.paramMap("probabilityCol").asInstanceOf[String])
        .setThreshold(metadata.paramMap("threshold").asInstanceOf[Double])
    } else if (data.contains("coefficientMatrix")) {
      // Spark 2.1
      val constructor = classOf[LogisticRegressionModel].getDeclaredConstructor(classOf[String], classOf[Matrix], classOf[Vector], classOf[Int], java.lang.Boolean.TYPE)
      constructor.setAccessible(true)
      val coefficientMatrixParams = data("coefficientMatrix").asInstanceOf[Map[String, Any]]
      val coefficientMatrix = new SparseMatrix(
        coefficientMatrixParams("numRows").asInstanceOf[Int],
        coefficientMatrixParams("numCols").asInstanceOf[Int],
        coefficientMatrixParams("colPtrs").asInstanceOf[List[Int]].toArray[Int],
        coefficientMatrixParams("rowIndices").asInstanceOf[List[Int]].toArray[Int],
        coefficientMatrixParams("values").asInstanceOf[List[Double]].toArray[Double],
        coefficientMatrixParams("isTransposed").asInstanceOf[Boolean]
      )
      val interceptVectorParams = data("interceptVector").asInstanceOf[Map[String, Any]]
      val interceptVector = Vectors.dense(interceptVectorParams("values").asInstanceOf[List[Double]].toArray[Double])
      constructor
        .newInstance(metadata.uid, coefficientMatrix, interceptVector, new Integer(data("numFeatures").asInstanceOf[Int]), new Boolean(data("isMultinomial").asInstanceOf[Boolean]))
        .setFeaturesCol(metadata.paramMap("featuresCol").asInstanceOf[String])
        .setPredictionCol(metadata.paramMap("predictionCol").asInstanceOf[String])
        .setProbabilityCol(metadata.paramMap("probabilityCol").asInstanceOf[String])
        .setThreshold(metadata.paramMap("threshold").asInstanceOf[Double])
    } else {
      throw new Exception("Unknown LogisticRegressionModel implementation")
    }
  } 
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy