All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.optimizer.Pushdown.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.spark.sql.optimizer

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId}
import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LeafNode, LogicalPlan, UnaryNode, Union}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import streaming.dsl.ScriptSQLExec
import tech.mlsql.common.utils.log.Logging
import tech.mlsql.common.utils.serder.json.JSONTool

import scala.collection.mutable

object Pushdown extends Logging{
  def apply(lp: LogicalPlan): LogicalPlan = {
    val tree = tagTreeNode(lp)
    val replaceNodes = findPushdownNodes(tree)
    if(replaceNodes.nonEmpty){
      replacePushdownSubtree(tree,replaceNodes)
    }else lp
  }

  def tagTreeNode(lp: LogicalPlan):TagTreeNode = {
    val childrenTreeNode = lp.children.map(tagTreeNode)

    lp match {
      case lr:LogicalRelation =>
        val pds = PushdownSourceInfo.getPushdownSourceInfo(lr)
        pds match{
          case ds:Pushdownable if ds.canPushdown(lr)=>
            TagTreeNode(lr,pds,Nil,true)
          case _ => TagTreeNode(lr,new NoPushdownSourceInfo(Map()),Nil,false)
        }
      case lf:LeafNode =>
        TagTreeNode(lf, new NoPushdownSourceInfo(Map()),Nil,false)
      case ur: UnaryNode if childrenTreeNode.head.canPushDown =>
        childrenTreeNode.head.dataSource match {
          case pushdown: Pushdownable if pushdown.isSupport(ur) =>
            TagTreeNode(ur, childrenTreeNode.head.dataSource,childrenTreeNode,true)
          case _ =>
            TagTreeNode(ur, new NoPushdownSourceInfo(Map()),childrenTreeNode,false)
        }
      case jn:Join if childrenTreeNode.head.canPushDown && childrenTreeNode.tail.head.canPushDown=>
        val ld = childrenTreeNode.head.dataSource
        val rd = childrenTreeNode.tail.head.dataSource
        (ld,rd) match{
          case (l:Pushdownable,r:Pushdownable) if l.fastEquals(r) && l.isSupport(jn)=>
            TagTreeNode(jn, ld,childrenTreeNode,true)
          case _ =>
            TagTreeNode(jn, new NoPushdownSourceInfo(Map()),childrenTreeNode,false)
        }
      case bn:BinaryNode if childrenTreeNode.head.canPushDown && childrenTreeNode.tail.head.canPushDown =>
        val ld = childrenTreeNode.head.dataSource
        val rd = childrenTreeNode.tail.head.dataSource
        (ld,rd) match{
          case (l:Pushdownable,r:Pushdownable) if l.fastEquals(r) && l.isSupport(bn)=>
            TagTreeNode(bn, ld,childrenTreeNode,true)
          case _ =>
            TagTreeNode(bn, new NoPushdownSourceInfo(Map()),childrenTreeNode,false)
        }
      case un:Union if isChildrenCanPushdown(childrenTreeNode) =>
        if (isDataSourceEqual(childrenTreeNode) && childrenTreeNode.head.dataSource.asInstanceOf[Pushdownable].isSupport(un)){
          TagTreeNode(un, childrenTreeNode.head.dataSource,childrenTreeNode,true)
        }else{
          TagTreeNode(lp, new NoPushdownSourceInfo(Map()),childrenTreeNode,false)
        }
      case _ =>
        TagTreeNode(lp, new NoPushdownSourceInfo(Map()),childrenTreeNode,false)

    }

  }

  def findPushdownNodes(tree: TagTreeNode): mutable.HashSet[TagTreeNode] = {
    val pushdownNodes = new mutable.HashSet[TagTreeNode]()

    val toVisit = new mutable.Stack[TagTreeNode]()
    toVisit.push(tree)
    while(toVisit.nonEmpty){
      val visit = toVisit.pop()
      if (visit.canPushDown && visit.children!=Nil){
        pushdownNodes += visit
      }else{
        visit.children.foreach(child => toVisit.push(child))
      }
    }
    pushdownNodes
  }

  def replacePushdownSubtree(tree:TagTreeNode, replaceNodes:mutable.HashSet[TagTreeNode]): LogicalPlan ={
    val rootStartTime = System.currentTimeMillis()
    val enableLog = isEnableLog()

    val newlp = tree.lp.transformDown{
      case lp:LogicalPlan if replaceNodes.map(_.lp).contains(lp) =>
        val startTime = System.currentTimeMillis()
        try {
          val treeNode = replaceNodes.find(_.lp.equals(lp)).get
          val oldAttrs = lp.output
          val newsub = treeNode.dataSource.asInstanceOf[Pushdownable].buildScan2(lp)
          val newAttrs = newsub.output
          val newIdToOldId = mutable.Map.empty[ExprId, ExprId]
          newAttrs.zip(oldAttrs).foreach(elem => newIdToOldId += (elem._1.exprId -> elem._2.exprId))
          val newnewsub = newsub.transform {
            case x: LogicalPlan =>
              x.transformExpressions {
                case a: AttributeReference =>
                  if (newIdToOldId.contains(a.exprId)) {
                    a.copy()(exprId = newIdToOldId(a.exprId), a.qualifier)
                  } else a
              }
          }
          if(enableLog) {
            logInfo("----Old sub Logicplan:" + lp.toString())
            logInfo("----New sub Logicplan:" + newnewsub.toString())
          }
          newnewsub
        } catch {
          case ex:Exception =>
            logError("----Sub Logicplan Trans Error:"+ ex.getMessage)
            lp
        }finally{
          val endTime = System.currentTimeMillis()
          if(enableLog) {
            logInfo("----Sub Logicplan Trans Time Cost:" + (endTime - startTime) + "ms")
          }
        }
    }

    if(enableLog) {
      val rootEndTime = System.currentTimeMillis()
      logInfo("--Old Logicplan:" + tree.lp.toString())
      logInfo("--New Logicplan:" + tree.lp.toString())
      logInfo("--Logicplan Trans Time Cost:" + (rootEndTime - rootStartTime) + "ms")
    }
    newlp
  }

  private def isChildrenCanPushdown(children: Seq[TagTreeNode]):Boolean ={
    children.tail.forall{
      node =>
        (node.canPushDown, children.head.canPushDown) match {
          case (true,true) =>
            true
          case _ => false
        }
    }
  }

  private def isDataSourceEqual(children: Seq[TagTreeNode]): Boolean ={
    children.tail.forall { node =>
      (node.dataSource, children.head.dataSource) match {
        case (l: Pushdownable, r: Pushdownable) =>
          l.fastEquals(r)
        case _ => false
      }
    }
  }

  def getHttpParams()={
    JSONTool.parseJson[Map[String, String]](ScriptSQLExec.context().userDefinedParam.getOrElse("__PARAMS__", "{}"))
  }

  def getScriptParams()={
    ScriptSQLExec.context().execListener.env()
  }

  def isEnableLog(): Boolean ={
    try{
        val httpParams = getHttpParams()
        val scriptParams = getScriptParams()
        httpParams.getOrElse("enableQueryWithIndexer", "false").toBoolean ||
              scriptParams.getOrElse("enableQueryWithIndexer", "false").toBoolean
    }catch {
      case ex:Exception =>
        logError("----Get PARAMS Error:"+ ex.getMessage)
        false
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy