All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.sourced.gitbase.spark.rule.PushdownTree.scala Maven / Gradle / Ivy

The newest version!
package tech.sourced.gitbase.spark.rule

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{
  Alias, And, AttributeReference, Expression, Literal, NamedExpression
}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{StructField, StructType}
import tech.sourced.gitbase.spark._

object PushdownTree extends Rule[LogicalPlan] {

  /** @inheritdoc */
  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
    case logical.Project(exp1, logical.Project(exp2, child))
      if containsDuplicates(exp1 ++ exp2) =>
      fixAttributeReferences(logical.Project(exp1 ++ exp2, child))
    case logical.Project(Seq(), child) => child
    case [email protected](
    projectExpressions,
    r@DataSourceV2Relation(_, DefaultReader(servers, _, query))) =>
      if (containsGroupBy(query)) {
        r
      } else if (containsDuplicates(projectExpressions)) {
        fixAttributeReferences(n)
      } else if (!query.hasProjection) {
        // Split expressions in supported and unsupported to push down
        // only the supported ones and keep upside the unsupported
        val (supportedProject, unsupportedProject) =
        splitSupportedAndUnsupportedExpressions(projectExpressions)

        val newSchema = StructType(
          supportedProject.map(e => StructField(e.name, e.dataType, e.nullable, e.metadata))
        )
        val newOutput = supportedProject.map(e =>
          AttributeReference(e.name, e.dataType, e.nullable, e.metadata)(e.exprId, e.qualifier)
        )

        logical.Project(unsupportedProject,
          DataSourceV2Relation(
            newOutput,
            DefaultReader(
              servers,
              newSchema,
              Project(supportedProject, query)
            )
          )
        )
      } else {
        n
      }

    case [email protected](
    expr,
    DataSourceV2Relation(out, DefaultReader(servers, schema, query))) =>
      if (!canBeHandled(Seq(expr))) {
        fixAttributeReferences(n)
      } else {
        DataSourceV2Relation(
          out,
          DefaultReader(
            servers,
            schema,
            Filter(splitExpressions(expr), query)
          )
        )
      }

    // We should only push down local sorts.
    case [email protected](
    order,
    false,
    DataSourceV2Relation(out, DefaultReader(servers, schema, query))) =>
      if (!canBeHandled(order.map(_.child))) {
        fixAttributeReferences(n)
      } else {
        DataSourceV2Relation(
          out,
          DefaultReader(
            servers,
            schema,
            Sort(order, query)
          )
        )
      }

    case logical.LocalLimit(
    limit: Literal,
    DataSourceV2Relation(out, DefaultReader(servers, schema, query))) =>
      val limitNumber = limit.value match {
        case n: Int => n.toLong
        case n: Long => n
        case _ => throw new SparkException("limit literal should be a number")
      }

      DataSourceV2Relation(
        out,
        DefaultReader(
          servers,
          schema,
          Limit(limitNumber, query)
        )
      )

    case node: DataSourceV2Relation => node

    case node => fixAttributeReferences(node)
  }

  private def toNamedExpression(expression: Expression): NamedExpression = {
    expression match {
      case e: NamedExpression => e
      case e => Alias(e, e.sql)()
    }
  }

  private def splitExpressions(expression: Expression): Seq[Expression] = {
    expression match {
      case And(left, right) =>
        splitExpressions(left) ++ splitExpressions(right)
      case e => Seq(e)
    }
  }

  private def canBeHandled(exprs: Seq[Expression]): Boolean = {
    exprs.flatMap(x => QueryBuilder.compileExpression(x)).length == exprs.length
  }

  private def containsDuplicates(exprs: Seq[NamedExpression]): Boolean = {
    exprs.groupBy(_.name).values.map(_.length).exists(_ > 1)
  }

  /**
    * Split expressions in supported and unsupported to push down
    * only the supported ones and keep upside the unsupported
    */
  private def splitSupportedAndUnsupportedExpressions(expressions: Seq[Expression]):
  (Seq[NamedExpression], Seq[NamedExpression]) = {
    import scala.collection.mutable
    val unsupported = mutable.Buffer[NamedExpression]()
    val supported = mutable.Buffer[NamedExpression]()

    expressions.foreach(exp => {
      if (canBeHandled(exp :: Nil)) {
        val e = toNamedExpression(exp)
        supported += e
        unsupported += AttributeReference(
          e.name,
          e.dataType,
          e.nullable,
          e.metadata
        )(e.exprId, e.qualifier)
      } else if (!canBeHandled(exp :: Nil) && canBeHandled(exp.children)) {
        unsupported += toNamedExpression(exp.withNewChildren(exp.children.map(c => {
          val child = toNamedExpression(c)
          supported += child
          AttributeReference(
            child.name,
            child.dataType,
            child.nullable,
            child.metadata
          )(child.exprId, child.qualifier)
        })))
      } else if (exp.children.nonEmpty) {
        unsupported += toNamedExpression(exp.withNewChildren(exp.children.map(c => {
          val (s, u) = splitSupportedAndUnsupportedExpressions(c :: Nil)
          if (u.length != 1) {
            throw new SparkException("invalid number of unsupported expressions obtained")
          }

          supported ++= s
          u.head match {
            case e: Alias => e.child // remove aliases from non topmost expressions
            case o => o
          }
        })))
      } else {
        unsupported += toNamedExpression(exp)
      }
    })

    if (unsupported.length != expressions.length) {
      throw new SparkException("unsupported exceptions don't match input expressions")
    }

    (supported.toList, unsupported.toList)
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy