import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.Path
import{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.linalg.{Vector, VectorUDT}
import{BooleanParam, Param, ParamMap, ParamValidators}
import{HasFeaturesCol, HasHandleInvalid, HasLabelCol}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Compatible
* Base trait for [[RFormula]] and [[RFormulaModel]].
private[sona] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with HasHandleInvalid {
* R formula parameter. The formula is provided in string form.
* @group param
val formula: Param[String] = new Param(this, "formula", "R model formula")
/** @group getParam */
def getFormula: String = $(formula)
* Force to index label whether it is numeric or string type.
* Usually we index label only when it is string type.
* If the formula was used by classification algorithms,
* we can force to index label even it is numeric type by setting this param with true.
* Default: false.
* @group param
val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel",
"Force to index label whether it is numeric or string")
setDefault(forceIndexLabel -> false)
/** @group getParam */
def getForceIndexLabel: Boolean = $(forceIndexLabel)
* Param for how to handle invalid data (unseen or NULL values) in features and label column
* of string type. Options are 'skip' (filter out rows with invalid data),
* 'error' (throw an error), or 'keep' (put invalid data in a special additional
* bucket, at index numLabels).
* Default: "error"
* @group param
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"How to handle invalid data (unseen or NULL values) in features and label column of string " +
"type. Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
setDefault(handleInvalid, StringIndexer.ERROR_INVALID)
* Param for how to order categories of a string FEATURE column used by `StringIndexer`.
* The last category after ordering is dropped when encoding strings.
* Supported options: 'frequencyDesc', 'frequencyAsc', 'alphabetDesc', 'alphabetAsc'.
* The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', `RFormula`
* drops the same category as R when encoding strings.
* The options are explained using an example `'b', 'a', 'b', 'a', 'c', 'b'`:
* {{{
* +-----------------+---------------------------------------+----------------------------------+
* | Option | Category mapped to 0 by StringIndexer | Category dropped by RFormula |
* +-----------------+---------------------------------------+----------------------------------+
* | 'frequencyDesc' | most frequent category ('b') | least frequent category ('c') |
* | 'frequencyAsc' | least frequent category ('c') | most frequent category ('b') |
* | 'alphabetDesc' | last alphabetical category ('c') | first alphabetical category ('a')|
* | 'alphabetAsc' | first alphabetical category ('a') | last alphabetical category ('c') |
* +-----------------+---------------------------------------+----------------------------------+
* }}}
* Note that this ordering option is NOT used for the label column. When the label column is
* indexed, it uses the default descending frequency ordering in `StringIndexer`.
* @group param
final val stringIndexerOrderType: Param[String] = new Param(this, "stringIndexerOrderType",
"How to order categories of a string FEATURE column used by StringIndexer. " +
"The last category after ordering is dropped when encoding strings. " +
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}. " +
"The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', " +
"RFormula drops the same category as R when encoding strings.",
setDefault(stringIndexerOrderType, StringIndexer.frequencyDesc)
/** @group getParam */
def getStringIndexerOrderType: String = $(stringIndexerOrderType)
protected def hasLabelCol(schema: StructType): Boolean = {$(labelCol))
* :: Experimental ::
* Implements the transforms required for fitting a dataset against an R model formula. Currently
* we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see
* the R formula docs here:
* The basic operators are:
* - `~` separate target and terms
* - `+` concat terms, "+ 0" means removing intercept
* - `-` remove a term, "- 1" means removing intercept
* - `:` interaction (multiplication for numeric values, or binarized categorical values)
* - `.` all columns except target
* Suppose `a` and `b` are double columns, we use the following simple examples
* to illustrate the effect of `RFormula`:
* - `y ~ a + b` means model `y ~ w0 + w1 * a + w2 * b` where `w0` is the intercept and `w1, w2`
* are coefficients.
* - `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3`
* are coefficients.
* RFormula produces a vector column of features and a double or string column of label.
* Like when formulas are used in R for linear regression, string input columns will be one-hot
* encoded, and numeric columns will be cast to doubles.
* If the label column is of type string, it will be first transformed to double with
* `StringIndexer`. If the label column does not exist in the DataFrame, the output label column
* will be created from the specified response variable in the formula.
class RFormula(override val uid: String)
extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("rFormula"))
* Sets the formula to use for this transformer. Must be called before use.
* @group setParam
* @param value an R formula in string form (e.g. "y ~ x + z")
def setFormula(value: String): this.type = set(formula, value)
/** @group setParam */
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
/** @group setParam */
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
/** @group setParam */
def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value)
/** @group setParam */
def setStringIndexerOrderType(value: String): this.type = set(stringIndexerOrderType, value)
/** Whether the formula specifies fitting an intercept. */
private[sona] def hasIntercept: Boolean = {
require(isDefined(formula), "Formula must be defined first.")
override def fit(dataset: Dataset[_]): RFormulaModel = {
transformSchema(dataset.schema, logging = true)
require(isDefined(formula), "Formula must be defined first.")
val parsedFormula = RFormulaParser.parse($(formula))
val resolvedFormula = parsedFormula.resolve(dataset.schema)
val encoderStages = ArrayBuffer[PipelineStage]()
val oneHotEncodeColumns = ArrayBuffer[(String, String)]()
val prefixesToRewrite = mutable.Map[String, String]()
val tempColumns = ArrayBuffer[String]()
def tmpColumn(category: String): String = {
val col = Identifiable.randomUID(category)
tempColumns += col
// First we index each string column referenced by the input terms.
val indexed: Map[String, String] = { term =>
dataset.schema(term).dataType match {
case _: StringType =>
val indexCol = tmpColumn("stridx")
encoderStages += new StringIndexer()
prefixesToRewrite(indexCol + "_") = term + "_"
(term, indexCol)
case _: VectorUDT =>
val group = AttributeGroup.fromStructField(dataset.schema(term))
val size = if (group.size < 0) {[Vector](0).size
} else {
encoderStages += new VectorSizeHint(uid)
(term, term)
case _ =>
(term, term)
// Then we handle one-hot encoding and interactions between terms.
var keepReferenceCategory = false
val encodedTerms = {
case Seq(term) if dataset.schema(term).dataType == StringType =>
val encodedCol = tmpColumn("onehot")
// Formula w/o intercept, one of the categories in the first category feature is
// being used as reference category, we will not drop any category for that feature.
if (!hasIntercept && !keepReferenceCategory) {
encoderStages += new OneHotEncoderEstimator(uid)
keepReferenceCategory = true
} else {
oneHotEncodeColumns += indexed(term) -> encodedCol
prefixesToRewrite(encodedCol + "_") = term + "_"
case Seq(term) =>
case terms =>
val interactionCol = tmpColumn("interaction")
encoderStages += new Interaction()
prefixesToRewrite(interactionCol + "_") = ""
if (oneHotEncodeColumns.nonEmpty) {
val (inputCols, outputCols) = oneHotEncodeColumns.toArray.unzip
encoderStages += new OneHotEncoderEstimator(uid)
encoderStages += new VectorAssembler(uid)
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
encoderStages += new ColumnPruner(tempColumns.toSet)
if ((dataset.schema.fieldNames.contains(resolvedFormula.label) &&
dataset.schema(resolvedFormula.label).dataType == StringType) || $(forceIndexLabel)) {
encoderStages += new StringIndexer()
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
// optimistic schema; does not contain any ML attributes
override def transformSchema(schema: StructType): StructType = {
require(!hasLabelCol(schema) || !$(forceIndexLabel),
"If label column already exists, forceIndexLabel can not be set with true.")
if (hasLabelCol(schema)) {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
} else {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+
StructField($(labelCol), DoubleType, true))
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
override def toString: String = s"RFormula(${get(formula).getOrElse("")}) (uid=$uid)"
object RFormula extends DefaultParamsReadable[RFormula] {
override def load(path: String): RFormula = super.load(path)
* :: Experimental ::
* Model fitted by [[RFormula]]. Fitting is required to determine the factor levels of
* formula terms.
class RFormulaModel private[sona](
override val uid: String,
private[sona] val resolvedFormula: ResolvedRFormula,
private[sona] val pipelineModel: PipelineModel)
extends Model[RFormulaModel] with RFormulaBase with MLWritable {
override def transform(dataset: Dataset[_]): DataFrame = {
override def transformSchema(schema: StructType): StructType = {
val withFeatures = pipelineModel.transformSchema(schema)
if (resolvedFormula.label.isEmpty || hasLabelCol(withFeatures)) {
} else if (schema.exists( == resolvedFormula.label)) {
val nullable = schema(resolvedFormula.label).dataType match {
case _: NumericType | BooleanType => false
case _ => true
StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
} else {
// Ignore the label field. This is a hack so that this transformer can also work on test
// datasets in a Pipeline.
override def copy(extra: ParamMap): RFormulaModel = {
val copied = new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(parent)
copyValues(copied, extra)
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
private def transformLabel(dataset: Dataset[_]): DataFrame = {
val labelName = resolvedFormula.label
if (labelName.isEmpty || hasLabelCol(dataset.schema)) {
} else if (dataset.schema.exists( == labelName)) {
dataset.schema(labelName).dataType match {
case _: NumericType | BooleanType =>
dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
case other =>
throw new IllegalArgumentException("Unsupported type for label: " + other)
} else {
// Ignore the label field. This is a hack so that this transformer can also work on test
// datasets in a Pipeline.
private def checkCanTransform(schema: StructType) {
val columnNames =
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType],
s"Label column already exists and is not of type ${Compatible.numericTypeSimpleString}.")
override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this)
object RFormulaModel extends MLReadable[RFormulaModel] {
override def read: MLReader[RFormulaModel] = new RFormulaModelReader
override def load(path: String): RFormulaModel = super.load(path)
/** [[MLWriter]] instance for [[RFormulaModel]] */
private[RFormulaModel] class RFormulaModelWriter(instance: RFormulaModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: resolvedFormula
val dataPath = new Path(path, "data").toString
// Save pipeline model
val pmPath = new Path(path, "pipelineModel").toString
private class RFormulaModelReader extends MLReader[RFormulaModel] {
/** Checked against metadata when loading model */
private val className = classOf[RFormulaModel].getName
override def load(path: String): RFormulaModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data ="label", "terms", "hasIntercept").head()
val label = data.getString(0)
val terms = data.getAs[Seq[Seq[String]]](1)
val hasIntercept = data.getBoolean(2)
val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept)
val pmPath = new Path(path, "pipelineModel").toString
val pipelineModel = PipelineModel.load(pmPath)
val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel)
* Utility transformer for removing temporary columns from a DataFrame.
* TODO(ekl) make this a public transformer
private class ColumnPruner(override val uid: String, val columnsToPrune: Set[String])
extends Transformer with MLWritable {
def this(columnsToPrune: Set[String]) =
this(Identifiable.randomUID("columnPruner"), columnsToPrune)
override def transform(dataset: Dataset[_]): DataFrame = {
val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) _*)
override def transformSchema(schema: StructType): StructType = {
StructType(schema.fields.filter(col => !columnsToPrune.contains(
override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this)
private object ColumnPruner extends MLReadable[ColumnPruner] {
override def read: MLReader[ColumnPruner] = new ColumnPrunerReader
override def load(path: String): ColumnPruner = super.load(path)
/** [[MLWriter]] instance for [[ColumnPruner]] */
private[ColumnPruner] class ColumnPrunerWriter(instance: ColumnPruner) extends MLWriter {
private case class Data(columnsToPrune: Seq[String])
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: columnsToPrune
val data = Data(instance.columnsToPrune.toSeq)
val dataPath = new Path(path, "data").toString
private class ColumnPrunerReader extends MLReader[ColumnPruner] {
/** Checked against metadata when loading model */
private val className = classOf[ColumnPruner].getName
override def load(path: String): ColumnPruner = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data ="columnsToPrune").head()
val columnsToPrune = data.getAs[Seq[String]](0).toSet
val pruner = new ColumnPruner(metadata.uid, columnsToPrune)
* Utility transformer that rewrites Vector attribute names via prefix replacement. For example,
* it can rewrite attribute names starting with 'foo_' to start with 'bar_' instead.
* @param vectorCol name of the vector column to rewrite.
* @param prefixesToRewrite the map of string prefixes to their replacement values. Each attribute
* name defined in vectorCol will be checked against the keys of this
* map. When a key prefixes a name, the matching prefix will be replaced
* by the value in the map.
private class VectorAttributeRewriter(
override val uid: String,
val vectorCol: String,
val prefixesToRewrite: Map[String, String])
extends Transformer with MLWritable {
def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)
override def transform(dataset: Dataset[_]): DataFrame = {
val metadata = {
val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
val attrs = { attr =>
if ( {
val name = prefixesToRewrite.foldLeft( { case (curName, (from, to)) =>
curName.replace(from, to)
} else {
new AttributeGroup(vectorCol, attrs).toMetadata
val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata) :+ rewrittenCol: _*)
override def transformSchema(schema: StructType): StructType = {
schema.fields.filter( != vectorCol) ++
schema.fields.filter( == vectorCol))
override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra)
override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this)
private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] {
override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader
override def load(path: String): VectorAttributeRewriter = super.load(path)
/** [[MLWriter]] instance for [[VectorAttributeRewriter]] */
class VectorAttributeRewriterWriter(instance: VectorAttributeRewriter) extends MLWriter {
private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String])
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: vectorCol, prefixesToRewrite
val data = Data(instance.vectorCol, instance.prefixesToRewrite)
val dataPath = new Path(path, "data").toString
private class VectorAttributeRewriterReader extends MLReader[VectorAttributeRewriter] {
/** Checked against metadata when loading model */
private val className = classOf[VectorAttributeRewriter].getName
override def load(path: String): VectorAttributeRewriter = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data ="vectorCol", "prefixesToRewrite").head()
val vectorCol = data.getString(0)
val prefixesToRewrite = data.getAs[Map[String, String]](1)
val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite)
