All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.feature.ChiSqSelector.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.sona.ml.feature

import org.apache.hadoop.fs.Path
import com.tencent.angel.sona.ml
import com.tencent.angel.sona.ml.{Estimator, Model}
import com.tencent.angel.sona.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.linalg
import org.apache.spark.linalg._
import com.tencent.angel.sona.ml.param.{DoubleParam, IntParam, Param, ParamMap, ParamValidators, Params}
import com.tencent.angel.sona.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasOutputCol}
import com.tencent.angel.sona.ml.stat.Statistics
import com.tencent.angel.sona.ml.stat.test.ChiSqTestResult
import com.tencent.angel.sona.ml.util._
import org.apache.spark.sql.util.SONASchemaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.util.DatasetUtil

import scala.collection.mutable

/**
  * Params for [[ChiSqSelector]] and [[ChiSqSelectorModel]].
  */
private[sona] trait ChiSqSelectorParams extends Params
  with HasFeaturesCol with HasOutputCol with HasLabelCol {

  /**
    * Number of features that selector will select, ordered by ascending p-value. If the
    * number of features is less than numTopFeatures, then this will select all features.
    * Only applicable when selectorType = "numTopFeatures".
    * The default value of numTopFeatures is 50.
    *
    * @group param
    */

  final val numTopFeatures = new IntParam(this, "numTopFeatures",
    "Number of features that selector will select, ordered by ascending p-value. If the" +
      " number of features is < numTopFeatures, then this will select all features.",
    ParamValidators.gtEq(1))
  setDefault(numTopFeatures -> 50)

  /** @group getParam */

  def getNumTopFeatures: Int = $(numTopFeatures)

  /**
    * Percentile of features that selector will select, ordered by statistics value descending.
    * Only applicable when selectorType = "percentile".
    * Default value is 0.1.
    *
    * @group param
    */

  final val percentile = new DoubleParam(this, "percentile",
    "Percentile of features that selector will select, ordered by ascending p-value.",
    ParamValidators.inRange(0, 1))
  setDefault(percentile -> 0.1)

  /** @group getParam */

  def getPercentile: Double = $(percentile)

  /**
    * The highest p-value for features to be kept.
    * Only applicable when selectorType = "fpr".
    * Default value is 0.05.
    *
    * @group param
    */

  final val fpr = new DoubleParam(this, "fpr", "The highest p-value for features to be kept.",
    ParamValidators.inRange(0, 1))
  setDefault(fpr -> 0.05)

  /** @group getParam */

  def getFpr: Double = $(fpr)

  /**
    * The upper bound of the expected false discovery rate.
    * Only applicable when selectorType = "fdr".
    * Default value is 0.05.
    *
    * @group param
    */

  final val fdr = new DoubleParam(this, "fdr",
    "The upper bound of the expected false discovery rate.", ParamValidators.inRange(0, 1))
  setDefault(fdr -> 0.05)

  /** @group getParam */
  def getFdr: Double = $(fdr)

  /**
    * The upper bound of the expected family-wise error rate.
    * Only applicable when selectorType = "fwe".
    * Default value is 0.05.
    *
    * @group param
    */

  final val fwe = new DoubleParam(this, "fwe",
    "The upper bound of the expected family-wise error rate.", ParamValidators.inRange(0, 1))
  setDefault(fwe -> 0.05)

  /** @group getParam */
  def getFwe: Double = $(fwe)

  /**
    * The selector type of the ChisqSelector.
    * Supported options: "numTopFeatures" (default), "percentile", "fpr", "fdr", "fwe".
    *
    * @group param
    */

  final val selectorType = new Param[String](this, "selectorType",
    "The selector type of the ChisqSelector. " +
      "Supported options: " + ChiSqSelector.supportedSelectorTypes.mkString(", "),
    ParamValidators.inArray[String](ChiSqSelector.supportedSelectorTypes))
  setDefault(selectorType -> ChiSqSelector.NumTopFeatures)

  /** @group getParam */

  def getSelectorType: String = $(selectorType)
}

/**
  * Chi-Squared feature selection, which selects categorical features to use for predicting a
  * categorical label.
  * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`,
  * `fdr`, `fwe`.
  *  - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test.
  *  - `percentile` is similar but chooses a fraction of all features instead of a fixed number.
  *  - `fpr` chooses all features whose p-value are below a threshold, thus controlling the false
  * positive rate of selection.
  *  - `fdr` uses the [Benjamini-Hochberg procedure]
  * (https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure)
  * to choose all features whose false discovery rate is below a threshold.
  *  - `fwe` chooses all features whose p-values are below a threshold. The threshold is scaled by
  * 1/numFeatures, thus controlling the family-wise error rate of selection.
  * By default, the selection method is `numTopFeatures`, with the default number of top features
  * set to 50.
  */

final class ChiSqSelector(override val uid: String)
  extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable {


  def this() = this(Identifiable.randomUID("chiSqSelector"))

  /** @group setParam */

  def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)

  /** @group setParam */

  def setPercentile(value: Double): this.type = set(percentile, value)

  /** @group setParam */

  def setFpr(value: Double): this.type = set(fpr, value)

  /** @group setParam */

  def setFdr(value: Double): this.type = set(fdr, value)

  /** @group setParam */

  def setFwe(value: Double): this.type = set(fwe, value)

  /** @group setParam */

  def setSelectorType(value: String): this.type = set(selectorType, value)

  /** @group setParam */

  def setFeaturesCol(value: String): this.type = set(featuresCol, value)

  /** @group setParam */

  def setOutputCol(value: String): this.type = set(outputCol, value)

  /** @group setParam */

  def setLabelCol(value: String): this.type = set(labelCol, value)


  override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
    transformSchema(dataset.schema, logging = true)
    val input: RDD[LabeledPoint] =
      dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
        case Row(label: Double, features: linalg.Vector) =>
          ml.feature.LabeledPoint(label, features)
      }

    val chiSqTestResult = Statistics.chiSqTest(input).zipWithIndex
    val features = $(selectorType) match {
      case ChiSqSelector.NumTopFeatures =>
        chiSqTestResult
          .sortBy { case (res, _) => res.pValue }
          .take($(numTopFeatures))
      case ChiSqSelector.Percentile =>
        chiSqTestResult
          .sortBy { case (res, _) => res.pValue }
          .take((chiSqTestResult.length * $(percentile)).toInt)
      case ChiSqSelector.FPR =>
        chiSqTestResult.filter { case (res, _) => res.pValue < $(fpr) }
      case ChiSqSelector.FDR =>
        // This uses the Benjamini-Hochberg procedure.
        // https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure
        val tempRes = chiSqTestResult
          .sortBy { case (res, _) => res.pValue }
        val selected = tempRes
          .zipWithIndex
          .filter { case ((res, _), index) =>
            res.pValue <= $(fdr) * (index + 1) / chiSqTestResult.length
          }
        if (selected.isEmpty) {
          Array.empty[(ChiSqTestResult, Int)]
        } else {
          val maxIndex = selected.map(_._2).max
          tempRes.take(maxIndex + 1)
        }
      case ChiSqSelector.FWE =>
        chiSqTestResult
          .filter { case (res, _) => res.pValue < $ {
            fwe
          } / chiSqTestResult.length
          }
      case errorType =>
        throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
    }
    val indices = features.map { case (_, index) => index }

    copyValues(new ChiSqSelectorModel(uid, indices).setParent(this))
  }


  override def transformSchema(schema: StructType): StructType = {
    val otherPairs = ChiSqSelector.supportedSelectorTypes.filter(_ != $(selectorType))
    otherPairs.foreach { paramName: String =>
      if (isSet(getParam(paramName))) {
        logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")
      }
    }
    SONASchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
    SONASchemaUtils.checkNumericType(schema, $(labelCol))
    SONASchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
  }


  override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra)
}


object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] {


  override def load(path: String): ChiSqSelector = super.load(path)

  /** String name for `numTopFeatures` selector type. */
  val NumTopFeatures: String = "numTopFeatures"

  /** String name for `percentile` selector type. */
  val Percentile: String = "percentile"

  /** String name for `fpr` selector type. */
  val FPR: String = "fpr"

  /** String name for `fdr` selector type. */
  val FDR: String = "fdr"

  /** String name for `fwe` selector type. */
  val FWE: String = "fwe"

  val supportedSelectorTypes: Array[String] = Array(NumTopFeatures, Percentile, FPR, FDR, FWE)
}

/**
  * Model fitted by [[ChiSqSelector]].
  */

final class ChiSqSelectorModel private[angel](
                                               override val uid: String,
                                               val selectedFeatures: Array[Int])
  extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable {

  import ChiSqSelectorModel._

  /** @group setParam */

  def setFeaturesCol(value: String): this.type = set(featuresCol, value)

  /** @group setParam */

  def setOutputCol(value: String): this.type = set(outputCol, value)


  override def transform(dataset: Dataset[_]): DataFrame = {
    val transformedSchema = transformSchema(dataset.schema, logging = true)
    val newField = transformedSchema.last

    // TODO: Make the transformer natively in ml framework to avoid extra conversion.
    val transformer: linalg.Vector => linalg.Vector = v => compress(v)

    val selector = udf(transformer)
    DatasetUtil.withColumn(dataset, $(outputCol), selector(col($(featuresCol))), newField.metadata)
  }


  override def transformSchema(schema: StructType): StructType = {
    SONASchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
    val newField = prepOutputField(schema)
    val outputFields = schema.fields :+ newField
    StructType(outputFields)
  }

  /**
    * Prepare the output column field, including per-feature metadata.
    */
  private def prepOutputField(schema: StructType): StructField = {
    val selector = selectedFeatures.toSet
    val origAttrGroup = AttributeGroup.fromStructField(schema($(featuresCol)))
    val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
      origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1)
    } else {
      Array.fill[Attribute](selector.size)(NominalAttribute.defaultAttr)
    }
    val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
    newAttributeGroup.toStructField
  }


  override def copy(extra: ParamMap): ChiSqSelectorModel = {
    val copied = new ChiSqSelectorModel(uid, selectedFeatures)
    copyValues(copied, extra).setParent(parent)
  }


  override def write: MLWriter = new ChiSqSelectorModelWriter(this)

  private val filterIndices = selectedFeatures.sorted

  protected def formatVersion: String = "1.0"

  /**
    * Returns a vector with features filtered.
    * Preserves the order of filtered features the same as their indices are stored.
    * Might be moved to Vector as .slice
    *
    * @param features vector
    */
  private def compress(features: linalg.Vector): linalg.Vector = {
    features match {
      case IntSparseVector(_, indices, values) =>
        val newSize = filterIndices.length
        val newValues = new mutable.ArrayBuilder.ofDouble
        val newIndices = new mutable.ArrayBuilder.ofInt
        var i = 0
        var j = 0
        var indicesIdx = 0
        var filterIndicesIdx = 0
        while (i < indices.length && j < filterIndices.length) {
          indicesIdx = indices(i)
          filterIndicesIdx = filterIndices(j)
          if (indicesIdx == filterIndicesIdx) {
            newIndices += j
            newValues += values(i)
            j += 1
            i += 1
          } else {
            if (indicesIdx > filterIndicesIdx) {
              j += 1
            } else {
              i += 1
            }
          }
        }
        // TODO: Sparse representation might be ineffective if (newSize ~= newValues.size)
        Vectors.sparse(newSize, newIndices.result(), newValues.result())
      case LongSparseVector(_, indices, values) =>
        val newSize = filterIndices.length
        val newValues = new mutable.ArrayBuilder.ofDouble
        val newIndices = new mutable.ArrayBuilder.ofLong
        var i = 0
        var j = 0
        var indicesIdx = 0L
        var filterIndicesIdx = 0
        while (i < indices.length && j < filterIndices.length) {
          indicesIdx = indices(i)
          filterIndicesIdx = filterIndices(j)
          if (indicesIdx == filterIndicesIdx) {
            newIndices += j
            newValues += values(i)
            j += 1
            i += 1
          } else {
            if (indicesIdx > filterIndicesIdx) {
              j += 1
            } else {
              i += 1
            }
          }
        }
        // TODO: Sparse representation might be ineffective if (newSize ~= newValues.size)
        Vectors.sparse(newSize, newIndices.result(), newValues.result())
      case DenseVector(_) =>
        val values = features.toArray
        Vectors.dense(filterIndices.map(i => values(i)))
      case other =>
        throw new UnsupportedOperationException(
          s"Only sparse and dense vectors are supported but got ${other.getClass}.")
    }
  }
}


object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {

  private[ChiSqSelectorModel]
  class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter {

    private case class Data(selectedFeatures: Seq[Int])

    override protected def saveImpl(path: String): Unit = {
      DefaultParamsWriter.saveMetadata(instance, path, sc)
      val data = Data(instance.selectedFeatures.toSeq)
      val dataPath = new Path(path, "data").toString
      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
    }
  }

  private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] {

    private val className = classOf[ChiSqSelectorModel].getName

    override def load(path: String): ChiSqSelectorModel = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
      val dataPath = new Path(path, "data").toString
      val data = sparkSession.read.parquet(dataPath).select("selectedFeatures").head()
      val selectedFeatures = data.getAs[Seq[Int]](0).toArray
      val model = new ChiSqSelectorModel(metadata.uid, selectedFeatures)
      metadata.getAndSetParams(model)
      model
    }
  }


  override def read: MLReader[ChiSqSelectorModel] = new ChiSqSelectorModelReader


  override def load(path: String): ChiSqSelectorModel = super.load(path)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy