All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.hive.thriftserver.CalliopeSQLOperationManager.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.hive.thriftserver

import java.sql.{DatabaseMetaData, Timestamp}
import java.util
import java.util.{Map => JMap}

import com.datastax.driver.core.KeyspaceMetadata
import com.tuplejump.calliope.server.ReflectionUtils
import com.tuplejump.calliope.sql.{CassandraProperties, CassandraSchemaHelper}
import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.hadoop.hive.metastore.TableType
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hive.service.cli._
import org.apache.hive.service.cli.operation._
import org.apache.hive.service.cli.session.HiveSession
import org.apache.hive.service.cli.thrift.TColumnValue
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.plans.logical.SetCommand
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow}

import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, Map}
import scala.math.{random, round}
import scala.util.matching.Regex

/**
 * Executes queries using Spark SQL, and maintains a list of handles to active queries.
 */
class CalliopeSQLOperationManager(hiveContext: HiveContext)
  extends OperationManager with Logging {

  val handleToOperation = ReflectionUtils
    .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")

  // TODO: Currenlty this will grow infinitely, even as sessions expire
  val sessionToActivePool = Map[HiveSession, String]()

  override def newExecuteStatementOperation(
      parentSession: HiveSession,
      statement: String,
      confOverlay: JMap[String, String],
      async: Boolean): ExecuteStatementOperation = synchronized {

    val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) {
      private var result: SchemaRDD = _
      private var iter: Iterator[SparkRow] = _
      private var dataTypes: Array[DataType] = _

      def close(): Unit = {
        // RDDs will be cleaned automatically upon garbage collection.
        logDebug("CLOSING")
      }

      def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = {
        if (!iter.hasNext) {
          new RowSet()
        } else {
          // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int
          val maxRows = maxRowsL.toInt
          var curRow = 0
          var rowSet = new ArrayBuffer[Row](maxRows.min(1024))

          while (curRow < maxRows && iter.hasNext) {
            val sparkRow = iter.next()
            val row = new Row()
            var curCol = 0

            while (curCol < sparkRow.length) {
              if (sparkRow.isNullAt(curCol)) {
                addNullColumnValue(sparkRow, row, curCol)
              } else {
                addNonNullColumnValue(sparkRow, row, curCol)
              }
              curCol += 1
            }
            rowSet += row
            curRow += 1
          }
          new RowSet(rowSet, 0)
        }
      }

      def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) {
        dataTypes(ordinal) match {
          case StringType =>
            to.addString(from(ordinal).asInstanceOf[String])
          case IntegerType =>
            to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal)))
          case BooleanType =>
            to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal)))
          case DoubleType =>
            to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal)))
          case FloatType =>
            to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal)))
          case DecimalType =>
            val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal
            to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
          case LongType =>
            to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal)))
          case ByteType =>
            to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal)))
          case ShortType =>
            to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal)))
          case TimestampType =>
            to.addColumnValue(
              ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp]))
          case BinaryType | _: ArrayType | _: StructType | _: MapType =>
            val hiveString = result
              .queryExecution
              .asInstanceOf[HiveContext#QueryExecution]
              .toHiveString((from.get(ordinal), dataTypes(ordinal)))
            to.addColumnValue(ColumnValue.stringValue(hiveString))
        }
      }

      def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) {
        dataTypes(ordinal) match {
          case StringType =>
            to.addString(null)
          case IntegerType =>
            to.addColumnValue(ColumnValue.intValue(null))
          case BooleanType =>
            to.addColumnValue(ColumnValue.booleanValue(null))
          case DoubleType =>
            to.addColumnValue(ColumnValue.doubleValue(null))
          case FloatType =>
            to.addColumnValue(ColumnValue.floatValue(null))
          case DecimalType =>
            to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal))
          case LongType =>
            to.addColumnValue(ColumnValue.longValue(null))
          case ByteType =>
            to.addColumnValue(ColumnValue.byteValue(null))
          case ShortType =>
            to.addColumnValue(ColumnValue.intValue(null))
          case TimestampType =>
            to.addColumnValue(ColumnValue.timestampValue(null))
          case BinaryType | _: ArrayType | _: StructType | _: MapType =>
            to.addColumnValue(ColumnValue.stringValue(null: String))
        }
      }

      def getResultSetSchema: TableSchema = {
        logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}")
        if (result.queryExecution.analyzed.output.size == 0) {
          new TableSchema(new FieldSchema("Result", "string", "") :: Nil)
        } else {
          val schema = result.queryExecution.analyzed.output.map { attr =>
            new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "")
          }
          new TableSchema(schema)
        }
      }

      def run(): Unit = {
        logInfo(s"Running query '$statement'")
        setState(OperationState.RUNNING)
        try {
          result = hiveContext.sql(statement)
          logDebug(result.queryExecution.toString())
          result.queryExecution.logical match {
            case SetCommand(Some(key), Some(value)) if (key == SQLConf.THRIFTSERVER_POOL) =>
              sessionToActivePool(parentSession) = value
              logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.")
            case _ =>
          }

          val groupId = round(random * 1000000).toString
          hiveContext.sparkContext.setJobGroup(groupId, statement)
          sessionToActivePool.get(parentSession).foreach { pool =>
            hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool)
          }
          iter = {
            val resultRdd = result.queryExecution.toRdd
            val useIncrementalCollect =
              hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
            if (useIncrementalCollect) {
              resultRdd.toLocalIterator
            } else {
              resultRdd.collect().iterator
            }
          }
          dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray
          setHasResultSet(true)
        } catch {
          // Actually do need to catch Throwable as some failures don't inherit from Exception and
          // HiveServer will silently swallow them.
          case e: Throwable =>
            logError("Error executing query:",e)
            throw new HiveSQLException(e.toString)
        }
        setState(OperationState.FINISHED)
      }
    }

   handleToOperation.put(operation.getHandle, operation)
   operation
  }

  override def newGetCatalogsOperation(parentSession: HiveSession): GetCatalogsOperation = {
    val operation: GetCatalogsOperation = new GetCatalogsOperation(parentSession){
      override def run(): Unit = {
        println("Running GetCatalogsOperation")
        super.run()
        val rowSet = ReflectionUtils.getSuperField[RowSet](this, "rowSet")
        val rows = ReflectionUtils.getPrivateField[java.util.ArrayList[Row]](rowSet, "rows")
        println(rows)
        //setSuperField(obj : Object, fieldName: String, fieldValue: Object)
      }
    }

    handleToOperation.put(operation.getHandle, operation)
    return operation
  }

  override def newGetSchemasOperation(parentSession: HiveSession, catalogName: String, schemaName: String): GetSchemasOperation = {
    val operation: GetSchemasOperation = new GetSchemasOperation(parentSession, catalogName, schemaName){
      val schemaNameI = schemaName
      private val RESULT_SET_SCHEMA = ReflectionUtils.getSuperField[TableSchema](this, "RESULT_SET_SCHEMA")

      override def run(): Unit = {
        println("Running GetSchemasOperation")
        super.run()
        val rowSet = ReflectionUtils.getSuperField[RowSet](this, "rowSet")
        val rows = ReflectionUtils.getPrivateField[util.ArrayList[Row]](rowSet, "rows")
        printRows(rows)
        val patternString:String = if(schemaNameI.isEmpty){
          "(.*)"
        } else{
          val subpatterns: Array[String] = convertPattern(schemaNameI).trim.split("\\|")
          s"(?i)(${subpatterns.mkString(" || ")})"
        }

        println(s"PATTERN: $patternString")

        val pattern = patternString.r

        val cassandraProperties = CassandraProperties(hiveContext.sparkContext)

        val cassandraMeta = CassandraSchemaHelper.getCassandraMetadata(cassandraProperties.cassandraHost,
          cassandraProperties.cassandraNativePort,
          cassandraProperties.cassandraUsername,
          cassandraProperties.cassandraPassword)

        if(cassandraMeta != null){
          cassandraMeta.getKeyspaces.map(_.getName).foreach {
            case pattern(dbname) =>
              println(s"Adding $dbname")
              rowSet.addRow(RESULT_SET_SCHEMA, Array(dbname, ""));
            case _ => //Do nothing
          }
        }

        ReflectionUtils.setSuperField(this, "rowSet", rowSet)
      }
    }
    handleToOperation.put(operation.getHandle, operation)
    operation
  }

  def printRows(rows: util.ArrayList[Row]) {
    println("[[")
    rows.foreach(println)
    println("]]")
  }

  private def convertPattern(pattern: String): String = {
  val wStr = ".*"
  pattern.replaceAll("([^\\\\])%", "$1" + wStr).replaceAll("\\\\%", "%").replaceAll("^%", wStr)
}

  override def newGetTablesOperation(parentSession: HiveSession, catalogName: String, schemaName: String, tableName: String, tableTypes: util.List[String]): MetadataOperation = {
    val operation: MetadataOperation = new GetTablesOperation(parentSession, catalogName, schemaName, tableName, tableTypes) {
      private val RESULT_SET_SCHEMA = ReflectionUtils.getSuperField[TableSchema](this, "RESULT_SET_SCHEMA")

      val schemaNameI = schemaName
      val tableNameI = tableName

      override def run(): Unit = {
        println("Running GetTablesOperation")
        super.run()
        val rowSet = ReflectionUtils.getSuperField[RowSet](this, "rowSet")
        val rows = ReflectionUtils.getPrivateField[java.util.ArrayList[Row]](rowSet, "rows")
        printRows(rows)

        val schemaPattern = if(schemaNameI.isEmpty){
          "(.*)"
        } else{
          val subpatterns: Array[String] = convertPattern(schemaNameI).trim.split("\\|")
          s"(?i)(${subpatterns.mkString(" || ")})"
        }


        val cassandraProperties = CassandraProperties(hiveContext.sparkContext)

        val cassandraMeta = CassandraSchemaHelper.getCassandraMetadata(cassandraProperties.cassandraHost,
          cassandraProperties.cassandraNativePort,
          cassandraProperties.cassandraUsername,
          cassandraProperties.cassandraPassword)

        if (cassandraMeta != null) {
          val tablePattern = if(tableNameI == null || tableNameI.isEmpty){
            "(.*)"
          } else{
            val subpatterns: Array[String] = convertPattern(schemaNameI).trim.split("\\|")
            s"(?i)(${subpatterns.mkString(" || ")})"
          }


          cassandraMeta.getKeyspaces.filter(k => (!k.getName.startsWith("system")) && k.getName.matches(schemaPattern)).foreach {
            ks =>
              ks.getTables.filter(_.getName.matches(tablePattern)).foreach {
                tbl =>
                  println(s"Adding table: ${tbl.getName}")
                  val rowData: Array[AnyRef] = Array[AnyRef](
                    "",
                    "",
                    s"${ks.getName}.${tbl.getName}",
                    "CASSANDRA_TABLES",
                    tbl.getOptions.getComment)

                  rowSet.addRow(RESULT_SET_SCHEMA, rowData)
              }
          }

          ReflectionUtils.setSuperField(this, "rowSet", rowSet)
        }
      }
    }
    handleToOperation.put(operation.getHandle, operation)
    operation
  }

  override def newGetTableTypesOperation(parentSession: HiveSession): GetTableTypesOperation = {
    val operation: GetTableTypesOperation = new GetTableTypesOperation(parentSession){
      private val RESULT_SET_SCHEMA = ReflectionUtils.getSuperField[TableSchema](this, "RESULT_SET_SCHEMA")

      override def run(): Unit = {
        println("Running GetTableTypesOperation")
        super.run()
        val rowSet = ReflectionUtils.getSuperField[RowSet](this, "rowSet")
        val rows = ReflectionUtils.getPrivateField[java.util.ArrayList[Row]](rowSet, "rows")
        printRows(rows)
        rowSet.addRow(RESULT_SET_SCHEMA, Array("CASSANDRA_TABLES"))

        ReflectionUtils.setSuperField(this, "rowSet", rowSet)
      }
    }
    handleToOperation.put(operation.getHandle, operation)
    operation
  }

  override def newGetColumnsOperation(parentSession: HiveSession, catalogName: String, schemaName: String, tableName: String, columnName: String): GetColumnsOperation = {
    val operation: GetColumnsOperation = new GetColumnsOperation(parentSession, catalogName, schemaName, tableName, columnName){
      private val RESULT_SET_SCHEMA = ReflectionUtils.getSuperField[TableSchema](this, "RESULT_SET_SCHEMA")

      override def run(): Unit = {
        println(s"Running GetColumnsOperation $schemaName : $tableName")

        val keyspace = if(schemaName == null || schemaName.isEmpty || schemaName.equals("*")) {
          val splits = tableName.split("\\.")
          splits(0)
        }else {
          schemaName
        }
        println(s"Checking keyspace: $keyspace")

        val cassandraProperties = CassandraProperties(hiveContext.sparkContext)

        val cassandraMeta = CassandraSchemaHelper.getCassandraMetadata(cassandraProperties.cassandraHost,
          cassandraProperties.cassandraNativePort,
          cassandraProperties.cassandraUsername,
          cassandraProperties.cassandraPassword)

        val keyspaceMeta: KeyspaceMetadata = if(cassandraMeta != null) cassandraMeta.getKeyspace(keyspace) else null
        println(keyspaceMeta)
        if (keyspaceMeta != null) {
          val rowSet = new RowSet()
          val tableStr = if(tableName.contains(".")) tableName.split("\\.")(1) else tableName
          setState(OperationState.RUNNING)
          val table = keyspaceMeta.getTable(tableStr)

          table.getColumns.zipWithIndex.foreach {
            case (column, idx) =>
            val rowData: Array[AnyRef] = Array(
              null,
              "",
              s"$keyspace.$tableStr",
              column.getName,
              toJavaSqlType(column.getType),
              column.getType.getName.toString,
              getTypeSize(column.getType),
              null,
              getDecimalDigits(column.getType),
              getNumPrecRadix(column.getType),
              new Integer(DatabaseMetaData.columnNullable),
              "",
              null,
              null,
              null,
              null,
              new Integer(idx),
              "YES",
              null,
              null,
              null,
              null,
              "NO")

            rowSet.addRow(RESULT_SET_SCHEMA, rowData)
          }

          ReflectionUtils.setSuperField(this, "rowSet", rowSet)
          setState(OperationState.FINISHED)
        } else {
          super.run()
          val rowSet = ReflectionUtils.getSuperField[RowSet](this, "rowSet")
          val rows = ReflectionUtils.getPrivateField[java.util.ArrayList[Row]](rowSet, "rows")
          printRows(rows)
        }
      }
    }
    handleToOperation.put(operation.getHandle, operation)
    operation
  }

  import com.datastax.driver.core.{DataType => CassandraDataType}
  private def toJavaSqlType(dataType: CassandraDataType): Integer = {
    dataType.getName match {
      case CassandraDataType.Name.ASCII => java.sql.Types.VARCHAR
      case CassandraDataType.Name.BIGINT => java.sql.Types.BIGINT
      case CassandraDataType.Name.BLOB => java.sql.Types.BLOB
      case CassandraDataType.Name.BOOLEAN => java.sql.Types.BOOLEAN
      case CassandraDataType.Name.COUNTER => java.sql.Types.BIGINT
      case CassandraDataType.Name.DECIMAL => java.sql.Types.DECIMAL
      case CassandraDataType.Name.DOUBLE => java.sql.Types.DOUBLE
      case CassandraDataType.Name.FLOAT => java.sql.Types.FLOAT
      case CassandraDataType.Name.INT => java.sql.Types.INTEGER
      case CassandraDataType.Name.TEXT => java.sql.Types.VARCHAR
      case CassandraDataType.Name.VARINT => java.sql.Types.DECIMAL //Big Integer is treated as BigDecimal by Catalyst
      case CassandraDataType.Name.INET => java.sql.Types.VARCHAR //TODO: Stopgap solution
      case CassandraDataType.Name.UUID => java.sql.Types.VARCHAR //TODO: Stopgap solution
      case CassandraDataType.Name.TIMEUUID => java.sql.Types.VARCHAR //TODO: Stopgap solution
      case CassandraDataType.Name.TIMESTAMP => java.sql.Types.TIMESTAMP
      case CassandraDataType.Name.CUSTOM => java.sql.Types.BLOB //TODO: Stopgap solution. Explore use of UDF
      case CassandraDataType.Name.MAP => java.sql.Types.VARCHAR
      case CassandraDataType.Name.SET => java.sql.Types.VARCHAR
      case CassandraDataType.Name.LIST => java.sql.Types.VARCHAR
      case CassandraDataType.Name.VARCHAR => java.sql.Types.VARCHAR
    }
  }

  private def getTypeSize(dataType: CassandraDataType): Integer = {
    dataType.getName match {
      case CassandraDataType.Name.ASCII => Integer.MAX_VALUE
      case CassandraDataType.Name.VARCHAR => Integer.MAX_VALUE
      case CassandraDataType.Name.BLOB => Integer.MAX_VALUE
      case CassandraDataType.Name.TIMESTAMP => 30
      case _ => null
    }
  }

  private def getDecimalDigits(dataType: CassandraDataType): Integer = {
    dataType.getName match {
      case CassandraDataType.Name.BIGINT => 0
      case CassandraDataType.Name.BOOLEAN => 0
      case CassandraDataType.Name.COUNTER => 0
      case CassandraDataType.Name.INT => 0
      case CassandraDataType.Name.FLOAT => 7
      case CassandraDataType.Name.DOUBLE => 15
      case _ => null
    }
  }


  private def getPrecision(dataType: CassandraDataType): Integer = {
    dataType.getName match {
      case CassandraDataType.Name.INT => 10
      case CassandraDataType.Name.BIGINT => 19
      case CassandraDataType.Name.FLOAT => 7
      case CassandraDataType.Name.DOUBLE => 15
      case _ => null
    }
  }

  /**
   * Scale for this type.
   */
  private def getScale(dataType: CassandraDataType): Integer = {
    dataType.getName match {
      case CassandraDataType.Name.BIGINT => 0
      case CassandraDataType.Name.BOOLEAN => 0
      case CassandraDataType.Name.COUNTER => 0
      case CassandraDataType.Name.INT => 0
      case CassandraDataType.Name.FLOAT => 7
      case CassandraDataType.Name.DOUBLE => 15
      case CassandraDataType.Name.DECIMAL => Integer.MAX_VALUE
      case CassandraDataType.Name.VARINT => Integer.MAX_VALUE
      case _ => null
    }
  }

  private def getNumPrecRadix(dataType: CassandraDataType): Integer = {
    dataType.getName match {
      case CassandraDataType.Name.BIGINT => 10
      case CassandraDataType.Name.COUNTER => 10
      case CassandraDataType.Name.INT => 10
      case CassandraDataType.Name.DECIMAL => 10
      case CassandraDataType.Name.VARINT => 10
      case CassandraDataType.Name.FLOAT => 2
      case CassandraDataType.Name.DOUBLE => 2
      case _ => null
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy