All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.sourced.gitbase.spark.rule.PushdownJoins.scala Maven / Gradle / Ivy

The newest version!
package tech.sourced.gitbase.spark.rule

import java.util.NoSuchElementException

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.Metadata
import tech.sourced.gitbase.spark._

object PushdownJoins extends Rule[LogicalPlan] {
  /** @inheritdoc */
  def apply(plan: LogicalPlan): LogicalPlan = {
    val schema = plan.schema
    val result = plan transformUp {
      // Joins are only applicable per repository, so we can push down completely
      // the join into the data source
      case q: logical.Join =>
        val jd = JoinOptimizer.getJoinData(q)
        if (!jd.valid) {
          q
        } else {
          jd match {
            case JoinData(Some(source), _, filters, projectExprs, attributes, servers, _) =>
              val node = DataSourceV2Relation(
                attributes,
                DefaultReader(
                  servers,
                  attributesToSchema(attributes),
                  source
                )
              )

              val filteredNode = filters match {
                case Some(filter) => logical.Filter(filter, node)
                case None => node
              }

              // If the projection is empty, project the original schema.
              if (projectExprs.nonEmpty) {
                logical.Project(projectExprs, filteredNode)
              } else {
                logical.Project(
                  attributes,
                  filteredNode
                )
              }
            case _ => q
          }
        }

      // Remove two consecutive projects and replace it with the outermost one.
      case logical.Project(list, logical.Project(_, child)) =>
        logical.Project(list, child)
    } transformUp {
      // Deduplicate columns with the same name. Joined gitbase tables will
      // always have the same value in columns with the same name, so it's
      // safe to deduplicate.
      case DataSourceV2Relation(out, DefaultReader(servers, _, source)) =>
        val names = out.map(_.name).distinct.toBuffer
        val newOut = out.flatMap(x => {
          val idx = names.indexOf(x.name)
          if (idx >= 0) {
            names.remove(idx)
            Some(x)
          } else {
            None
          }
        })

        DataSourceV2Relation(
          newOut,
          DefaultReader(
            servers,
            attributesToSchema(newOut),
            source
          )
        )

      // Since we deduplicated, it's possible that some Attributes are now not
      // pointing to the correct deduplicated column. So we need to replace
      // these attributes with the one that's available, trying to get the exact
      // match if possible.
      case n => fixAttributeReferences(n)
    }

    // After the deduplication SELECT * will require a new project to have the
    // same schema as it did before.
    if (result.schema.length != schema.length) {
      fixAttributeReferences(logical.Project(
        schema.fields.map(col =>
          AttributeReference(col.name, col.dataType, col.nullable, col.metadata)()),
        result
      ))
    } else {
      result
    }
  }

}

case class JoinData(source: Option[Node] = None,
                    conditions: Option[Expression] = None,
                    filter: Option[Expression] = None,
                    project: Seq[NamedExpression] = Nil,
                    attributes: Seq[AttributeReference] = Nil,
                    servers: Seq[GitbaseServer] = Nil,
                    valid: Boolean = false)

/**
  * Support methods for optimizing [[DefaultReader]]s.
  */
private[rule] object JoinOptimizer extends Logging {

  private[rule] def hasSource(attr: Attribute): Boolean =
    getSource(attr) != ""


  private[rule] def getSource(attr: NamedExpression): String =
    getSource(attr.metadata)

  private[rule] def getSource(metadata: Metadata): String =
    try {
      metadata.getString(Sources.SourceKey)
    } catch {
      case _: NoSuchElementException => ""
    }

  /**
    * Returns the data about a join to perform optimizations on it.
    *
    * @param j join to get the data from
    * @return join data
    */
  private[rule] def getJoinData(j: logical.Join): JoinData = {
    // left and right ends in a GitRelation
    val leftRel = getGitbaseRelation(j.left)
    val rightRel = getGitbaseRelation(j.right)

    // Not a valid Join to optimize GitRelations
    if (leftRel.isEmpty || rightRel.isEmpty || !isJoinSupported(j)) {
      logUnableToOptimize("It doesn't have gitbase relations in both sides, " +
        "or the Join type is not supported.")
      return JoinData()
    }

    // Check Join conditions. They must be all conditions related with GitRelations
    val unsupportedConditions = JoinOptimizer.getUnsupportedConditions(
      j,
      leftRel.get,
      rightRel.get
    )

    if (unsupportedConditions.nonEmpty) {
      logUnableToOptimize(s"Obtained unsupported conditions: $unsupportedConditions")
      return JoinData()
    }

    j.condition match {
      case Some(cond) =>
        val tables = getRelationTables(leftRel.get, rightRel.get)
        if (!conditionsAllowPushdown(cond, tables)) {
          logUnableToOptimize("Join conditions are not restricted by repository_id")
          return JoinData()
        }
      case None =>
    }

    val left = getLeafJoinData(j.left)
    val right = getLeafJoinData(j.right)
    mergeJoinData(Seq(
      JoinData(valid = true, conditions = j.condition),
      left,
      right
    ))
  }

  private def getLeafJoinData(p: LogicalPlan): JoinData = {
    var filter: Option[Expression] = None
    var valid = true
    var project: Seq[NamedExpression] = Seq()
    var source: Option[Node] = None
    var attributes: Seq[AttributeReference] = Seq()
    var servers: Seq[GitbaseServer] = Seq()

    p.map {
      case [email protected](_, _, _, _) =>
        logUnableToOptimize(s"Invalid node: $jm")
        valid = false
      case logical.Filter(cond, _) =>
        filter = Some(cond)
      case logical.Project(namedExpressions, _) =>
        project = namedExpressions
      case DataSourceV2Relation(out, DefaultReader(srvs, _, src)) =>
        source = Some(src)
        if (project.isEmpty) {
          project = out
        }
        attributes = out
        servers = srvs
      case other =>
        logUnableToOptimize(s"Invalid node: $other")
        valid = false
    }

    JoinData(source, None, filter, project, attributes, servers, valid)
  }

  private def getRelationTables(left: DataSourceV2Relation,
                                right: DataSourceV2Relation): Seq[String] = {
    val leftSource = left.reader.asInstanceOf[DefaultReader].node
    val rightSource = right.reader.asInstanceOf[DefaultReader].node
    (getSourceTables(leftSource) ++ getSourceTables(rightSource)).distinct
  }

  private def getSourceTables(s: Node): Seq[String] = s match {
    case Table(t) => Seq(t)
    case Join(left, right, _) => (getSourceTables(left) ++ getSourceTables(right)).distinct
    case _ => Seq()
  }

  private def conditionsAllowPushdown(expression: Expression,
                                      tables: Seq[String]): Boolean = {
    expression.find {
      case EqualTo(left: Attribute, right: Attribute) =>
        val leftSource = left.metadata.getString(Sources.SourceKey)
        val rightSource = right.metadata.getString(Sources.SourceKey)
        left.name == "repository_id" && right.name == "repository_id" &&
          tables.contains(leftSource) && tables.contains(rightSource) &&
          leftSource != rightSource
      case And(left, right) =>
        conditionsAllowPushdown(left, tables) || conditionsAllowPushdown(right, tables)
      case _ => false
    }.isDefined
  }

  /**
    * Reduce all join data into one single join data.
    *
    * @param data sequence of join data to be merged
    * @return merged join data
    */
  private def mergeJoinData(data: Seq[JoinData]): JoinData = {
    val d = data.reduce((jd1, jd2) => {
      // get all filter expressions
      val filters: Option[Expression] = mixExpressions(
        jd1.filter,
        jd2.filter,
        And
      )
      // get all join conditions
      val conditions = mixExpressions(
        jd1.conditions,
        jd2.conditions,
        And
      )

      val source = (jd1.source, jd2.source) match {
        case (Some(s1), Some(s2)) => Some(Join(s1, s2, conditions))
        case (Some(s1), None) => Some(s1)
        case (None, Some(s1)) => Some(s1)
        case _ => None
      }

      JoinData(
        source,
        conditions,
        filters,
        jd1.project ++ jd2.project,
        jd1.attributes ++ jd2.attributes,
        (jd1.servers ++ jd2.servers).distinct,
        jd1.valid && jd2.valid
      )
    })

    if (d.source.isEmpty) {
      JoinData()
    } else {
      d
    }
  }

  private val supportedJoinTypes: Seq[JoinType] = Inner :: Nil

  /**
    * Reports whether the given join is supported.
    *
    * @param j join
    * @return is supported or not
    */
  def isJoinSupported(j: logical.Join): Boolean = supportedJoinTypes.contains(j.joinType)

  /**
    * Retrieves all the unsupported conditions in the join.
    *
    * @param join  Join
    * @param left  left relation
    * @param right right relation
    * @return unsupported conditions
    */
  def getUnsupportedConditions(join: logical.Join,
                               left: DataSourceV2Relation,
                               right: DataSourceV2Relation): Set[Attribute] = {
    val leftReferences = left.references.baseSet
    val rightReferences = right.references.baseSet
    val joinReferences = join.references.baseSet
    (joinReferences -- leftReferences -- rightReferences).map(_.a)
  }

  /**
    * Mixes the two given expressions with the given join function if both exist
    * or returns the one that exists otherwise.
    *
    * @param l            left expression
    * @param r            right expression
    * @param joinFunction function used to join them
    * @return an optional expression
    */
  def mixExpressions(l: Option[Expression],
                     r: Option[Expression],
                     joinFunction: (Expression, Expression) => Expression):
  Option[Expression] = {
    (l, r) match {
      case (Some(expr1), Some(expr2)) => Some(joinFunction(expr1, expr2))
      case (None, None) => None
      case (le, None) => le
      case (None, re) => re
    }
  }

  /**
    * Returns the first git relation found in the given logical plan, if any.
    *
    * @param lp logical plan
    * @return git relation, or none if there is no such relation
    */
  def getGitbaseRelation(lp: LogicalPlan): Option[DataSourceV2Relation] =
    lp.find {
      case DataSourceV2Relation(_, _: DefaultReader) => true
      case _ => false
    } map (_.asInstanceOf[DataSourceV2Relation])

  private def logUnableToOptimize(msg: String = ""): Unit = {
    logError("*" * 80)
    logError("* This Join could not be optimized. This might severely impact the performance *")
    logError("* of your query. This happened because there is an unexpected node between the *")
    logError("* two relations of a Join, such as Limit or another kind of unknown relation.  *")
    logError("* Note that this will not stop your query or make it fail, only make it slow.  *")
    logError("*" * 80)
    if (msg.nonEmpty) {
      def split(str: String): Seq[String] = {
        if (str.length > 76) {
          Seq(str.substring(0, 76)) ++ split(str.substring(76))
        } else {
          Seq(str)
        }
      }

      logError(s"* Reason:${" " * 70}*")
      msg.lines.flatMap(split)
        .map(line => s"* $line${" " * (76 - line.length)} *")
        .foreach(logError(_))
      logError("*" * 80)
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy