All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.train.AutoTrainedModel.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.train

import com.microsoft.ml.spark.core.serialize.ConstructorWritable
import org.apache.spark.ml.{Model, PipelineModel, Transformer}
import org.apache.spark.ml.param.ParamMap

/** Defines common inheritance and functions across auto trained models.
  */
abstract class AutoTrainedModel[TrainedModel <: Model[TrainedModel]](val model: PipelineModel)
  extends Model[TrainedModel] with ConstructorWritable[TrainedModel] {
  /** Retrieve the param map from the underlying model.
    * @return The param map from the underlying model.
    */
  def getParamMap: ParamMap = model.stages.last.extractParamMap()

  /** Retrieve the underlying model.
    * @return The underlying model.
    */
  def getModel: Transformer = model.stages.last
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy