com.tencent.angel.sona.ml.feature.RFormula.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.tencent.angel.sona.ml.feature
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.Path
import com.tencent.angel.sona.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import com.tencent.angel.sona.ml.attribute.AttributeGroup
import org.apache.spark.linalg.{Vector, VectorUDT}
import com.tencent.angel.sona.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import com.tencent.angel.sona.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol}
import com.tencent.angel.sona.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Compatible
/**
* Base trait for [[RFormula]] and [[RFormulaModel]].
*/
private[sona] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with HasHandleInvalid {
/**
* R formula parameter. The formula is provided in string form.
*
* @group param
*/
val formula: Param[String] = new Param(this, "formula", "R model formula")
/** @group getParam */
def getFormula: String = $(formula)
/**
* Force to index label whether it is numeric or string type.
* Usually we index label only when it is string type.
* If the formula was used by classification algorithms,
* we can force to index label even it is numeric type by setting this param with true.
* Default: false.
*
* @group param
*/
val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel",
"Force to index label whether it is numeric or string")
setDefault(forceIndexLabel -> false)
/** @group getParam */
def getForceIndexLabel: Boolean = $(forceIndexLabel)
/**
* Param for how to handle invalid data (unseen or NULL values) in features and label column
* of string type. Options are 'skip' (filter out rows with invalid data),
* 'error' (throw an error), or 'keep' (put invalid data in a special additional
* bucket, at index numLabels).
* Default: "error"
*
* @group param
*/
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"How to handle invalid data (unseen or NULL values) in features and label column of string " +
"type. Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
setDefault(handleInvalid, StringIndexer.ERROR_INVALID)
/**
* Param for how to order categories of a string FEATURE column used by `StringIndexer`.
* The last category after ordering is dropped when encoding strings.
* Supported options: 'frequencyDesc', 'frequencyAsc', 'alphabetDesc', 'alphabetAsc'.
* The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', `RFormula`
* drops the same category as R when encoding strings.
*
* The options are explained using an example `'b', 'a', 'b', 'a', 'c', 'b'`:
* {{{
* +-----------------+---------------------------------------+----------------------------------+
* | Option | Category mapped to 0 by StringIndexer | Category dropped by RFormula |
* +-----------------+---------------------------------------+----------------------------------+
* | 'frequencyDesc' | most frequent category ('b') | least frequent category ('c') |
* | 'frequencyAsc' | least frequent category ('c') | most frequent category ('b') |
* | 'alphabetDesc' | last alphabetical category ('c') | first alphabetical category ('a')|
* | 'alphabetAsc' | first alphabetical category ('a') | last alphabetical category ('c') |
* +-----------------+---------------------------------------+----------------------------------+
* }}}
* Note that this ordering option is NOT used for the label column. When the label column is
* indexed, it uses the default descending frequency ordering in `StringIndexer`.
*
* @group param
*/
final val stringIndexerOrderType: Param[String] = new Param(this, "stringIndexerOrderType",
"How to order categories of a string FEATURE column used by StringIndexer. " +
"The last category after ordering is dropped when encoding strings. " +
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}. " +
"The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', " +
"RFormula drops the same category as R when encoding strings.",
ParamValidators.inArray(StringIndexer.supportedStringOrderType))
setDefault(stringIndexerOrderType, StringIndexer.frequencyDesc)
/** @group getParam */
def getStringIndexerOrderType: String = $(stringIndexerOrderType)
protected def hasLabelCol(schema: StructType): Boolean = {
schema.map(_.name).contains($(labelCol))
}
}
/**
* :: Experimental ::
* Implements the transforms required for fitting a dataset against an R model formula. Currently
* we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see
* the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
*
* The basic operators are:
* - `~` separate target and terms
* - `+` concat terms, "+ 0" means removing intercept
* - `-` remove a term, "- 1" means removing intercept
* - `:` interaction (multiplication for numeric values, or binarized categorical values)
* - `.` all columns except target
*
* Suppose `a` and `b` are double columns, we use the following simple examples
* to illustrate the effect of `RFormula`:
* - `y ~ a + b` means model `y ~ w0 + w1 * a + w2 * b` where `w0` is the intercept and `w1, w2`
* are coefficients.
* - `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3`
* are coefficients.
*
* RFormula produces a vector column of features and a double or string column of label.
* Like when formulas are used in R for linear regression, string input columns will be one-hot
* encoded, and numeric columns will be cast to doubles.
* If the label column is of type string, it will be first transformed to double with
* `StringIndexer`. If the label column does not exist in the DataFrame, the output label column
* will be created from the specified response variable in the formula.
*/
class RFormula(override val uid: String)
extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("rFormula"))
/**
* Sets the formula to use for this transformer. Must be called before use.
*
* @group setParam
* @param value an R formula in string form (e.g. "y ~ x + z")
*/
def setFormula(value: String): this.type = set(formula, value)
/** @group setParam */
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
/** @group setParam */
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
/** @group setParam */
def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value)
/** @group setParam */
def setStringIndexerOrderType(value: String): this.type = set(stringIndexerOrderType, value)
/** Whether the formula specifies fitting an intercept. */
private[sona] def hasIntercept: Boolean = {
require(isDefined(formula), "Formula must be defined first.")
RFormulaParser.parse($(formula)).hasIntercept
}
override def fit(dataset: Dataset[_]): RFormulaModel = {
transformSchema(dataset.schema, logging = true)
require(isDefined(formula), "Formula must be defined first.")
val parsedFormula = RFormulaParser.parse($(formula))
val resolvedFormula = parsedFormula.resolve(dataset.schema)
val encoderStages = ArrayBuffer[PipelineStage]()
val oneHotEncodeColumns = ArrayBuffer[(String, String)]()
val prefixesToRewrite = mutable.Map[String, String]()
val tempColumns = ArrayBuffer[String]()
def tmpColumn(category: String): String = {
val col = Identifiable.randomUID(category)
tempColumns += col
col
}
// First we index each string column referenced by the input terms.
val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term =>
dataset.schema(term).dataType match {
case _: StringType =>
val indexCol = tmpColumn("stridx")
encoderStages += new StringIndexer()
.setInputCol(term)
.setOutputCol(indexCol)
.setStringOrderType($(stringIndexerOrderType))
.setHandleInvalid($(handleInvalid))
prefixesToRewrite(indexCol + "_") = term + "_"
(term, indexCol)
case _: VectorUDT =>
val group = AttributeGroup.fromStructField(dataset.schema(term))
val size = if (group.size < 0) {
dataset.select(term).first().getAs[Vector](0).size
} else {
group.size
}
encoderStages += new VectorSizeHint(uid)
.setHandleInvalid("optimistic")
.setInputCol(term)
.setSize(size)
(term, term)
case _ =>
(term, term)
}
}.toMap
// Then we handle one-hot encoding and interactions between terms.
var keepReferenceCategory = false
val encodedTerms = resolvedFormula.terms.map {
case Seq(term) if dataset.schema(term).dataType == StringType =>
val encodedCol = tmpColumn("onehot")
// Formula w/o intercept, one of the categories in the first category feature is
// being used as reference category, we will not drop any category for that feature.
if (!hasIntercept && !keepReferenceCategory) {
encoderStages += new OneHotEncoderEstimator(uid)
.setInputCols(Array(indexed(term)))
.setOutputCols(Array(encodedCol))
.setDropLast(false)
keepReferenceCategory = true
} else {
oneHotEncodeColumns += indexed(term) -> encodedCol
}
prefixesToRewrite(encodedCol + "_") = term + "_"
encodedCol
case Seq(term) =>
term
case terms =>
val interactionCol = tmpColumn("interaction")
encoderStages += new Interaction()
.setInputCols(terms.map(indexed).toArray)
.setOutputCol(interactionCol)
prefixesToRewrite(interactionCol + "_") = ""
interactionCol
}
if (oneHotEncodeColumns.nonEmpty) {
val (inputCols, outputCols) = oneHotEncodeColumns.toArray.unzip
encoderStages += new OneHotEncoderEstimator(uid)
.setInputCols(inputCols)
.setOutputCols(outputCols)
.setDropLast(true)
}
encoderStages += new VectorAssembler(uid)
.setInputCols(encodedTerms.toArray)
.setOutputCol($(featuresCol))
.setHandleInvalid($(handleInvalid))
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
encoderStages += new ColumnPruner(tempColumns.toSet)
if ((dataset.schema.fieldNames.contains(resolvedFormula.label) &&
dataset.schema(resolvedFormula.label).dataType == StringType) || $(forceIndexLabel)) {
encoderStages += new StringIndexer()
.setInputCol(resolvedFormula.label)
.setOutputCol($(labelCol))
.setHandleInvalid($(handleInvalid))
}
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
}
// optimistic schema; does not contain any ML attributes
override def transformSchema(schema: StructType): StructType = {
require(!hasLabelCol(schema) || !$(forceIndexLabel),
"If label column already exists, forceIndexLabel can not be set with true.")
if (hasLabelCol(schema)) {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
} else {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+
StructField($(labelCol), DoubleType, true))
}
}
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
override def toString: String = s"RFormula(${get(formula).getOrElse("")}) (uid=$uid)"
}
object RFormula extends DefaultParamsReadable[RFormula] {
override def load(path: String): RFormula = super.load(path)
}
/**
* :: Experimental ::
* Model fitted by [[RFormula]]. Fitting is required to determine the factor levels of
* formula terms.
*/
class RFormulaModel private[sona](
override val uid: String,
private[sona] val resolvedFormula: ResolvedRFormula,
private[sona] val pipelineModel: PipelineModel)
extends Model[RFormulaModel] with RFormulaBase with MLWritable {
override def transform(dataset: Dataset[_]): DataFrame = {
checkCanTransform(dataset.schema)
transformLabel(pipelineModel.transform(dataset))
}
override def transformSchema(schema: StructType): StructType = {
checkCanTransform(schema)
val withFeatures = pipelineModel.transformSchema(schema)
if (resolvedFormula.label.isEmpty || hasLabelCol(withFeatures)) {
withFeatures
} else if (schema.exists(_.name == resolvedFormula.label)) {
val nullable = schema(resolvedFormula.label).dataType match {
case _: NumericType | BooleanType => false
case _ => true
}
StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
} else {
// Ignore the label field. This is a hack so that this transformer can also work on test
// datasets in a Pipeline.
withFeatures
}
}
override def copy(extra: ParamMap): RFormulaModel = {
val copied = new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(parent)
copyValues(copied, extra)
}
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
private def transformLabel(dataset: Dataset[_]): DataFrame = {
val labelName = resolvedFormula.label
if (labelName.isEmpty || hasLabelCol(dataset.schema)) {
dataset.toDF
} else if (dataset.schema.exists(_.name == labelName)) {
dataset.schema(labelName).dataType match {
case _: NumericType | BooleanType =>
dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
case other =>
throw new IllegalArgumentException("Unsupported type for label: " + other)
}
} else {
// Ignore the label field. This is a hack so that this transformer can also work on test
// datasets in a Pipeline.
dataset.toDF
}
}
private def checkCanTransform(schema: StructType) {
val columnNames = schema.map(_.name)
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
require(
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType],
s"Label column already exists and is not of type ${Compatible.numericTypeSimpleString}.")
}
override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this)
}
object RFormulaModel extends MLReadable[RFormulaModel] {
override def read: MLReader[RFormulaModel] = new RFormulaModelReader
override def load(path: String): RFormulaModel = super.load(path)
/** [[MLWriter]] instance for [[RFormulaModel]] */
private[RFormulaModel] class RFormulaModelWriter(instance: RFormulaModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: resolvedFormula
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(instance.resolvedFormula))
.repartition(1).write.parquet(dataPath)
// Save pipeline model
val pmPath = new Path(path, "pipelineModel").toString
instance.pipelineModel.save(pmPath)
}
}
private class RFormulaModelReader extends MLReader[RFormulaModel] {
/** Checked against metadata when loading model */
private val className = classOf[RFormulaModel].getName
override def load(path: String): RFormulaModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("label", "terms", "hasIntercept").head()
val label = data.getString(0)
val terms = data.getAs[Seq[Seq[String]]](1)
val hasIntercept = data.getBoolean(2)
val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept)
val pmPath = new Path(path, "pipelineModel").toString
val pipelineModel = PipelineModel.load(pmPath)
val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel)
metadata.getAndSetParams(model)
model
}
}
}
/**
* Utility transformer for removing temporary columns from a DataFrame.
* TODO(ekl) make this a public transformer
*/
private class ColumnPruner(override val uid: String, val columnsToPrune: Set[String])
extends Transformer with MLWritable {
def this(columnsToPrune: Set[String]) =
this(Identifiable.randomUID("columnPruner"), columnsToPrune)
override def transform(dataset: Dataset[_]): DataFrame = {
val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
dataset.select(columnsToKeep.map(dataset.col): _*)
}
override def transformSchema(schema: StructType): StructType = {
StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
}
override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this)
}
private object ColumnPruner extends MLReadable[ColumnPruner] {
override def read: MLReader[ColumnPruner] = new ColumnPrunerReader
override def load(path: String): ColumnPruner = super.load(path)
/** [[MLWriter]] instance for [[ColumnPruner]] */
private[ColumnPruner] class ColumnPrunerWriter(instance: ColumnPruner) extends MLWriter {
private case class Data(columnsToPrune: Seq[String])
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: columnsToPrune
val data = Data(instance.columnsToPrune.toSeq)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class ColumnPrunerReader extends MLReader[ColumnPruner] {
/** Checked against metadata when loading model */
private val className = classOf[ColumnPruner].getName
override def load(path: String): ColumnPruner = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("columnsToPrune").head()
val columnsToPrune = data.getAs[Seq[String]](0).toSet
val pruner = new ColumnPruner(metadata.uid, columnsToPrune)
metadata.getAndSetParams(pruner)
pruner
}
}
}
/**
* Utility transformer that rewrites Vector attribute names via prefix replacement. For example,
* it can rewrite attribute names starting with 'foo_' to start with 'bar_' instead.
*
* @param vectorCol name of the vector column to rewrite.
* @param prefixesToRewrite the map of string prefixes to their replacement values. Each attribute
* name defined in vectorCol will be checked against the keys of this
* map. When a key prefixes a name, the matching prefix will be replaced
* by the value in the map.
*/
private class VectorAttributeRewriter(
override val uid: String,
val vectorCol: String,
val prefixesToRewrite: Map[String, String])
extends Transformer with MLWritable {
def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)
override def transform(dataset: Dataset[_]): DataFrame = {
val metadata = {
val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
val attrs = group.attributes.get.map { attr =>
if (attr.name.isDefined) {
val name = prefixesToRewrite.foldLeft(attr.name.get) { case (curName, (from, to)) =>
curName.replace(from, to)
}
attr.withName(name)
} else {
attr
}
}
new AttributeGroup(vectorCol, attrs).toMetadata
}
val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata)
dataset.select(otherCols :+ rewrittenCol: _*)
}
override def transformSchema(schema: StructType): StructType = {
StructType(
schema.fields.filter(_.name != vectorCol) ++
schema.fields.filter(_.name == vectorCol))
}
override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra)
override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this)
}
private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] {
override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader
override def load(path: String): VectorAttributeRewriter = super.load(path)
/** [[MLWriter]] instance for [[VectorAttributeRewriter]] */
private[VectorAttributeRewriter]
class VectorAttributeRewriterWriter(instance: VectorAttributeRewriter) extends MLWriter {
private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String])
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: vectorCol, prefixesToRewrite
val data = Data(instance.vectorCol, instance.prefixesToRewrite)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class VectorAttributeRewriterReader extends MLReader[VectorAttributeRewriter] {
/** Checked against metadata when loading model */
private val className = classOf[VectorAttributeRewriter].getName
override def load(path: String): VectorAttributeRewriter = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head()
val vectorCol = data.getString(0)
val prefixesToRewrite = data.getAs[Map[String, String]](1)
val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite)
metadata.getAndSetParams(rewriter)
rewriter
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy