![JAR search and dependency download from the Maven repository](/logo.png)
io.delta.sql.parser.DeltaSqlParser.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This file contains code from the Apache Spark project (original license above).
* It contains modifications, which are licensed as follows:
*/
/*
* Copyright (2020) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.delta.sql.parser
import java.util.Locale
import scala.collection.JavaConverters._
import org.apache.spark.sql.delta.commands.DeltaGenerateCommand
import io.delta.sql.parser.DeltaSqlBaseParser._
import io.delta.tables.execution.{ DescribeDeltaHistoryCommand, VacuumTableCommand }
import org.antlr.v4.runtime._
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.{ Interval, ParseCancellationException }
import org.antlr.v4.runtime.tree._
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.{ FunctionIdentifier, TableIdentifier }
import org.apache.spark.sql.catalyst.parser.{ ParseErrorListener, ParseException, ParserInterface }
import org.apache.spark.sql.catalyst.parser.ParserUtils.{ string, withOrigin }
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.delta.commands.DescribeDeltaDetailCommand
import org.apache.spark.sql.delta.commands.ConvertToDeltaCommand
import org.apache.spark.sql.types._
/**
* A SQL parser that tries to parse Delta commands. If failng to parse the SQL text, it will
* forward the call to `delegate`.
*/
class DeltaSqlParser(val delegate: ParserInterface) extends ParserInterface {
private val builder = new DeltaSqlAstBuilder
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
builder.visit(parser.singleStatement()) match {
case plan: LogicalPlan => plan
case _ => delegate.parsePlan(sqlText)
}
}
// scalastyle:off line.size.limit
/**
* Fork from `org.apache.spark.sql.catalyst.parser.AbstractSqlParser#parse(java.lang.String, scala.Function1)`.
*
* @see https://github.com/apache/spark/blob/v2.4.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala#L81
*/
// scalastyle:on
protected def parse[T](command: String)(toResult: DeltaSqlBaseParser => T): T = {
val lexer = new DeltaSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener)
val tokenStream = new CommonTokenStream(lexer)
val parser = new DeltaSqlBaseParser(tokenStream)
parser.addParseListener(PostProcessor)
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)
try {
try {
// first, try parsing with potentially faster SLL mode
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
toResult(parser)
} catch {
case e: ParseCancellationException =>
// if we fail, parse with LL mode
tokenStream.seek(0) // rewind input stream
parser.reset()
// Try Again.
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
toResult(parser)
}
} catch {
case e: ParseException if e.command.isDefined =>
throw e
case e: ParseException =>
throw e.withCommand(command)
case e: AnalysisException =>
val position = Origin(e.line, e.startPosition)
throw new ParseException(Option(command), e.message, position, position)
}
}
override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText)
override def parseTableIdentifier(sqlText: String): TableIdentifier =
delegate.parseTableIdentifier(sqlText)
override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
delegate.parseFunctionIdentifier(sqlText)
override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText)
override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText)
}
/**
* Define how to convert an AST generated from `DeltaSqlBase.g4` to a `LogicalPlan`. The parent
* class `DeltaSqlBaseBaseVisitor` defines all visitXXX methods generated from `#` instructions in
* `DeltaSqlBase.g4` (such as `#vacuumTable`).
*/
class DeltaSqlAstBuilder extends DeltaSqlBaseBaseVisitor[AnyRef] {
/**
* Create a [[VacuumTableCommand]] logical plan. Example SQL:
* {{{
* VACUUM ('/path/to/dir' | delta.`/path/to/dir`) [RETAIN number HOURS] [DRY RUN];
* }}}
*/
override def visitVacuumTable(ctx: VacuumTableContext): AnyRef = withOrigin(ctx) {
VacuumTableCommand(
Option(ctx.path).map(string),
Option(ctx.table).map(visitTableIdentifier),
Option(ctx.number).map(_.getText.toDouble),
ctx.RUN != null
)
}
override def visitDescribeDeltaDetail(ctx: DescribeDeltaDetailContext): LogicalPlan = withOrigin(ctx) {
DescribeDeltaDetailCommand(Option(ctx.path).map(string), Option(ctx.table).map(visitTableIdentifier))
}
override def visitDescribeDeltaHistory(ctx: DescribeDeltaHistoryContext): LogicalPlan = withOrigin(ctx) {
DescribeDeltaHistoryCommand(
Option(ctx.path).map(string),
Option(ctx.table).map(visitTableIdentifier),
Option(ctx.limit).map(_.getText.toInt)
)
}
override def visitGenerate(ctx: GenerateContext): LogicalPlan = withOrigin(ctx) {
DeltaGenerateCommand(modeName = ctx.modeName.getText, tableId = visitTableIdentifier(ctx.table))
}
override def visitConvert(ctx: ConvertContext): LogicalPlan = withOrigin(ctx) {
ConvertToDeltaCommand(
visitTableIdentifier(ctx.table),
Option(ctx.colTypeList).map(colTypeList => StructType(visitColTypeList(colTypeList))),
None
)
}
override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
visit(ctx.statement).asInstanceOf[LogicalPlan]
}
protected def visitTableIdentifier(ctx: QualifiedNameContext): TableIdentifier = withOrigin(ctx) {
ctx.identifier.asScala match {
case Seq(tbl) => TableIdentifier(tbl.getText)
case Seq(db, tbl) => TableIdentifier(tbl.getText, Some(db.getText))
case _ => throw new ParseException(s"Illegal table name ${ctx.getText}", ctx)
}
}
override def visitPassThrough(ctx: PassThroughContext): LogicalPlan = null
override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) {
ctx.colType().asScala.map(visitColType)
}
override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
import ctx._
val builder = new MetadataBuilder
// Add Hive type string to metadata.
val rawDataType = typedVisit[DataType](ctx.dataType)
val cleanedDataType = HiveStringType.replaceCharType(rawDataType)
if (rawDataType != cleanedDataType) {
builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString)
}
StructField(ctx.colName.getText, cleanedDataType, nullable = NOT == null, builder.build())
}
protected def typedVisit[T](ctx: ParseTree): T = {
ctx.accept(this).asInstanceOf[T]
}
override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT)
(dataType, ctx.INTEGER_VALUE().asScala.toList) match {
case ("boolean", Nil) => BooleanType
case ("tinyint" | "byte", Nil) => ByteType
case ("smallint" | "short", Nil) => ShortType
case ("int" | "integer", Nil) => IntegerType
case ("bigint" | "long", Nil) => LongType
case ("float", Nil) => FloatType
case ("double", Nil) => DoubleType
case ("date", Nil) => DateType
case ("timestamp", Nil) => TimestampType
case ("string", Nil) => StringType
case ("char", length :: Nil) => CharType(length.getText.toInt)
case ("varchar", length :: Nil) => VarcharType(length.getText.toInt)
case ("binary", Nil) => BinaryType
case ("decimal", Nil) => DecimalType.USER_DEFAULT
case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
case ("decimal", precision :: scale :: Nil) =>
DecimalType(precision.getText.toInt, scale.getText.toInt)
case ("interval", Nil) => CalendarIntervalType
case (dt, params) =>
val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt
throw new ParseException(s"DataType $dtStr is not supported.", ctx)
}
}
}
// scalastyle:off line.size.limit
/**
* Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`.
*
* @see https://github.com/apache/spark/blob/v2.4.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala#L157
*/
// scalastyle:on
class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
override def consume(): Unit = wrapped.consume
override def getSourceName(): String = wrapped.getSourceName
override def index(): Int = wrapped.index
override def mark(): Int = wrapped.mark
override def release(marker: Int): Unit = wrapped.release(marker)
override def seek(where: Int): Unit = wrapped.seek(where)
override def size(): Int = wrapped.size
override def getText(interval: Interval): String = {
// ANTLR 4.7's CodePointCharStream implementations have bugs when
// getText() is called with an empty stream, or intervals where
// the start > end. See
// https://github.com/antlr/antlr4/commit/ac9f7530 for one fix
// that is not yet in a released ANTLR artifact.
if (size() > 0 && (interval.b - interval.a >= 0)) {
wrapped.getText(interval)
} else {
""
}
}
override def LA(i: Int): Int = {
val la = wrapped.LA(i)
if (la == 0 || la == IntStream.EOF) la
else Character.toUpperCase(la)
}
}
// scalastyle:off line.size.limit
/**
* Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`.
*
* @see https://github.com/apache/spark/blob/v2.4.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala#L248
*/
// scalastyle:on
case object PostProcessor extends DeltaSqlBaseBaseListener {
/** Remove the back ticks from an Identifier. */
override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
replaceTokenByIdentifier(ctx, 1) { token =>
// Remove the double back ticks in the string.
token.setText(token.getText.replace("``", "`"))
token
}
}
/** Treat non-reserved keywords as Identifiers. */
override def exitNonReserved(ctx: NonReservedContext): Unit = {
replaceTokenByIdentifier(ctx, 0)(identity)
}
private def replaceTokenByIdentifier(ctx: ParserRuleContext, stripMargins: Int)(
f: CommonToken => CommonToken = identity
): Unit = {
val parent = ctx.getParent
parent.removeLastChild()
val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
val newToken = new CommonToken(
new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
DeltaSqlBaseParser.IDENTIFIER,
token.getChannel,
token.getStartIndex + stripMargins,
token.getStopIndex - stripMargins
)
parent.addChild(new TerminalNodeImpl(f(newToken)))
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy