ai.h2o.sparkling.ml.features.H2OTargetEncoder.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.h2o.sparkling.ml.features
import ai.h2o.sparkling.H2OContext
import ai.h2o.sparkling.backend.utils.RestCommunication
import ai.h2o.sparkling.ml.internals.H2OModel
import ai.h2o.sparkling.ml.models.{H2OTargetEncoderBase, H2OTargetEncoderModel}
import ai.h2o.sparkling.ml.params.{EnumParamValidator, H2OTargetEncoderProblemType}
import ai.h2o.sparkling.ml.utils.EstimatorCommonUtils
import ai.h2o.targetencoding._
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{BooleanType, DoubleType, StringType}
class H2OTargetEncoder(override val uid: String)
extends Estimator[H2OTargetEncoderModel]
with H2OTargetEncoderBase
with DefaultParamsWritable
with RestCommunication
with EstimatorCommonUtils {
def this() = this(Identifiable.randomUID("H2OTargetEncoder"))
override def fit(dataset: Dataset[_]): H2OTargetEncoderModel = {
val h2oModel = if (getInputCols().isEmpty) {
None
} else {
val h2oContext = H2OContext.ensure(
"H2OContext needs to be created in order to use target encoding. Please create one as H2OContext.getOrCreate().")
val problemType = H2OTargetEncoderProblemType.valueOf(getProblemType())
val toConvertDF = if (problemType == H2OTargetEncoderProblemType.Regression) {
dataset.withColumn(getLabelCol(), col(getLabelCol()).cast(DoubleType))
} else {
dataset.toDF()
}
val input = h2oContext.asH2OFrame(toConvertDF)
val distinctInputCols = getInputCols().flatten.distinct
val toCategorical = if (isLabelColCategorical(problemType, dataset)) {
distinctInputCols ++ Seq(getLabelCol())
} else {
distinctInputCols
}
val converted = input.convertColumnsToCategorical(toCategorical)
val params = Map(
"data_leakage_handling" -> getHoldoutStrategy(),
"blending" -> getBlendedAvgEnabled(),
"inflection_point" -> getBlendedAvgInflectionPoint(),
"smoothing" -> getBlendedAvgSmoothing(),
"response_column" -> getLabelCol(),
"fold_column" -> getFoldCol(),
"columns_to_encode" -> getInputCols(),
"seed" -> getNoiseSeed(),
"training_frame" -> converted.frameId)
val targetEncoderModelId = trainAndGetDestinationKey(s"/3/ModelBuilders/targetencoder", params)
input.delete()
Some(H2OModel(targetEncoderModelId))
}
val model = new H2OTargetEncoderModel(uid, h2oModel).setParent(this)
copyValues(model)
}
private def isLabelColCategorical(problemType: H2OTargetEncoderProblemType, dataset: Dataset[_]): Boolean = {
problemType == H2OTargetEncoderProblemType.Classification ||
(problemType == H2OTargetEncoderProblemType.Auto && {
val dataType = dataset.select(col((getLabelCol()))).schema.fields.head.dataType
dataType.isInstanceOf[StringType] || dataType.isInstanceOf[BooleanType]
})
}
override def copy(extra: ParamMap): H2OTargetEncoder = defaultCopy(extra)
//
// Parameter Setters
//
def setFoldCol(value: String): this.type = set(foldCol, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
def setInputCols(values: Array[String]): this.type = set(inputCols, values.map(v => Array(v)))
def setInputCols(groups: Array[Array[String]]): this.type = set(inputCols, groups)
def setOutputCols(values: Array[String]): this.type = set(outputCols, values)
def setHoldoutStrategy(value: String): this.type = {
set(
holdoutStrategy,
EnumParamValidator.getValidatedEnumValue[TargetEncoderModel.DataLeakageHandlingStrategy](value))
}
def setBlendedAvgEnabled(value: Boolean): this.type = set(blendedAvgEnabled, value)
def setBlendedAvgInflectionPoint(value: Double): this.type = set(blendedAvgInflectionPoint, value)
def setBlendedAvgSmoothing(value: Double): this.type = {
require(value > 0.0, "The smoothing value has to be a positive number.")
set(blendedAvgSmoothing, value)
}
def setNoise(value: Double): this.type = {
require(value >= 0.0, "Noise can't be a negative value.")
set(noise, value)
}
def setNoiseSeed(value: Long): this.type = set(noiseSeed, value)
def setProblemType(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[H2OTargetEncoderProblemType](value)
set(problemType, validated)
}
}
object H2OTargetEncoder extends DefaultParamsReadable[H2OTargetEncoder]
© 2015 - 2024 Weber Informatics LLC | Privacy Policy