
tech.sourced.gitbase.spark.rule.PushdownAggregations.scala Maven / Gradle / Ivy
The newest version!
package tech.sourced.gitbase.spark.rule
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.{Alias,
AttributeReference, Divide, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import tech.sourced.gitbase.spark._
object PushdownAggregations extends Rule[LogicalPlan] {
/** @inheritdoc*/
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
// Ignore aggregates with no aggregate expressions.
case [email protected](_, Nil, _) => n
case [email protected](grouping, aggregate,
DataSourceV2Relation(_, DefaultReader(servers, _, query))) =>
if (!canBeHandled(grouping) || !canBeHandled(aggregate) || containsGroupBy(query)) {
return fixAttributeReferences(n)
}
val transformedAggregate = aggregate.flatMap(_.flatMap {
case e@(_: Count | _: Min | _: Max | _: Sum) => Seq(e)
case Average(child) => Seq(Count(child), Sum(child))
case r: AttributeReference => Seq(r)
case _ => Seq()
}).map {
case n: NamedExpression => n
case e => Alias(e, e.toString())()
}.groupBy(_.name).values.map(_.head).toSeq.sortBy(_.name)
val missingAttrs = grouping.flatMap(_.flatMap {
case a: AttributeReference =>
if (transformedAggregate.exists(n => n.name == a.name)) {
None
} else {
Some(a)
}
case _ => None
}).groupBy(_.name).values.map(_.head).toSeq.sortBy(_.name)
val pushedDownAggregate = transformedAggregate ++ missingAttrs
val newOut = pushedDownAggregate.map {
case e =>
AttributeReference(e.name, e.dataType, e.nullable, e.metadata)(e.exprId, e.qualifier)
}
val newAggregate = aggregate.map(e => e.transformUp {
case n@Count(child) =>
val cnt = newOut.find(_.name == Count(child.head).toString)
.getOrElse(throw new SparkException(
"This is likely a bug. Could not find matching COUNT" +
s" to be pushed down for COUNT(${child.head})"))
Sum(cnt)
case Average(child) =>
val sum = newOut.find(_.name == Sum(child).toString)
.getOrElse(throw new SparkException(
"This is likely a bug. Could not find matching SUM" +
s" to be pushed down for AVG($child)"))
val cnt = newOut.find(_.name == Count(child).toString)
.getOrElse(throw new SparkException(
"This is likely a bug. Could not find matching COUNT" +
s" to be pushed down for AVG($child)"))
Divide(sum, cnt)
case n@Min(child) =>
val min = newOut.find(_.name == n.toString)
.getOrElse(throw new SparkException(
"This is likely a bug. Could not find matching MIN" +
s" to be pushed down for attribute MIN($child)"))
Min(min)
case n@Max(child) =>
val max = newOut.find(_.name == n.toString)
.getOrElse(throw new SparkException(
"This is likely a bug. Could not find matching MAX" +
s" to be pushed down for attribute MAX($child)"))
Max(max)
case n@Sum(child) =>
val sum = newOut.find(_.name == n.toString)
.getOrElse(throw new SparkException(
"This is likely a bug. Could not find matching SUM" +
s" to be pushed down for attribute SUM($child)"))
Sum(sum)
case expr => expr
}.asInstanceOf[NamedExpression])
logical.Aggregate(
grouping,
newAggregate,
DataSourceV2Relation(
newOut,
DefaultReader(
servers,
attributesToSchema(newOut),
GroupBy(pushedDownAggregate, grouping, query)
)
)
)
case node: DataSourceV2Relation => node
case node => fixAttributeReferences(node)
}
private def canBeHandled(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(x => QueryBuilder.compileExpression(x)).length == exprs.length
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy