All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.hydrosphere.spark_ml_serving.ModelConversions.scala Maven / Gradle / Ivy

There is a newer version: 0.3.3
Show newest version
package io.hydrosphere.spark_ml_serving

import io.hydrosphere.spark_ml_serving.classification._
import io.hydrosphere.spark_ml_serving.clustering._
import io.hydrosphere.spark_ml_serving.common._
import io.hydrosphere.spark_ml_serving.preprocessors._
import io.hydrosphere.spark_ml_serving.regression._
import org.apache.spark.ml.classification._
import org.apache.spark.ml.clustering._
import org.apache.spark.ml.feature._
import org.apache.spark.ml.regression._
import org.apache.spark.ml.{PipelineModel, Transformer}

object ModelConversions {
  implicit def sparkToLocal[T <: Transformer](m: Any): LocalModel[T] = {
    m match {
      case _ : PipelineModel.type  => LocalPipelineModel

      case x: LocalModel[T] => x

      // Classification models
      case _: DecisionTreeClassificationModel.type  => LocalDecisionTreeClassificationModel
      case _: MultilayerPerceptronClassificationModel.type => LocalMultilayerPerceptronClassificationModel
      case _: NaiveBayesModel.type => LocalNaiveBayes
      case _: RandomForestClassificationModel.type => LocalRandomForestClassificationModel
      case _: LogisticRegressionModel.type => LocalLogisticRegressionModel
      case _: GBTClassificationModel.type  => LocalGBTClassificationModel

        // Clustering models
      case _: GaussianMixtureModel.type => LocalGaussianMixtureModel
      case _: KMeansModel.type  => LocalKMeansModel

        // Preprocessing
      case _: Binarizer.type => LocalBinarizer
      case _: CountVectorizerModel.type => LocalCountVectorizerModel
      case _: DCT.type => LocalDCT
      case _: HashingTF.type => LocalHashingTF
      case _: IndexToString.type => LocalIndexToString
      case _: MaxAbsScalerModel.type => LocalMaxAbsScalerModel
      case _: MinMaxScalerModel.type => LocalMinMaxScalerModel
      case _: NGram.type => LocalNGram
      case _: Normalizer.type => LocalNormalizer
      case _: OneHotEncoder.type => LocalOneHotEncoder
      case _: PCAModel.type => LocalPCAModel
      case _: PolynomialExpansion.type => LocalPolynomialExpansion
      case _: StandardScalerModel.type => LocalStandardScalerModel
      case _: StopWordsRemover.type  => LocalStopWordsRemover
      case _: StringIndexerModel.type => LocalStringIndexerModel
      case _: Tokenizer.type  => LocalTokenizer
      case _: VectorIndexerModel.type => LocalVectorIndexerModel
      case _: Word2VecModel.type => LocalWord2VecModel
      case _: IDFModel.type => LocalIDF
      case _: ChiSqSelectorModel.type => LocalChiSqSelectorModel
      case _: RegexTokenizer.type  => LocalRegexTokenizer

        // Regression
      case _: DecisionTreeRegressionModel.type => LocalDecisionTreeRegressionModel
      case _: LinearRegressionModel.type => LocalLinearRegressionModel
      case _: RandomForestRegressionModel.type => LocalRandomForestRegressionModel
      case _: GBTRegressionModel.type => LocalGBTRegressor

      case _: KMeans.type => LocalKMeansModel
      case _ => throw new Exception(s"Unknown model: ${m.getClass}")
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy