All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.netease.arctic.spark.sql.catalyst.parser.ArcticSqlExtensionsParser.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.netease.arctic.spark.sql.catalyst.parser

import com.netease.arctic.spark.sql.catalyst.plans
import com.netease.arctic.spark.sql.parser._
import com.netease.arctic.spark.table.ArcticSparkTable
import com.netease.arctic.spark.util.ArcticSparkUtils
import org.antlr.v4.runtime._
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
import org.antlr.v4.runtime.tree.TerminalNodeImpl
import org.apache.spark.sql.arctic.parser.ArcticExtendSparkSqlAstBuilder
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.{NonReservedContext, QuotedIdentifierContext}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MergeIntoTable}
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, TableIdentifier}
import org.apache.spark.sql.connector.catalog.{Table, TableCatalog}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{AnalysisException, SparkSession}

import java.util.Locale
import scala.collection.JavaConverters.seqAsJavaListConverter
import scala.util.Try

class ArcticSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface with SQLConfHelper {

  private lazy val createTableAstBuilder = new ArcticExtendSparkSqlAstBuilder(delegate)
  private lazy val arcticCommandAstVisitor = new ArcticCommandAstParser()


  /**
   * Parse a string to a DataType.
   */
  override def parseDataType(sqlText: String): DataType = {
    delegate.parseDataType(sqlText)
  }

  /**
   * Parse a string to a raw DataType without CHAR/VARCHAR replacement.
   */
  def parseRawDataType(sqlText: String): DataType = throw new UnsupportedOperationException()

  /**
   * Parse a string to an Expression.
   */
  override def parseExpression(sqlText: String): Expression = {
    delegate.parseExpression(sqlText)
  }

  /**
   * Parse a string to a TableIdentifier.
   */
  override def parseTableIdentifier(sqlText: String): TableIdentifier = {
    delegate.parseTableIdentifier(sqlText)
  }

  /**
   * Parse a string to a FunctionIdentifier.
   */
  override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
    delegate.parseFunctionIdentifier(sqlText)
  }

  /**
   * Parse a string to a multi-part identifier.
   */
  override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
    delegate.parseMultipartIdentifier(sqlText)
  }

  /**
   * Creates StructType for a given SQL string, which is a comma separated list of field
   * definitions which will preserve the correct Hive metadata.
   */
  override def parseTableSchema(sqlText: String): StructType = {
    delegate.parseTableSchema(sqlText)
  }

  def isArcticCommand(sqlText: String): Boolean = {
    val normalized = sqlText.toLowerCase(Locale.ROOT).trim().replaceAll("\\s+", " ")
    (normalized.contains("migrate") && normalized.contains("to arctic"))
  }

  def isArcticExtendSparkStatement(sqlText: String): Boolean = {
    val normalized = sqlText.toLowerCase(Locale.ROOT).trim().replaceAll("\\s+", " ")
    normalized.contains("create table") && normalized.contains("using arctic") && normalized.contains("primary key")
  }

  def buildLexer(sql: String): Option[Lexer] = {
    lazy val charStream = new UpperCaseCharStream(CharStreams.fromString(sql))
    if (isArcticExtendSparkStatement(sql)) {
      Some(new ArcticExtendSparkSqlLexer(charStream))
    } else if (isArcticCommand(sql)) {
      Some(new ArcticSqlCommandLexer(charStream))
    } else {
      Option.empty
    }
  }

  def buildAntlrParser(stream: TokenStream, lexer: Lexer): Parser = {
    lexer match {
      case _: ArcticExtendSparkSqlLexer =>
        val parser = new ArcticExtendSparkSqlParser(stream)
        parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled
        parser.SQL_standard_keyword_behavior = conf.ansiEnabled
        parser
      case _: ArcticSqlCommandLexer =>
        val parser = new ArcticSqlCommandParser(stream)
        parser
      case _ =>
        throw new IllegalStateException("no suitable parser found")
    }
  }

  def toLogicalResult(parser: Parser): LogicalPlan = parser match {
    case p: ArcticExtendSparkSqlParser =>
      createTableAstBuilder.visitArcticCommand(p.arcticCommand())
    case p: ArcticSqlCommandParser =>
      arcticCommandAstVisitor.visitArcticCommand(p.arcticCommand())
  }

  /**
   * Parse a string to a LogicalPlan.
   */
  override def parsePlan(sqlText: String): LogicalPlan = {
    val lexerOpt = buildLexer(sqlText)
    if (lexerOpt.isDefined) {
      val lexer = lexerOpt.get
      lexer.removeErrorListeners()
      lexer.addErrorListener(ArcticParseErrorListener)

      val tokenStream = new CommonTokenStream(lexer)
      val parser = buildAntlrParser(tokenStream, lexer)
      parser.removeErrorListeners()
      parser.addErrorListener(ArcticParseErrorListener)

      try {
        try {
          // first, try parsing with potentially faster SLL mode
          parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
          toLogicalResult(parser)
        }
        catch {
          case _: ParseCancellationException =>
            // if we fail, parse with LL mode
            tokenStream.seek(0) // rewind input stream
            parser.reset()
            // Try Again.
            parser.getInterpreter.setPredictionMode(PredictionMode.LL)
            toLogicalResult(parser)
        }
      } catch {
        case e: ParseException if e.command.isDefined =>
          throw e
        case e: ParseException => throw e.withCommand(sqlText)
        case e: AnalysisException =>
          val position = Origin(e.line, e.startPosition)
          throw new ParseException(Option(sqlText), e.message, position, position)
      }
    } else {
      val parsedPlan = delegate.parsePlan(sqlText)
      parsedPlan match {
        case p =>
          replaceMergeIntoCommands(p)
      }
    }
  }

  private def replaceMergeIntoCommands(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {

    case m @ MergeIntoTable(UnresolvedArcticTable(aliasedTable), _, _, _, _) =>
      plans.MergeIntoArcticTable(aliasedTable, m.sourceTable, m.mergeCondition, m.matchedActions, m.notMatchedActions)
  }

  object UnresolvedArcticTable {

    def unapply(plan: LogicalPlan): Option[LogicalPlan] = {
      EliminateSubqueryAliases(plan) match {
        case UnresolvedRelation(multipartIdentifier, _, _) if isArcticKeyedTable(multipartIdentifier) =>
          Some(plan)
        case _ =>
          None
      }
    }

    private def isArcticKeyedTable(multipartIdent: Seq[String]): Boolean = {
      val catalogAndIdentifier = ArcticSparkUtils.tableCatalogAndIdentifier(SparkSession.active, multipartIdent.asJava)
      catalogAndIdentifier.catalog match {
        case tableCatalog: TableCatalog =>
          Try(tableCatalog.loadTable(catalogAndIdentifier.identifier))
            .map(isArcticKeyedTable)
            .getOrElse(false)

        case _ =>
          false
      }
    }

    private def isArcticKeyedTable(table: Table): Boolean = table match {
      case _: ArcticSparkTable =>
        true
      case _ => false
    }
  }
}

/* Copied from Apache Spark's to avoid dependency on Spark Internals */
class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
  override def consume(): Unit = wrapped.consume

  override def getSourceName(): String = wrapped.getSourceName

  override def index(): Int = wrapped.index

  override def mark(): Int = wrapped.mark

  override def release(marker: Int): Unit = wrapped.release(marker)

  override def seek(where: Int): Unit = wrapped.seek(where)

  override def size(): Int = wrapped.size

  override def getText(interval: Interval): String = wrapped.getText(interval)

  // scalastyle:off
  override def LA(i: Int): Int = {
    val la = wrapped.LA(i)
    if (la == 0 || la == IntStream.EOF) la
    else Character.toUpperCase(la)
  }
  // scalastyle:on
}

/**
 * The post-processor validates & cleans-up the parse tree during the parse process.
 */
case object ArcticSqlExtensionsPostProcessor extends ArcticExtendSparkSqlBaseListener {

  /** Remove the back ticks from an Identifier. */
  def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
    replaceTokenByIdentifier(ctx, 1) { token =>
      // Remove the double back ticks in the string.
      token.setText(token.getText.replace("``", "`"))
      token
    }
  }

  /** Treat non-reserved keywords as Identifiers. */
  def exitNonReserved(ctx: NonReservedContext): Unit = {
    replaceTokenByIdentifier(ctx, 0)(identity)
  }

  private def replaceTokenByIdentifier(
    ctx: ParserRuleContext,
    stripMargins: Int
  )(
    f: CommonToken => CommonToken = identity
  ): Unit = {
    val parent = ctx.getParent
    parent.removeLastChild()
    val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
    val newToken = new CommonToken(
      new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
      ArcticExtendSparkSqlParser.IDENTIFIER,
      token.getChannel,
      token.getStartIndex + stripMargins,
      token.getStopIndex - stripMargins)
    parent.addChild(new TerminalNodeImpl(f(newToken)))
  }
}

/* Partially copied from Apache Spark's Parser to avoid dependency on Spark Internals */
case object ArcticParseErrorListener extends BaseErrorListener {
  override def syntaxError(
    recognizer: Recognizer[_, _],
    offendingSymbol: scala.Any,
    line: Int,
    charPositionInLine: Int,
    msg: String,
    e: RecognitionException
  ): Unit = {
    val (start, stop) = offendingSymbol match {
      case token: CommonToken =>
        val start = Origin(Some(line), Some(token.getCharPositionInLine))
        val length = token.getStopIndex - token.getStartIndex + 1
        val stop = Origin(Some(line), Some(token.getCharPositionInLine + length))
        (start, stop)
      case _ =>
        val start = Origin(Some(line), Some(charPositionInLine))
        (start, start)
    }
  }
}






© 2015 - 2025 Weber Informatics LLC | Privacy Policy