com.microsoft.ml.spark.automl.ParamSpace.scala Maven / Gradle / Ivy
The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.
package com.microsoft.ml.spark.automl
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, ParamSpace}
/** Represents a distribution of values.
* @tparam T The type T of the values generated.
*/
abstract class Dist[T] {
def getNext(): T
def getParamPair(param: Param[_]): ParamPair[_] = {
ParamPair(param.asInstanceOf[Param[T]], getNext())
}
}
/** Represents a parameter grid for tuning with discrete values.
* Can be generated with the ParamGridBuilder.
* @param paramValues The parameter values generated by ParamGridBuilder.
*/
class GridSpace(val paramValues: Array[ParamMap]) extends ParamSpace {
override def paramMaps: Iterator[ParamMap] = paramValues.toIterator
}
/** Represents a generator of parameters with specified distributions added by the HyperparamBuilder.
* @param paramDistributions A list of parameters and their distributions generated by HyperparamBuilder.
*/
class RandomSpace(paramDistributions: Array[(Param[_], Dist[_])]) extends ParamSpace {
val paramMaps = new Iterator[ParamMap] {
override def hasNext: Boolean = true
override def next(): ParamMap =
ParamMap(paramDistributions.map(paramDist => paramDist._2.getParamPair(paramDist._1)): _*)
}
}