All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.train.AutoTrainer.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.train

import com.microsoft.ml.spark.core.contracts.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.{ComplexParamsWritable, Estimator, Model}
import org.apache.spark.ml.param.{EstimatorParam, IntParam}

/** Defines common inheritance and parameters across trainers.
  */
trait AutoTrainer[TrainedModel <: Model[TrainedModel]] extends Estimator[TrainedModel]
  with HasLabelCol with ComplexParamsWritable with HasFeaturesCol {

  /** Doc for model to run.
    */
  def modelDoc: String

  /** Number of features to hash to
    * @group param
    */
  val numFeatures = new IntParam(this, "numFeatures", "Number of features to hash to")
  setDefault(numFeatures -> 0)

  /** @group getParam */
  def getNumFeatures: Int = $(numFeatures)
  /** @group setParam */
  def setNumFeatures(value: Int): this.type = set(numFeatures, value)

  /** Model to run.  See doc on derived classes.
    * @group param
    */
  val model = new EstimatorParam(this, "model", modelDoc)
  /** @group getParam */
  def getModel: Estimator[_ <: Model[_]] = $(model)
  /** @group setParam */
  def setModel(value: Estimator[_ <: Model[_]]): this.type = set(model, value)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy