All Downloads are FREE. Search and download functionalities are using the official Maven repository.

frameless.ml.internals.TreesInputsChecker.scala Maven / Gradle / Ivy

The newest version!
package frameless
package ml
package internals

import shapeless.ops.hlist.Length
import shapeless.{HList, LabelledGeneric, Nat, Witness}
import org.apache.spark.ml.linalg._

import scala.annotation.implicitNotFound

/**
  * Can be used for all tree-based ML algorithm (decision tree, random forest, gradient-boosted trees)
  */
@implicitNotFound(
  msg = "Cannot prove that ${Inputs} is a valid input type. " +
    "Input type must only contain a field of type Double (the label) and a field of type " +
    "org.apache.spark.ml.linalg.Vector (the features)."
)
trait TreesInputsChecker[Inputs] {
  val featuresCol: String
  val labelCol: String
}

object TreesInputsChecker {

  implicit def checkTreesInputs[
  Inputs,
  InputsRec <: HList,
  LabelK <: Symbol,
  FeaturesK <: Symbol](
    implicit
    i0: LabelledGeneric.Aux[Inputs, InputsRec],
    i1: Length.Aux[InputsRec, Nat._2],
    i2: SelectorByValue.Aux[InputsRec, Double, LabelK],
    i3: Witness.Aux[LabelK],
    i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK],
    i5: Witness.Aux[FeaturesK]
  ): TreesInputsChecker[Inputs] = {
    new TreesInputsChecker[Inputs] {
      val labelCol: String = implicitly[Witness.Aux[LabelK]].value.name
      val featuresCol: String = implicitly[Witness.Aux[FeaturesK]].value.name
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy