All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.sourced.gitbase.spark.rule.package.scala Maven / Gradle / Ivy

package tech.sourced.gitbase.spark

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{StructField, StructType}

package object rule {

  def getAll: Seq[Rule[LogicalPlan]] = {
    List(
      AddSource,
      PushdownJoins,
      PushdownTree,
      PushdownAggregations
    )
  }

  /**
    * Creates a schema from a list of attributes.
    *
    * @param attributes list of attributes
    * @return resultant schema
    */
  private[rule] def attributesToSchema(attributes: Seq[AttributeReference]): StructType =
    StructType(
      attributes
        .map((a: Attribute) => StructField(a.name, a.dataType, a.nullable, a.metadata))
        .toArray
    )

  private[rule] def containsGroupBy(node: Node): Boolean = {
    (node transformSingleDown {
      case n: GroupBy => Some(n)
      case _ => None
    }).isDefined
  }

  private[rule] def fixAttributeReferences(plan: LogicalPlan): LogicalPlan = {
    import JoinOptimizer._
    val availableAttrs: Seq[Attribute] = plan.children.flatMap(child => {
      child.find {
        case _: logical.Project => true
        case DataSourceV2Relation(_, _: DefaultReader) => true
        case _: logical.Join => true
        case _ => false
      } match {
        case Some(logical.Project(attrs, _)) => attrs.map(_.toAttribute)
        case Some(DataSourceV2Relation(output, _: DefaultReader)) => output
        case _ => Seq()
      }
    })

    plan.transformExpressionsUp {
      case a: Attribute =>
        val candidates = availableAttrs.filter(attr => attr.name == a.name)
        if (candidates.nonEmpty) {
          candidates.find(attr => getSource(a) == getSource(attr))
              .getOrElse(candidates.head)
        } else {
          a
        }
      case x => x
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy