All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.jdbc.MySQLDialect.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.jdbc

import java.sql.{Connection, SQLException, Types}
import java.util
import java.util.Locale

import scala.collection.mutable.ArrayBuilder
import scala.util.control.NonFatal

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference, NullOrdering, SortDirection}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types._

private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with NoLegacyJDBCError {

  override def canHandle(url : String): Boolean =
    url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql")

  private val distinctUnsupportedAggregateFunctions =
    Set("VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP")

  // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html
  private val supportedAggregateFunctions =
    Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions
  private val supportedFunctions = supportedAggregateFunctions ++ Set("DATE_ADD", "DATE_DIFF")

  override def isSupportedFunction(funcName: String): Boolean =
    supportedFunctions.contains(funcName)

  class MySQLSQLBuilder extends JDBCSQLBuilder {
    override def visitExtract(field: String, source: String): String = {
      field match {
        case "DAY_OF_YEAR" => s"DAYOFYEAR($source)"
        case "YEAR_OF_WEEK" => s"EXTRACT(YEAR FROM $source)"
        // WEEKDAY uses Monday = 0, Tuesday = 1, ... and ISO standard is Monday = 1, ...,
        // so we use the formula (WEEKDAY + 1) to follow the ISO standard.
        case "DAY_OF_WEEK" => s"(WEEKDAY($source) + 1)"
        case _ => super.visitExtract(field, source)
      }
    }

    override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
      funcName match {
        case "DATE_ADD" =>
          s"DATE_ADD(${inputs(0)}, INTERVAL ${inputs(1)} DAY)"
        case "DATE_DIFF" =>
          s"DATEDIFF(${inputs(0)}, ${inputs(1)})"
        case _ => super.visitSQLFunction(funcName, inputs)
      }
    }

    override def visitSortOrder(
        sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = {
      (sortDirection, nullOrdering) match {
        case (SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) =>
          s"$sortKey $sortDirection"
        case (SortDirection.ASCENDING, NullOrdering.NULLS_LAST) =>
          s"CASE WHEN $sortKey IS NULL THEN 1 ELSE 0 END, $sortKey $sortDirection"
        case (SortDirection.DESCENDING, NullOrdering.NULLS_FIRST) =>
          s"CASE WHEN $sortKey IS NULL THEN 0 ELSE 1 END, $sortKey $sortDirection"
        case (SortDirection.DESCENDING, NullOrdering.NULLS_LAST) =>
          s"$sortKey $sortDirection"
      }
    }

    override def visitStartsWith(l: String, r: String): String = {
      val value = r.substring(1, r.length() - 1)
      s"$l LIKE '${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'"
    }

    override def visitEndsWith(l: String, r: String): String = {
      val value = r.substring(1, r.length() - 1)
      s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}' ESCAPE '\\\\'"
    }

    override def visitContains(l: String, r: String): String = {
      val value = r.substring(1, r.length() - 1)
      s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'"
    }

    override def visitAggregateFunction(
        funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
      if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) {
        throw new SparkUnsupportedOperationException(
          errorClass = "_LEGACY_ERROR_TEMP_3184",
          messageParameters = Map(
            "class" -> this.getClass.getSimpleName,
            "funcName" -> funcName))
      } else {
        super.visitAggregateFunction(funcName, isDistinct, inputs)
      }
  }

  override def compileExpression(expr: Expression): Option[String] = {
    val mysqlSQLBuilder = new MySQLSQLBuilder()
    try {
      Some(mysqlSQLBuilder.build(expr))
    } catch {
      case NonFatal(e) =>
        logWarning("Error occurs while compiling V2 expression", e)
        None
    }
  }

  override def getCatalystType(
      sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
    def getCatalystTypeForBitArray: Option[DataType] = {
      md.putLong("binarylong", 1)
      if (conf.legacyMySqlBitArrayMappingEnabled) {
        Some(LongType)
      } else {
        Some(BinaryType)
      }
    }
    sqlType match {
      case Types.VARBINARY if "BIT".equalsIgnoreCase(typeName) && size != 1 =>
        // MariaDB connector behaviour
        getCatalystTypeForBitArray
      case Types.BIT if size > 1 =>
        // MySQL connector behaviour
        getCatalystTypeForBitArray
      case Types.VARCHAR if "TINYTEXT".equalsIgnoreCase(typeName) =>
        // TINYTEXT is Types.VARCHAR(63) from mysql jdbc, but keep it AS-IS for historical reason
        Some(StringType)
      case Types.VARCHAR | Types.CHAR if "JSON".equalsIgnoreCase(typeName) =>
        // scalastyle:off line.size.limit
        // Some MySQL JDBC drivers convert JSON type into Types.VARCHAR(-1) or Types.CHAR(Int.Max).
        // MySQL Connector/J 5.x as an example:
        // https://github.com/mysql/mysql-connector-j/blob/release/5.1/src/com/mysql/jdbc/MysqlDefs.java#L295
        // Explicitly converts it into StringType here.
        // scalastyle:on line.size.limit
        Some(StringType)
      case Types.TINYINT =>
        if (md.build().getBoolean("isSigned")) {
          Some(ByteType)
        } else {
          Some(ShortType)
        }
      case Types.SMALLINT =>
        if (md.build().getBoolean("isSigned")) {
          Some(ShortType)
        } else {
          Some(IntegerType)
        }
      case Types.INTEGER if "MEDIUMINT UNSIGNED".equalsIgnoreCase(typeName) =>
        // Signed values in [-8388608, 8388607] and unsigned values in [0, 16777215],
        // both of them fit IntegerType
        Some(IntegerType)
      case Types.REAL | Types.FLOAT =>
        if (md.build().getBoolean("isSigned")) Some(FloatType) else Some(DoubleType)
      case Types.TIMESTAMP if "DATETIME".equalsIgnoreCase(typeName) =>
        // scalastyle:off line.size.limit
        // In MYSQL, DATETIME is TIMESTAMP WITHOUT TIME ZONE
        // https://github.com/mysql/mysql-connector-j/blob/8.3.0/src/main/core-api/java/com/mysql/cj/MysqlType.java#L251
        // scalastyle:on line.size.limit
        Some(getTimestampType(md.build()))
      case Types.TIMESTAMP if !conf.legacyMySqlTimestampNTZMappingEnabled => Some(TimestampType)
      case _ => None
    }
  }

  override def quoteIdentifier(colName: String): String = {
    s"`$colName`"
  }

  override def schemasExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = {
    listSchemas(conn, options).exists(_.head == schema)
  }

  override def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = {
    val schemaBuilder = ArrayBuilder.make[Array[String]]
    try {
      JdbcUtils.executeQuery(conn, options, "SHOW SCHEMAS") { rs =>
        while (rs.next()) {
          schemaBuilder += Array(rs.getString("Database"))
        }
      }
    } catch {
      case _: Exception =>
        logWarning("Cannot show schemas.")
    }
    schemaBuilder.result()
  }

  override def isCascadingTruncateTable(): Option[Boolean] = Some(false)

  // See https://dev.mysql.com/doc/refman/8.0/en/alter-table.html
  override def getUpdateColumnTypeQuery(
      tableName: String,
      columnName: String,
      newDataType: String): String = {
    s"ALTER TABLE $tableName MODIFY COLUMN ${quoteIdentifier(columnName)} $newDataType"
  }

  // See Old Syntax: https://dev.mysql.com/doc/refman/5.6/en/alter-table.html
  // According to https://dev.mysql.com/worklog/task/?id=10761 old syntax works for
  // both versions of MySQL i.e. 5.x and 8.0
  // The old syntax requires us to have type definition. Since we do not have type
  // information, we throw the exception for old version.
  override def getRenameColumnQuery(
      tableName: String,
      columnName: String,
      newName: String,
      dbMajorVersion: Int): String = {
    if (dbMajorVersion >= 8) {
      s"ALTER TABLE $tableName RENAME COLUMN ${quoteIdentifier(columnName)} TO" +
        s" ${quoteIdentifier(newName)}"
    } else {
      throw QueryExecutionErrors.renameColumnUnsupportedForOlderMySQLError()
    }
  }

  // See https://dev.mysql.com/doc/refman/8.0/en/alter-table.html
  // require to have column data type to change the column nullability
  // ALTER TABLE tbl_name MODIFY [COLUMN] col_name column_definition
  // column_definition:
  //    data_type [NOT NULL | NULL]
  // e.g. ALTER TABLE t1 MODIFY b INT NOT NULL;
  // We don't have column data type here, so throw Exception for now
  override def getUpdateColumnNullabilityQuery(
      tableName: String,
      columnName: String,
      isNullable: Boolean): String = {
    throw QueryExecutionErrors.unsupportedUpdateColumnNullabilityError()
  }

  // See https://dev.mysql.com/doc/refman/8.0/en/alter-table.html
  override def getTableCommentQuery(table: String, comment: String): String = {
    s"ALTER TABLE $table COMMENT = '$comment'"
  }

  override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
    // See SPARK-35446: MySQL treats REAL as a synonym to DOUBLE by default
    // We override getJDBCType so that FloatType is mapped to FLOAT instead
    case FloatType => Option(JdbcType("FLOAT", java.sql.Types.FLOAT))
    case StringType => Option(JdbcType("LONGTEXT", java.sql.Types.LONGVARCHAR))
    case ByteType => Option(JdbcType("TINYINT", java.sql.Types.TINYINT))
    case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT))
    // scalastyle:off line.size.limit
    // In MYSQL, DATETIME is TIMESTAMP WITHOUT TIME ZONE
    // https://github.com/mysql/mysql-connector-j/blob/8.3.0/src/main/core-api/java/com/mysql/cj/MysqlType.java#L251
    // scalastyle:on line.size.limit
    case TimestampNTZType if !conf.legacyMySqlTimestampNTZMappingEnabled =>
      Option(JdbcType("DATETIME", java.sql.Types.TIMESTAMP))
    case _ => JdbcUtils.getCommonJDBCType(dt)
  }

  override def getSchemaCommentQuery(schema: String, comment: String): String = {
    throw QueryExecutionErrors.unsupportedCommentNamespaceError(schema)
  }

  override def removeSchemaCommentQuery(schema: String): String = {
    throw QueryExecutionErrors.unsupportedRemoveNamespaceCommentError(schema)
  }

  // CREATE INDEX syntax
  // https://dev.mysql.com/doc/refman/8.0/en/create-index.html
  override def createIndex(
      indexName: String,
      tableIdent: Identifier,
      columns: Array[NamedReference],
      columnsProperties: util.Map[NamedReference, util.Map[String, String]],
      properties: util.Map[String, String]): String = {
    val columnList = columns.map(col => quoteIdentifier(col.fieldNames.head))
    val (indexType, indexPropertyList) = JdbcUtils.processIndexProperties(properties, "mysql")

    // columnsProperties doesn't apply to MySQL so it is ignored
    s"CREATE INDEX ${quoteIdentifier(indexName)} $indexType ON" +
      s" ${quoteIdentifier(tableIdent.name())} (${columnList.mkString(", ")})" +
      s" ${indexPropertyList.mkString(" ")}"
  }

  // SHOW INDEX syntax
  // https://dev.mysql.com/doc/refman/8.0/en/show-index.html
  override def indexExists(
      conn: Connection,
      indexName: String,
      tableIdent: Identifier,
      options: JDBCOptions): Boolean = {
    val sql = s"SHOW INDEXES FROM ${quoteIdentifier(tableIdent.name())} " +
      s"WHERE key_name = '$indexName'"
    JdbcUtils.checkIfIndexExists(conn, sql, options)
  }

  override def dropIndex(indexName: String, tableIdent: Identifier): String = {
    s"DROP INDEX ${quoteIdentifier(indexName)} ON ${tableIdent.name()}"
  }

  // SHOW INDEX syntax
  // https://dev.mysql.com/doc/refman/8.0/en/show-index.html
  override def listIndexes(
      conn: Connection,
      tableIdent: Identifier,
      options: JDBCOptions): Array[TableIndex] = {
    val sql = s"SHOW INDEXES FROM ${tableIdent.name()}"
    var indexMap: Map[String, TableIndex] = Map()
    try {
      JdbcUtils.executeQuery(conn, options, sql) { rs =>
        while (rs.next()) {
          val indexName = rs.getString("key_name")
          val colName = rs.getString("column_name")
          val indexType = rs.getString("index_type")
          val indexComment = rs.getString("index_comment")
          if (indexMap.contains(indexName)) {
            val index = indexMap.get(indexName).get
            val newIndex = new TableIndex(indexName, indexType,
              index.columns() :+ FieldReference(colName),
              index.columnProperties, index.properties)
            indexMap += (indexName -> newIndex)
          } else {
            // The only property we are building here is `COMMENT` because it's the only one
            // we can get from `SHOW INDEXES`.
            val properties = new util.Properties();
            if (indexComment.nonEmpty) properties.put("COMMENT", indexComment)
            val index = new TableIndex(indexName, indexType, Array(FieldReference(colName)),
              new util.HashMap[NamedReference, util.Properties](), properties)
            indexMap += (indexName -> index)
          }
        }
      }
    } catch {
      case _: Exception =>
        logWarning("Cannot retrieved index info.")
    }
    indexMap.values.toArray
  }

  override def classifyException(
      e: Throwable,
      errorClass: String,
      messageParameters: Map[String, String],
      description: String): AnalysisException = {
    e match {
      case sqlException: SQLException =>
        sqlException.getErrorCode match {
          // ER_DUP_KEYNAME
          case 1050 if errorClass == "FAILED_JDBC.RENAME_TABLE" =>
            val newTable = messageParameters("newName")
            throw QueryCompilationErrors.tableAlreadyExistsError(newTable)
          case 1061 if errorClass == "FAILED_JDBC.CREATE_INDEX" =>
            val indexName = messageParameters("indexName")
            val tableName = messageParameters("tableName")
            throw new IndexAlreadyExistsException(indexName, tableName, cause = Some(e))
          case 1091 if errorClass == "FAILED_JDBC.DROP_INDEX" =>
            val indexName = messageParameters("indexName")
            val tableName = messageParameters("tableName")
            throw new NoSuchIndexException(indexName, tableName, cause = Some(e))
          case _ => super.classifyException(e, errorClass, messageParameters, description)
        }
      case unsupported: UnsupportedOperationException => throw unsupported
      case _ => super.classifyException(e, errorClass, messageParameters, description)
    }
  }

  override def dropSchema(schema: String, cascade: Boolean): String = {
    if (cascade) {
      s"DROP SCHEMA ${quoteIdentifier(schema)}"
    } else {
      throw QueryExecutionErrors.unsupportedDropNamespaceError(schema)
    }
  }

  class MySQLSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions)
    extends JdbcSQLQueryBuilder(dialect, options) {

    override def build(): String = {
      val limitOrOffsetStmt = if (limit > 0) {
        if (offset > 0) {
          s"LIMIT $offset, $limit"
        } else {
          dialect.getLimitClause(limit)
        }
      } else if (offset > 0) {
        // MySQL doesn't support OFFSET without LIMIT. According to the suggestion of MySQL
        // official website, in order to retrieve all rows from a certain offset up to the end of
        // the result set, you can use some large number for the second parameter. Please refer:
        // https://dev.mysql.com/doc/refman/8.0/en/select.html
        s"LIMIT $offset, 18446744073709551615"
      } else {
        ""
      }

      options.prepareQuery +
        s"SELECT $columnList FROM ${options.tableOrQuery} $tableSampleClause" +
        s" $whereClause $groupByClause $orderByClause $limitOrOffsetStmt"
    }
  }

  override def getJdbcSQLQueryBuilder(options: JDBCOptions): JdbcSQLQueryBuilder =
    new MySQLSQLQueryBuilder(this, options)

  override def supportsLimit: Boolean = true

  override def supportsOffset: Boolean = true
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy