All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.delta.sql.parser.DeltaSqlParser.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * This file contains code from the Apache Spark project (original license above).
 * It contains modifications, which are licensed as follows:
 */

/*
 * Copyright (2020) The Delta Lake Project Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.delta.sql.parser

import java.util.Locale

import scala.collection.JavaConverters._

import org.apache.spark.sql.delta.commands.DeltaGenerateCommand
import io.delta.sql.parser.DeltaSqlBaseParser._
import io.delta.tables.execution.{ DescribeDeltaHistoryCommand, VacuumTableCommand }
import org.antlr.v4.runtime._
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.{ Interval, ParseCancellationException }
import org.antlr.v4.runtime.tree._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.{ FunctionIdentifier, TableIdentifier }
import org.apache.spark.sql.catalyst.parser.{ ParseErrorListener, ParseException, ParserInterface }
import org.apache.spark.sql.catalyst.parser.ParserUtils.{ string, withOrigin }
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.delta.commands.DescribeDeltaDetailCommand
import org.apache.spark.sql.delta.commands.ConvertToDeltaCommand
import org.apache.spark.sql.types._

/**
  * A SQL parser that tries to parse Delta commands. If failng to parse the SQL text, it will
  * forward the call to `delegate`.
  */
class DeltaSqlParser(val delegate: ParserInterface) extends ParserInterface {
  private val builder = new DeltaSqlAstBuilder

  override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
    builder.visit(parser.singleStatement()) match {
      case plan: LogicalPlan => plan
      case _                 => delegate.parsePlan(sqlText)
    }
  }

  // scalastyle:off line.size.limit
  /**
    * Fork from `org.apache.spark.sql.catalyst.parser.AbstractSqlParser#parse(java.lang.String, scala.Function1)`.
    *
    * @see https://github.com/apache/spark/blob/v2.4.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala#L81
    */
  // scalastyle:on
  protected def parse[T](command: String)(toResult: DeltaSqlBaseParser => T): T = {
    val lexer = new DeltaSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
    lexer.removeErrorListeners()
    lexer.addErrorListener(ParseErrorListener)

    val tokenStream = new CommonTokenStream(lexer)
    val parser      = new DeltaSqlBaseParser(tokenStream)
    parser.addParseListener(PostProcessor)
    parser.removeErrorListeners()
    parser.addErrorListener(ParseErrorListener)

    try {
      try {
        // first, try parsing with potentially faster SLL mode
        parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
        toResult(parser)
      } catch {
        case e: ParseCancellationException =>
          // if we fail, parse with LL mode
          tokenStream.seek(0) // rewind input stream
          parser.reset()

          // Try Again.
          parser.getInterpreter.setPredictionMode(PredictionMode.LL)
          toResult(parser)
      }
    } catch {
      case e: ParseException if e.command.isDefined =>
        throw e
      case e: ParseException                        =>
        throw e.withCommand(command)
      case e: AnalysisException                     =>
        val position = Origin(e.line, e.startPosition)
        throw new ParseException(Option(command), e.message, position, position)
    }
  }

  override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText)

  override def parseTableIdentifier(sqlText: String): TableIdentifier =
    delegate.parseTableIdentifier(sqlText)

  override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
    delegate.parseFunctionIdentifier(sqlText)

  override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText)

  override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText)
}

/**
  * Define how to convert an AST generated from `DeltaSqlBase.g4` to a `LogicalPlan`. The parent
  * class `DeltaSqlBaseBaseVisitor` defines all visitXXX methods generated from `#` instructions in
  * `DeltaSqlBase.g4` (such as `#vacuumTable`).
  */
class DeltaSqlAstBuilder extends DeltaSqlBaseBaseVisitor[AnyRef] {

  /**
    * Create a [[VacuumTableCommand]] logical plan. Example SQL:
    * {{{
    *   VACUUM ('/path/to/dir' | delta.`/path/to/dir`) [RETAIN number HOURS] [DRY RUN];
    * }}}
    */
  override def visitVacuumTable(ctx: VacuumTableContext): AnyRef = withOrigin(ctx) {
    VacuumTableCommand(
      Option(ctx.path).map(string),
      Option(ctx.table).map(visitTableIdentifier),
      Option(ctx.number).map(_.getText.toDouble),
      ctx.RUN != null
    )
  }

  override def visitDescribeDeltaDetail(ctx: DescribeDeltaDetailContext): LogicalPlan = withOrigin(ctx) {
    DescribeDeltaDetailCommand(Option(ctx.path).map(string), Option(ctx.table).map(visitTableIdentifier))
  }

  override def visitDescribeDeltaHistory(ctx: DescribeDeltaHistoryContext): LogicalPlan = withOrigin(ctx) {
    DescribeDeltaHistoryCommand(
      Option(ctx.path).map(string),
      Option(ctx.table).map(visitTableIdentifier),
      Option(ctx.limit).map(_.getText.toInt)
    )
  }

  override def visitGenerate(ctx: GenerateContext): LogicalPlan = withOrigin(ctx) {
    DeltaGenerateCommand(modeName = ctx.modeName.getText, tableId = visitTableIdentifier(ctx.table))
  }

  override def visitConvert(ctx: ConvertContext): LogicalPlan = withOrigin(ctx) {
    ConvertToDeltaCommand(
      visitTableIdentifier(ctx.table),
      Option(ctx.colTypeList).map(colTypeList => StructType(visitColTypeList(colTypeList))),
      None
    )
  }

  override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
    visit(ctx.statement).asInstanceOf[LogicalPlan]
  }

  protected def visitTableIdentifier(ctx: QualifiedNameContext): TableIdentifier = withOrigin(ctx) {
    ctx.identifier.asScala match {
      case Seq(tbl)     => TableIdentifier(tbl.getText)
      case Seq(db, tbl) => TableIdentifier(tbl.getText, Some(db.getText))
      case _            => throw new ParseException(s"Illegal table name ${ctx.getText}", ctx)
    }
  }

  override def visitPassThrough(ctx: PassThroughContext): LogicalPlan = null

  override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) {
    ctx.colType().asScala.map(visitColType)
  }

  override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
    import ctx._

    val builder = new MetadataBuilder

    // Add Hive type string to metadata.
    val rawDataType     = typedVisit[DataType](ctx.dataType)
    val cleanedDataType = HiveStringType.replaceCharType(rawDataType)
    if (rawDataType != cleanedDataType) {
      builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString)
    }

    StructField(ctx.colName.getText, cleanedDataType, nullable = NOT == null, builder.build())
  }

  protected def typedVisit[T](ctx: ParseTree): T = {
    ctx.accept(this).asInstanceOf[T]
  }

  override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
    val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT)
    (dataType, ctx.INTEGER_VALUE().asScala.toList) match {
      case ("boolean", Nil)                       => BooleanType
      case ("tinyint" | "byte", Nil)              => ByteType
      case ("smallint" | "short", Nil)            => ShortType
      case ("int" | "integer", Nil)               => IntegerType
      case ("bigint" | "long", Nil)               => LongType
      case ("float", Nil)                         => FloatType
      case ("double", Nil)                        => DoubleType
      case ("date", Nil)                          => DateType
      case ("timestamp", Nil)                     => TimestampType
      case ("string", Nil)                        => StringType
      case ("char", length :: Nil)                => CharType(length.getText.toInt)
      case ("varchar", length :: Nil)             => VarcharType(length.getText.toInt)
      case ("binary", Nil)                        => BinaryType
      case ("decimal", Nil)                       => DecimalType.USER_DEFAULT
      case ("decimal", precision :: Nil)          => DecimalType(precision.getText.toInt, 0)
      case ("decimal", precision :: scale :: Nil) =>
        DecimalType(precision.getText.toInt, scale.getText.toInt)
      case ("interval", Nil)                      => CalendarIntervalType
      case (dt, params)                           =>
        val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt
        throw new ParseException(s"DataType $dtStr is not supported.", ctx)
    }
  }
}

// scalastyle:off line.size.limit
/**
  * Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`.
  *
  * @see https://github.com/apache/spark/blob/v2.4.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala#L157
  */
// scalastyle:on
class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
  override def consume(): Unit            = wrapped.consume
  override def getSourceName(): String    = wrapped.getSourceName
  override def index(): Int               = wrapped.index
  override def mark(): Int                = wrapped.mark
  override def release(marker: Int): Unit = wrapped.release(marker)
  override def seek(where: Int): Unit     = wrapped.seek(where)
  override def size(): Int                = wrapped.size

  override def getText(interval: Interval): String = {
    // ANTLR 4.7's CodePointCharStream implementations have bugs when
    // getText() is called with an empty stream, or intervals where
    // the start > end. See
    // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix
    // that is not yet in a released ANTLR artifact.
    if (size() > 0 && (interval.b - interval.a >= 0)) {
      wrapped.getText(interval)
    } else {
      ""
    }
  }

  override def LA(i: Int): Int = {
    val la = wrapped.LA(i)
    if (la == 0 || la == IntStream.EOF) la
    else Character.toUpperCase(la)
  }
}

// scalastyle:off line.size.limit
/**
  * Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`.
  *
  * @see https://github.com/apache/spark/blob/v2.4.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala#L248
  */
// scalastyle:on
case object PostProcessor extends DeltaSqlBaseBaseListener {

  /** Remove the back ticks from an Identifier. */
  override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
    replaceTokenByIdentifier(ctx, 1) { token =>
      // Remove the double back ticks in the string.
      token.setText(token.getText.replace("``", "`"))
      token
    }
  }

  /** Treat non-reserved keywords as Identifiers. */
  override def exitNonReserved(ctx: NonReservedContext): Unit = {
    replaceTokenByIdentifier(ctx, 0)(identity)
  }

  private def replaceTokenByIdentifier(ctx: ParserRuleContext, stripMargins: Int)(
    f: CommonToken => CommonToken = identity
  ): Unit = {
    val parent   = ctx.getParent
    parent.removeLastChild()
    val token    = ctx.getChild(0).getPayload.asInstanceOf[Token]
    val newToken = new CommonToken(
      new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
      DeltaSqlBaseParser.IDENTIFIER,
      token.getChannel,
      token.getStartIndex + stripMargins,
      token.getStopIndex - stripMargins
    )
    parent.addChild(new TerminalNodeImpl(f(newToken)))
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy