All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.sourced.gitbase.spark.QueryBuilder.scala Maven / Gradle / Ivy

The newest version!
package tech.sourced.gitbase.spark

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.types._
import tech.sourced.gitbase.spark.udf.{Uast, UastMode}

object QueryBuilder {

  private val dialect = GitbaseDialect()

  def qualify(col: Attribute): String = {
    try {
      val table = col.metadata.getString(Sources.SourceKey)
      s"$table.`${col.name}`"
    } catch {
      case _: NoSuchElementException =>
        s"`${col.name}`"
    }
  }

  def qualify(table: String, col: String): String = s"$table.`$col`"

  def compileValue(value: Any): Any = dialect.compileValue(value)

  def compileExpression(e1: Expression, e2: Expression): Option[(String, String)] = {
    (compileExpression(e1), compileExpression(e2)) match {
      case (Some(a), Some(b)) => Some((a, b))
      case _ => None
    }
  }

  private def stripQuotes(str: String): String =
    if (str.startsWith("'")) {
      str.stripPrefix("'").stripSuffix("'")
    } else {
      str.stripPrefix("\"").stripSuffix("\"")
    }

  // scalastyle:off cyclomatic.complexity
  def compileExpression(expr: Expression): Option[String] = {
    expr match {
      case Alias(child, name) => compileExpression(child)
        .map(x => s"$x AS `$name`")
      case ScalaUDF(_, _, children, _, Some(name), _, _) =>
        val args = children.map(compileExpression)
        if (args.forall(_.isDefined) && udf.isSupported(name)) {
          Some(s"`$name`(${args.map(_.get).mkString(", ")})")
        } else {
          None
        }
      case EqualTo(attr, value) =>
        compileExpression(attr, value).map(x => s"${x._1} = ${x._2}")
      case EqualNullSafe(attr: Attribute, value) =>
        compileExpression(attr, value)
          .map(x => s"(NOT (${x._1} != ${x._2} " +
            s"OR ${x._1} IS NULL OR ${x._2} IS NULL) OR " +
            s"(${x._1} IS NULL AND ${x._2} IS NULL))")
      case LessThan(attr, value) =>
        compileExpression(attr, value).map(x => s"${x._1} < ${x._2}")
      case GreaterThan(attr, value) =>
        compileExpression(attr, value).map(x => s"${x._1} > ${x._2}")
      case LessThanOrEqual(attr, value) =>
        compileExpression(attr, value).map(x => s"${x._1} <= ${x._2}")
      case GreaterThanOrEqual(attr, value) =>
        compileExpression(attr, value).map(x => s"${x._1} >= ${x._2}")
      case IsNull(attr) => compileExpression(attr).map(a => s"$a IS NULL")
      case IsNotNull(attr) => compileExpression(attr)
        .map(a => s"$a IS NOT NULL")
      case RLike(left, right) =>
        compileExpression(left, right).map(x => s"${x._1} REGEXP ${x._2}")

      case StartsWith(attr, value) =>
        compileExpression(attr, value)
          .map(x => s"${x._1} REGEXP '^${stripQuotes(x._2)}'")
      case EndsWith(attr, value) =>
        compileExpression(attr, value)
          .map(x => s"${x._1} REGEXP '${stripQuotes(x._2)}$$'")
      case Contains(attr, value) =>
        compileExpression(attr, value)
          .map(x => s"${x._1} REGEXP '${stripQuotes(x._2)}'")

      case In(attr, value) if value.isEmpty =>
        compileExpression(attr)
          .map(a => s"CASE WHEN $a IS NULL THEN NULL ELSE FALSE END")
      case In(attr, value) =>
        (compileExpression(attr), value.map(x => compileExpression(x))) match {
          case (Some(a), b) if b.forall(_.isDefined) =>
            Some(s"$a IN (${b.map(_.get).mkString(", ")})")
          case _ => None
        }
      case Not(f1) => compileExpression(f1).map(p => s"(NOT ($p))")
      case Or(f1, f2) =>
        // We can't compile Or filter unless both sub-filters are compiled successfully.
        // It applies too for the following And filter.
        // If we can make sure compileFilter supports all filters, we can remove this check.
        val or = Seq(f1, f2).flatMap(x => compileExpression(x))
        Option(if (or.size == 2) {
          s"(${or.mkString(") OR (")})"
        } else {
          null
        })
      case And(f1, f2) =>
        val and = Seq(f1, f2).flatMap(x => compileExpression(x))
        Option(if (and.size == 2) {
          s"(${and.mkString(") AND (")})"
        } else {
          null
        })
      case col: Attribute => Some(qualify(col))
      // Literal sql method prints longs as {NUMBER}L, which is not valid SQL. To
      // prevent that, we need to do a special case for longs.
      case Literal(v: Long, _) => Some(v.toString)
      case lit: Literal => Some(lit.sql)
      case Year(e) => compileExpression(e)
        .map(e => s"YEAR($e)")
      case Month(e) => compileExpression(e)
        .map(e => s"MONTH($e)")
      case DayOfMonth(e) => compileExpression(e)
        .map(e => s"DAY($e)")
      case DayOfYear(e) => compileExpression(e)
        .map(e => s"DAYOFYEAR($e)")
      case Round(e, scale) =>
        compileExpression(e, scale).map(x => s"ROUND(${x._1}, ${x._2})")
      case Ceil(e) => compileExpression(e)
        .map(e => s"CEIL($e)")
      case Floor(e) => compileExpression(e)
        .map(e => s"FLOOR($e)")

      case Substring(str, pos, len) =>
        (compileExpression(str), compileExpression(pos), compileExpression(len)) match {
          case (Some(a), Some(b), Some(c)) => Some(s"SUBSTRING($a, $b, $c)")
          case _ => None
        }

      case Cast(e, StringType, _) => compileExpression(e)
        .map(e => s"CAST($e AS CHAR)")
      case Cast(e, BooleanType | IntegerType | ShortType, _) => compileExpression(e)
        .map(e => s"CAST($e AS SIGNED)")
      case Cast(e, TimestampType, _) => compileExpression(e)
        .map(e => s"CAST($e AS DATETIME)")
      case Cast(e, DateType, _) => compileExpression(e)
        .map(e => s"CAST($e AS DATE)")
      case Cast(e, DoubleType | FloatType | LongType, _) => compileExpression(e)
        .map(e => s"CAST($e AS DECIMAL)")
      case Cast(e, _, _) => compileExpression(e)

      case Min(child) => compileExpression(child)
        .map(e => s"MIN($e)")

      case Max(child) => compileExpression(child)
        .map(e => s"MAX($e)")

      case Sum(child) => compileExpression(child)
        .map(e => s"SUM($e)")

      case Count(children) => if (children.length == 1) {
        compileExpression(children.head).map(e => s"COUNT($e)")
      } else {
        None
      }

      case Average(child) => compileExpression(child)
        .map(e => s"AVG($e)")

      // We only compile non-distinct aggregations. Gitbase does not support
      // distinct in aggregations yet.
      case AggregateExpression(fn, _, false, _) =>
        compileExpression(fn)

      case Uast(children) =>
        val args = children.map(compileExpression)
        if (args.forall(_.isDefined)) {
          Some(s"`uast`(${args.map(_.get).mkString(", ")})")
        } else {
          None
        }

      case UastMode(children) =>
        val args = children.map(compileExpression)
        if (args.forall(_.isDefined)) {
          Some(s"`uast_mode`(${args.map(_.get).mkString(", ")})")
        } else {
          None
        }

      case _ => None
    }
  }

  // scalastyle:on cyclomatic.complexity
}

/**
  * Select query builder.
  */
case class QueryBuilder(node: Node = null,
                        schema: StructType = StructType(Seq()),
                        query: Query = Query()) {

  import QueryBuilder._

  def selectedFields(q: Query): String =
    if (q.project.nonEmpty) {
      q.project.flatMap(e => compileExpression(e)).mkString(", ")
    } else {
      // when there is no field selected, such as a count of repositories,
      // just get the first field to avoid returning all the fields
      schema.fields.headOption match {
        case Some(f) => qualify(AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
        case None =>
          throw new SparkException("unable to build sql query with no columns")
      }
    }

  def whereClause(q: Query = Query()): String = {
    val compiledFilters = q.filters.flatMap(x => compileExpression(x))
    if (compiledFilters.isEmpty) {
      ""
    } else {
      s" WHERE (${compiledFilters.mkString(") AND (")})"
    }
  }

  def orderByClause(query: Query = Query()): String = if (query.sort.isEmpty) {
    ""
  } else {
    s" ORDER BY ${query.sort.flatMap(orderByField).mkString(", ")}"
  }

  def limitClause(query: Query = Query()): String =
    query.limit.map(x => s" LIMIT ${x.toString.stripSuffix("L")}").getOrElse("")

  def orderByField(field: SortOrder): Option[String] = {
    compileExpression(field.child)
      .map(x => s"$x ${field.direction.sql}")
  }

  def groupByClause(query: Query = Query()): String = if (query.grouping.isEmpty) {
    ""
  } else {
    s" GROUP BY ${query.grouping.flatMap(x => compileExpression(x)).mkString(", ")}"
  }

  def getOnClause(cond: Expression): String = compileExpression(cond)
    .getOrElse(throw new SparkException(s"unable to compile expression $cond"))

  def sourceToSql(source: Node): String = source match {
    case Join(left, right, Some(cond)) =>
      s"${sourceToSql(left)} INNER JOIN ${sourceToSql(right)} ON ${getOnClause(cond)}"
    case Join(left, right, None) =>
      s"${sourceToSql(left)} JOIN ${sourceToSql(right)}"
    case Table(name) => name
    case _ => throw new SparkException("invalid node found in query source")
  }

  def subqueryToSql(q: Query): String = s"(${queryToSql(q)}) AS `t`"

  def selectedTables(q: Query = Query()): String = (q.subquery, q.source) match {
    case (Some(_), Some(_)) =>
      throw new SparkException("This is likely a bug. Found source and subquery in the query.")
    case (Some(subquery), None) =>
      subqueryToSql(subquery)
    case (None, Some(source)) =>
      sourceToSql(source)
    case (None, None) =>
      throw new SparkException("no source or subquery found in query")
  }

  def queryToSql(q: Query): String =
    s"SELECT ${selectedFields(q)} FROM ${selectedTables(q)}${whereClause(q)}" +
      s"${groupByClause(q)}${orderByClause(q)}${limitClause(q)}"

  /**
    * Returns the built select SQL query.
    *
    * @return select SQL query
    */
  def sql: String = {
    queryToSql(node.buildQuery(query))
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy