All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.extensions.TiParser.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2021 PingCAP, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.extensions

import org.apache.spark.sql.{SparkSession, TiContext}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{
  InsertIntoStatement,
  LogicalPlan,
  SubqueryAlias,
  With
}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.execution.command.{
  CacheTableCommand,
  CreateViewCommand,
  ExplainCommand,
  UncacheTableCommand
}
import org.apache.spark.sql.types.{DataType, StructType}

import java.util

case class TiParser(
    getOrCreateTiContext: SparkSession => TiContext,
    sparkSession: SparkSession,
    delegate: ParserInterface)
    extends ParserInterface {
  private lazy val tiContext = getOrCreateTiContext(sparkSession)
  private lazy val internal = new SparkSqlParser()

  private val cteTableNames = new ThreadLocal[java.util.Set[String]] {
    override def initialValue(): util.Set[String] = new util.HashSet[String]()
  }

  /**
   * WAR to lead Spark to consider this relation being on local files.
   * Otherwise Spark will lookup this relation in his session catalog.
   * CHECK Spark [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveRelations]] for details.
   */
  private val qualifyTableIdentifier: PartialFunction[LogicalPlan, LogicalPlan] = {
    case r @ UnresolvedRelation(tableIdentifier, _, _) if needQualify(tableIdentifier) =>
      r.copy(qualifyTableIdentifierInternal(tableIdentifier))
    case i @ InsertIntoStatement(r @ UnresolvedRelation(tableIdentifier, _, _), _, _, _, _, _)
        if needQualify(tableIdentifier) =>
      // When getting temp view, we leverage legacy catalog.
      i.copy(r.copy(qualifyTableIdentifierInternal(tableIdentifier)))
    case w @ With(_, cteRelations) =>
      for (x <- cteRelations) {
        cteTableNames.get().add(x._1.toLowerCase())
      }
      w.copy(cteRelations = cteRelations
        .map(p => (p._1, p._2.transform(qualifyTableIdentifier).asInstanceOf[SubqueryAlias])))
    case cv @ CreateViewCommand(_, _, _, _, _, child, _, _, _) =>
      cv.copy(child = child transform qualifyTableIdentifier)
    case e @ ExplainCommand(plan, _) =>
      e.copy(logicalPlan = plan transform qualifyTableIdentifier)
    case c @ CacheTableCommand(tableIdentifier, plan, _, _, _)
        if plan.isEmpty && needQualify(tableIdentifier) =>
      // Caching an unqualified catalog table.
      c.copy(qualifyTableIdentifierInternal(tableIdentifier))
    case c @ CacheTableCommand(_, plan, _, _, _) if plan.isDefined =>
      c.copy(plan = Some(plan.get transform qualifyTableIdentifier))
    case u @ UncacheTableCommand(tableIdentifier, _) if needQualify(tableIdentifier) =>
      // Uncaching an unqualified catalog table.
      u.copy(qualifyTableIdentifierInternal(tableIdentifier))
    case logicalPlan =>
      logicalPlan transformExpressionsUp {
        case s: SubqueryExpression =>
          val cteNamesBeforeSubQuery = new util.HashSet[String]()
          cteNamesBeforeSubQuery.addAll(cteTableNames.get())
          val newPlan = s.withNewPlan(s.plan transform qualifyTableIdentifier)
          // cte table names in the subquery should not been seen outside subquey
          cteTableNames.get().clear()
          cteTableNames.get().addAll(cteNamesBeforeSubQuery)
          newPlan
      }
  }

  override def parsePlan(sqlText: String): LogicalPlan = {
    val plan = internal.parsePlan(sqlText)
    cteTableNames.get().clear()
    plan.transform(qualifyTableIdentifier)
  }

  override def parseExpression(sqlText: String): Expression =
    internal.parseExpression(sqlText)

  override def parseTableIdentifier(sqlText: String): TableIdentifier =
    internal.parseTableIdentifier(sqlText)

  override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
    internal.parseFunctionIdentifier(sqlText)

  override def parseTableSchema(sqlText: String): StructType =
    internal.parseTableSchema(sqlText)

  override def parseDataType(sqlText: String): DataType =
    internal.parseDataType(sqlText)

  private def qualifyTableIdentifierInternal(tableIdentifier: Seq[String]): Seq[String] = {
    if (tableIdentifier.size == 1) {
      tiContext.tiCatalog.getCurrentDatabase :: tableIdentifier.toList
    } else {
      tableIdentifier
    }
  }

  private def needQualify(tableIdentifier: Seq[String]): Boolean = {
    tableIdentifier.size == 1 && tiContext.sessionCatalog
      .getTempView(tableIdentifier.head)
      .isEmpty && !cteTableNames.get().contains(tableIdentifier.head.toLowerCase())
  }

  override def parseMultipartIdentifier(sqlText: String): Seq[String] =
    internal.parseMultipartIdentifier(sqlText)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy