All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.feature.QuantileDiscretizer.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.sona.ml.feature

import com.tencent.angel.sona.ml.Estimator
import com.tencent.angel.sona.ml.attribute.NominalAttribute
import com.tencent.angel.sona.ml.param.{DoubleParam, IntArrayParam, IntParam, Param, ParamMap, ParamValidators, Params}
import com.tencent.angel.sona.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import com.tencent.angel.sona.ml.util._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.SONASchemaUtils


/**
  * Params for [[QuantileDiscretizer]].
  */
private[sona] trait QuantileDiscretizerBase extends Params
  with HasHandleInvalid with HasInputCol with HasOutputCol {

  /**
    * Number of buckets (quantiles, or categories) into which data points are grouped. Must
    * be greater than or equal to 2.
    *
    * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values.
    *
    * default: 2
    *
    * @group param
    */
  val numBuckets = new IntParam(this, "numBuckets", "Number of buckets (quantiles, or " +
    "categories) into which data points are grouped. Must be >= 2.",
    ParamValidators.gtEq(2))
  setDefault(numBuckets -> 2)

  /** @group getParam */
  def getNumBuckets: Int = getOrDefault(numBuckets)

  /**
    * Array of number of buckets (quantiles, or categories) into which data points are grouped.
    * Each value must be greater than or equal to 2
    *
    * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values.
    *
    * @group param
    */
  val numBucketsArray = new IntArrayParam(this, "numBucketsArray", "Array of number of buckets " +
    "(quantiles, or categories) into which data points are grouped. This is for multiple " +
    "columns input. If transforming multiple columns and numBucketsArray is not set, but " +
    "numBuckets is set, then numBuckets will be applied across all columns.",
    (arrayOfNumBuckets: Array[Int]) => arrayOfNumBuckets.forall(ParamValidators.gtEq(2)))

  /** @group getParam */
  def getNumBucketsArray: Array[Int] = $(numBucketsArray)

  /**
    * Relative error (see documentation for
    * `org.apache.spark.sql.DataFrameStatFunctions.approxQuantile` for description)
    * Must be in the range [0, 1].
    * Note that in multiple columns case, relative error is applied to all columns.
    * default: 0.001
    *
    * @group param
    */
  val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " +
    "for the approximate quantile algorithm used to generate buckets. " +
    "Must be in the range [0, 1].", ParamValidators.inRange(0.0, 1.0))
  setDefault(relativeError -> 0.001)

  /** @group getParam */
  def getRelativeError: Double = getOrDefault(relativeError)

  /**
    * Param for how to handle invalid entries. Options are 'skip' (filter out rows with
    * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special
    * additional bucket). Note that in the multiple columns case, the invalid handling is applied
    * to all columns. That said for 'error' it will throw an error if any invalids are found in
    * any column, for 'skip' it will skip rows with any invalids in any columns, etc.
    * Default: "error"
    *
    * @group param
    */

  override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
    "how to handle invalid entries. Options are skip (filter out rows with invalid values), " +
      "error (throw an error), or keep (keep invalid values in a special additional bucket).",
    ParamValidators.inArray(Bucketizer.supportedHandleInvalids))
  setDefault(handleInvalid, Bucketizer.ERROR_INVALID)

}

/**
  * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
  * categorical features. The number of bins can be set using the `numBuckets` parameter. It is
  * possible that the number of buckets used will be smaller than this value, for example, if there
  * are too few distinct values of the input to create enough distinct quantiles.
  * Since 2.3.0, `QuantileDiscretizer` can map multiple columns at once by setting the `inputCols`
  * parameter. If both of the `inputCol` and `inputCols` parameters are set, an Exception will be
  * thrown. To specify the number of buckets for each column, the `numBucketsArray` parameter can
  * be set, or if the number of buckets should be the same across columns, `numBuckets` can be
  * set as a convenience.
  *
  * NaN handling:
  * null and NaN values will be ignored from the column during `QuantileDiscretizer` fitting. This
  * will produce a `Bucketizer` model for making predictions. During the transformation,
  * `Bucketizer` will raise an error when it finds NaN values in the dataset, but the user can
  * also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`.
  * If the user chooses to keep NaN values, they will be handled specially and placed into their own
  * bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3],
  * but NaNs will be counted in a special bucket[4].
  *
  * Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for
  * `org.apache.spark.sql.DataFrameStatFunctions.approxQuantile`
  * for a detailed description). The precision of the approximation can be controlled with the
  * `relativeError` parameter. The lower and upper bin bounds will be `-Infinity` and `+Infinity`,
  * covering all real values.
  */

final class QuantileDiscretizer(override val uid: String)
  extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable
    with HasInputCols with HasOutputCols {

  def this() = this(Identifiable.randomUID("quantileDiscretizer"))

  /** @group setParam */
  def setRelativeError(value: Double): this.type = set(relativeError, value)

  /** @group setParam */
  def setNumBuckets(value: Int): this.type = set(numBuckets, value)

  /** @group setParam */
  def setInputCol(value: String): this.type = set(inputCol, value)

  /** @group setParam */

  def setOutputCol(value: String): this.type = set(outputCol, value)

  /** @group setParam */
  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

  /** @group setParam */
  def setNumBucketsArray(value: Array[Int]): this.type = set(numBucketsArray, value)

  /** @group setParam */
  def setInputCols(value: Array[String]): this.type = set(inputCols, value)

  /** @group setParam */
  def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

  private[sona] def getInOutCols: (Array[String], Array[String]) = {
    require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) ||
      (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)),
      "QuantileDiscretizer only supports setting either inputCol/outputCol or" +
        "inputCols/outputCols."
    )

    if (isSet(inputCol)) {
      (Array($(inputCol)), Array($(outputCol)))
    } else {
      require($(inputCols).length == $(outputCols).length,
        "inputCols number do not match outputCols")
      ($(inputCols), $(outputCols))
    }
  }


  override def transformSchema(schema: StructType): StructType = {
    val (inputColNames, outputColNames) = getInOutCols
    val existingFields = schema.fields
    var outputFields = existingFields
    inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) =>
      SONASchemaUtils.checkNumericType(schema, inputColName)
      require(existingFields.forall(_.name != outputColName),
        s"Output column $outputColName already exists.")
      val attr = NominalAttribute.defaultAttr.withName(outputColName)
      outputFields :+= attr.toStructField()
    }
    StructType(outputFields)
  }

  override def fit(dataset: Dataset[_]): Bucketizer = {
    transformSchema(dataset.schema, logging = true)
    val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid))
    if (isSet(inputCols)) {
      val splitsArray = if (isSet(numBucketsArray)) {
        val probArrayPerCol = $(numBucketsArray).map { numOfBuckets =>
          (0.0 to 1.0 by 1.0 / numOfBuckets).toArray
        }

        val probabilityArray = probArrayPerCol.flatten.sorted.distinct
        val splitsArrayRaw = $(inputCols).map { inputCol =>
          dataset.stat.approxQuantile(inputCol, probabilityArray, $(relativeError))
        }
        //        val splitsArrayRaw = dataset.stat.approxQuantile($(inputCols),
        //          probabilityArray, $(relativeError))

        splitsArrayRaw.zip(probArrayPerCol).map { case (splits, probs) =>
          val probSet = probs.toSet
          val idxSet = probabilityArray.zipWithIndex.collect {
            case (p, idx) if probSet(p) =>
              idx
          }.toSet
          splits.zipWithIndex.collect {
            case (s, idx) if idxSet(idx) =>
              s
          }
        }
      } else {
        $(inputCols).map { inputCol =>
          dataset.stat.approxQuantile(inputCol,
            (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
        }
        //        dataset.stat.approxQuantile($(inputCols),
        //          (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
      }
      bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits))
    } else {
      val splits = dataset.stat.approxQuantile($(inputCol),
        (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
      bucketizer.setSplits(getDistinctSplits(splits))
    }
    copyValues(bucketizer.setParent(this))
  }

  private def getDistinctSplits(splits: Array[Double]): Array[Double] = {
    splits(0) = Double.NegativeInfinity
    splits(splits.length - 1) = Double.PositiveInfinity
    val distinctSplits = splits.distinct
    if (splits.length != distinctSplits.length) {
      log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" +
        s" buckets as a result.")
    }
    distinctSplits.sorted
  }

  override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)
}


object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
  override def load(path: String): QuantileDiscretizer = super.load(path)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy