com.tencent.angel.sona.ml.feature.Bucketizer.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.tencent.angel.sona.ml.feature
import java.{util => ju}
import org.apache.spark.SparkException
import com.tencent.angel.sona.ml.Model
import com.tencent.angel.sona.ml.attribute.NominalAttribute
import com.tencent.angel.sona.ml.param._
import com.tencent.angel.sona.ml.param.shared._
import com.tencent.angel.sona.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.util.SONASchemaUtils
import org.apache.spark.util.DatasetUtil
/**
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
*
* Since 2.3.0,
* `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
* when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The
* `splits` parameter is only used for single column usage, and `splitsArray` is for multiple
* columns.
*/
final class Bucketizer(override val uid: String)
extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol
with HasInputCols with HasOutputCols with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("bucketizer"))
/**
* Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets.
* A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which
* also includes y. Splits should be of length greater than or equal to 3 and strictly increasing.
* Values at -inf, inf must be explicitly provided to cover all Double values;
* otherwise, values outside the splits specified will be treated as errors.
*
* See also [[handleInvalid]], which can optionally create an additional bucket for NaN values.
*
* @group param
*/
val splits: DoubleArrayParam = new DoubleArrayParam(this, "splits",
"Split points for mapping continuous features into buckets. With n+1 splits, there are n " +
"buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " +
"bucket, which also includes y. The splits should be of length >= 3 and strictly " +
"increasing. Values at -inf, inf must be explicitly provided to cover all Double values; " +
"otherwise, values outside the splits specified will be treated as errors.",
Bucketizer.checkSplits)
/** @group getParam */
def getSplits: Array[Double] = $(splits)
/** @group setParam */
def setSplits(value: Array[Double]): this.type = set(splits, value)
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
/**
* Param for how to handle invalid entries. Options are 'skip' (filter out rows with
* invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special
* additional bucket). Note that in the multiple column case, the invalid handling is applied
* to all columns. That said for 'error' it will throw an error if any invalids are found in
* any column, for 'skip' it will skip rows with any invalids in any columns, etc.
* Default: "error"
*
* @group param
*/
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"how to handle invalid entries. Options are skip (filter out rows with invalid values), " +
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
ParamValidators.inArray(Bucketizer.supportedHandleInvalids))
/** @group setParam */
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
/**
* Parameter for specifying multiple splits parameters. Each element in this array can be used to
* map continuous features into buckets.
*
* @group param
*/
val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray",
"The array of split points for mapping continuous features into buckets for multiple " +
"columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " +
"splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " +
"The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " +
"explicitly provided to cover all Double values; otherwise, values outside the splits " +
"specified will be treated as errors.",
Bucketizer.checkSplitsArray)
/** @group getParam */
def getSplitsArray: Array[Array[Double]] = $(splitsArray)
/** @group setParam */
def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value)
/** @group setParam */
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
/** @group setParam */
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema)
val (inputColumns, outputColumns) = if (isSet(inputCols)) {
($(inputCols).toSeq, $(outputCols).toSeq)
} else {
(Seq($(inputCol)), Seq($(outputCol)))
}
val (filteredDataset, keepInvalid) = {
if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
// "skip" NaN option is set, will filter out NaN values in the dataset
(dataset.na.drop(inputColumns).toDF(), false)
} else {
(dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID)
}
}
val seqOfSplits = if (isSet(inputCols)) {
$(splitsArray).toSeq
} else {
Seq($(splits))
}
val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) =>
udf { (feature: Double) =>
Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid)
}
}
val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) =>
bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType))
}
val metadata = outputColumns.map { col =>
transformedSchema(col).metadata
}
var finalDataset = filteredDataset
newCols.indices.foreach { index =>
finalDataset = DatasetUtil.withColumn(finalDataset, outputColumns(index), newCols(index), metadata(index))
}
finalDataset
// filteredDataset.withColumns(outputColumns, newCols, metadata)
}
private def prepOutputField(splits: Array[Double], outputCol: String): StructField = {
val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray
val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true),
values = Some(buckets))
attr.toStructField()
}
override def transformSchema(schema: StructType): StructType = {
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits),
Seq(outputCols, splitsArray))
if (isSet(inputCols)) {
require(getInputCols.length == getOutputCols.length &&
getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " +
s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " +
s"equal lengths, but they have different lengths: " +
s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).")
var transformedSchema = schema
$(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) =>
SONASchemaUtils.checkNumericType(transformedSchema, inputCol)
transformedSchema = SONASchemaUtils.appendColumn(transformedSchema,
prepOutputField($(splitsArray)(idx), outputCol))
}
transformedSchema
} else {
SONASchemaUtils.checkNumericType(schema, $(inputCol))
SONASchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol)))
}
}
override def copy(extra: ParamMap): Bucketizer = {
defaultCopy[Bucketizer](extra).setParent(parent)
}
}
object Bucketizer extends DefaultParamsReadable[Bucketizer] {
private[sona] val SKIP_INVALID: String = "skip"
private[sona] val ERROR_INVALID: String = "error"
private[sona] val KEEP_INVALID: String = "keep"
private[sona] val supportedHandleInvalids: Array[String] =
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
/**
* We require splits to be of length >= 3 and to be in strictly increasing order.
* No NaN split should be accepted.
*/
private[sona] def checkSplits(splits: Array[Double]): Boolean = {
if (splits.length < 3) {
false
} else {
var i = 0
val n = splits.length - 1
while (i < n) {
if (splits(i) >= splits(i + 1) || splits(i).isNaN) return false
i += 1
}
!splits(n).isNaN
}
}
/**
* Check each splits in the splits array.
*/
private[sona] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = {
splitsArray.forall(checkSplits(_))
}
/**
* Binary searching in several buckets to place each data point.
*
* @param splits array of split points
* @param feature data point
* @param keepInvalid NaN flag.
* Set "true" to make an extra bucket for NaN values;
* Set "false" to report an error for NaN values
* @return bucket for each data point
* @throws SparkException if a feature is < splits.head or > splits.last
*/
private[sona] def binarySearchForBuckets(
splits: Array[Double],
feature: Double,
keepInvalid: Boolean): Double = {
if (feature.isNaN) {
if (keepInvalid) {
splits.length - 1
} else {
throw new SparkException("Bucketizer encountered NaN value. To handle or skip NaNs," +
" try setting Bucketizer.handleInvalid.")
}
} else if (feature == splits.last) {
splits.length - 2
} else {
val idx = ju.Arrays.binarySearch(splits, feature)
if (idx >= 0) {
idx
} else {
val insertPos = -idx - 1
if (insertPos == 0 || insertPos == splits.length) {
throw new SparkException(s"Feature value $feature out of Bucketizer bounds" +
s" [${splits.head}, ${splits.last}]. Check your features, or loosen " +
s"the lower/upper bound constraints.")
} else {
insertPos - 1
}
}
}
}
override def load(path: String): Bucketizer = super.load(path)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy