All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.feature.RFormulaParser.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.sona.ml.feature

import org.apache.spark.linalg.VectorUDT

import scala.collection.mutable
import scala.util.parsing.combinator.RegexParsers
import org.apache.spark.sql.types._

/**
  * Represents a parsed R formula.
  */
private[sona] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
  /**
    * Resolves formula terms into column names. A schema is necessary for inferring the meaning
    * of the special '.' term. Duplicate terms will be removed during resolution.
    */
  def resolve(schema: StructType): ResolvedRFormula = {
    val dotTerms = expandDot(schema)
    var includedTerms = Seq[Seq[String]]()
    terms.foreach {
      case col: ColumnRef =>
        includedTerms :+= Seq(col.value)
      case ColumnInteraction(cols) =>
        includedTerms ++= expandInteraction(schema, cols)
      case Dot =>
        includedTerms ++= dotTerms.map(Seq(_))
      case Deletion(term: Term) =>
        term match {
          case inner: ColumnRef =>
            includedTerms = includedTerms.filter(_ != Seq(inner.value))
          case ColumnInteraction(cols) =>
            val fromInteraction = expandInteraction(schema, cols).map(_.toSet)
            includedTerms = includedTerms.filter(t => !fromInteraction.contains(t.toSet))
          case Dot =>
            // e.g. "- .", which removes all first-order terms
            includedTerms = includedTerms.filter {
              case Seq(t) => !dotTerms.contains(t)
              case _ => true
            }
          case _: Deletion =>
            throw new RuntimeException("Deletion terms cannot be nested")
          case _: Intercept =>
        }
      case _: Intercept =>
    }
    ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept)
  }

  /** Whether this formula specifies fitting with response variable. */
  def hasLabel: Boolean = label.value.nonEmpty

  /** Whether this formula specifies fitting with an intercept term. */
  def hasIntercept: Boolean = {
    var intercept = true
    terms.foreach {
      case Intercept(enabled) =>
        intercept = enabled
      case Deletion(Intercept(enabled)) =>
        intercept = !enabled
      case _ =>
    }
    intercept
  }

  // expands the Dot operators in interaction terms
  private def expandInteraction(schema: StructType, terms: Seq[InteractableTerm]): Seq[Seq[String]] = {
    if (terms.isEmpty) {
      return Seq(Nil)
    }

    val rest = expandInteraction(schema, terms.tail)
    val validInteractions = (terms.head match {
      case Dot =>
        expandDot(schema).flatMap { t =>
          rest.map { r =>
            Seq(t) ++ r
          }
        }
      case ColumnRef(value) =>
        rest.map(Seq(value) ++ _)
    }).map(_.distinct)

    // Deduplicates feature interactions, for example, a:b is the same as b:a.
    val seen = mutable.Set[Set[String]]()
    validInteractions.flatMap {
      case t if seen.contains(t.toSet) =>
        None
      case t =>
        seen += t.toSet
        Some(t)
    }.sortBy(_.length)
  }

  // the dot operator excludes complex column types
  private def expandDot(schema: StructType): Seq[String] = {
    schema.fields.filter(_.dataType match {
      case _: NumericType | StringType | BooleanType | _: VectorUDT => true
      case _ => false
    }).map(_.name).filter(_ != label.value)
  }
}

/**
  * Represents a fully evaluated and simplified R formula.
  *
  * @param label        the column name of the R formula label (response variable).
  * @param terms        the simplified terms of the R formula. Interactions terms are represented as Seqs
  *                     of column names; non-interaction terms as length 1 Seqs.
  * @param hasIntercept whether the formula specifies fitting with an intercept.
  */
private[sona] case class ResolvedRFormula(label: String, terms: Seq[Seq[String]], hasIntercept: Boolean) {
  override def toString: String = {
    val ts = terms.map {
      case t if t.length > 1 =>
        s"${t.mkString("{", ",", "}")}"
      case t =>
        t.mkString
    }
    val termStr = ts.mkString("[", ",", "]")
    s"ResolvedRFormula(label=$label, terms=$termStr, hasIntercept=$hasIntercept)"
  }
}

/**
  * R formula terms. See the R formula docs here for more information:
  * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
  */
private[sona] sealed trait Term

/** A term that may be part of an interaction, e.g. 'x' in 'x:y' */
private[sona] sealed trait InteractableTerm extends Term

/* R formula reference to all available columns, e.g. "." in a formula */
private[sona] case object Dot extends InteractableTerm

/* R formula reference to a column, e.g. "+ Species" in a formula */
private[sona] case class ColumnRef(value: String) extends InteractableTerm

/* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a formula */
private[sona] case class ColumnInteraction(terms: Seq[InteractableTerm]) extends Term

/* R formula intercept toggle, e.g. "+ 0" in a formula */
private[sona] case class Intercept(enabled: Boolean) extends Term

/* R formula deletion of a variable, e.g. "- Species" in a formula */
private[sona] case class Deletion(term: Term) extends Term

/**
  * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.', ':'.
  */
private[sona] object RFormulaParser extends RegexParsers {
  private val intercept: Parser[Intercept] =
    "([01])".r ^^ { case a => Intercept(a == "1") }

  private val columnRef: Parser[ColumnRef] =
    "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }

  private val empty: Parser[ColumnRef] = "" ^^ { case a => ColumnRef("") }

  private val label: Parser[ColumnRef] = columnRef | empty

  private val dot: Parser[InteractableTerm] = "\\.".r ^^ { case _ => Dot }

  private val interaction: Parser[List[InteractableTerm]] = rep1sep(columnRef | dot, ":")

  private val term: Parser[Term] = intercept |
    interaction ^^ { case terms => ColumnInteraction(terms) } | dot | columnRef

  private val terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
    case op ~ list => list.foldLeft(List(op)) {
      case (left, "+" ~ right) => left ++ Seq(right)
      case (left, "-" ~ right) => left ++ Seq(Deletion(right))
    }
  }

  private val formula: Parser[ParsedRFormula] =
    (label ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }

  def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
    case Success(result, _) => result
    case failure: NoSuccess => throw new IllegalArgumentException(
      "Could not parse formula: " + value)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy