All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.comet.fuzz.QueryGen.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.comet.fuzz

import java.io.{BufferedWriter, FileWriter}

import scala.collection.mutable
import scala.util.Random

import org.apache.spark.sql.SparkSession

object QueryGen {

  def generateRandomQueries(
      r: Random,
      spark: SparkSession,
      numFiles: Int,
      numQueries: Int): Unit = {
    for (i <- 0 until numFiles) {
      spark.read.parquet(s"test$i.parquet").createTempView(s"test$i")
    }

    val w = new BufferedWriter(new FileWriter("queries.sql"))

    val uniqueQueries = mutable.HashSet[String]()

    for (_ <- 0 until numQueries) {
      val sql = r.nextInt().abs % 8 match {
        case 0 => generateJoin(r, spark, numFiles)
        case 1 => generateAggregate(r, spark, numFiles)
        case 2 => generateScalar(r, spark, numFiles)
        case 3 => generateCast(r, spark, numFiles)
        case 4 => generateUnaryArithmetic(r, spark, numFiles)
        case 5 => generateBinaryArithmetic(r, spark, numFiles)
        case 6 => generateBinaryComparison(r, spark, numFiles)
        case _ => generateConditional(r, spark, numFiles)
      }
      if (!uniqueQueries.contains(sql)) {
        uniqueQueries += sql
        w.write(sql + "\n")
      }
    }
    w.close()
  }

  private def generateAggregate(r: Random, spark: SparkSession, numFiles: Int): String = {
    val tableName = s"test${r.nextInt(numFiles)}"
    val table = spark.table(tableName)

    val func = Utils.randomChoice(Meta.aggFunc, r)
    val args = Range(0, func.num_args)
      .map(_ => Utils.randomChoice(table.columns, r))

    val groupingCols = Range(0, 2).map(_ => Utils.randomChoice(table.columns, r))

    if (groupingCols.isEmpty) {
      s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " +
        s"FROM $tableName " +
        s"ORDER BY ${args.mkString(", ")};"
    } else {
      s"SELECT ${groupingCols.mkString(", ")}, ${func.name}(${args.mkString(", ")}) " +
        s"FROM $tableName " +
        s"GROUP BY ${groupingCols.mkString(",")} " +
        s"ORDER BY ${groupingCols.mkString(", ")};"
    }
  }

  private def generateScalar(r: Random, spark: SparkSession, numFiles: Int): String = {
    val tableName = s"test${r.nextInt(numFiles)}"
    val table = spark.table(tableName)

    val func = Utils.randomChoice(Meta.scalarFunc, r)
    val args = Range(0, func.num_args)
      .map(_ => Utils.randomChoice(table.columns, r))

    // Example SELECT c0, log(c0) as x FROM test0
    s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " +
      s"FROM $tableName " +
      s"ORDER BY ${args.mkString(", ")};"
  }

  private def generateUnaryArithmetic(r: Random, spark: SparkSession, numFiles: Int): String = {
    val tableName = s"test${r.nextInt(numFiles)}"
    val table = spark.table(tableName)

    val op = Utils.randomChoice(Meta.unaryArithmeticOps, r)
    val a = Utils.randomChoice(table.columns, r)

    // Example SELECT a, -a FROM test0
    s"SELECT $a, $op$a " +
      s"FROM $tableName " +
      s"ORDER BY $a;"
  }

  private def generateBinaryArithmetic(r: Random, spark: SparkSession, numFiles: Int): String = {
    val tableName = s"test${r.nextInt(numFiles)}"
    val table = spark.table(tableName)

    val op = Utils.randomChoice(Meta.binaryArithmeticOps, r)
    val a = Utils.randomChoice(table.columns, r)
    val b = Utils.randomChoice(table.columns, r)

    // Example SELECT a, b, a+b FROM test0
    s"SELECT $a, $b, $a $op $b " +
      s"FROM $tableName " +
      s"ORDER BY $a, $b;"
  }

  private def generateBinaryComparison(r: Random, spark: SparkSession, numFiles: Int): String = {
    val tableName = s"test${r.nextInt(numFiles)}"
    val table = spark.table(tableName)

    val op = Utils.randomChoice(Meta.comparisonOps, r)
    val a = Utils.randomChoice(table.columns, r)
    val b = Utils.randomChoice(table.columns, r)

    // Example SELECT a, b, a <=> b FROM test0
    s"SELECT $a, $b, $a $op $b " +
      s"FROM $tableName " +
      s"ORDER BY $a, $b;"
  }

  private def generateConditional(r: Random, spark: SparkSession, numFiles: Int): String = {
    val tableName = s"test${r.nextInt(numFiles)}"
    val table = spark.table(tableName)

    val op = Utils.randomChoice(Meta.comparisonOps, r)
    val a = Utils.randomChoice(table.columns, r)
    val b = Utils.randomChoice(table.columns, r)

    // Example SELECT a, b, IF(a <=> b, 1, 2), CASE WHEN a <=> b THEN 1 ELSE 2 END FROM test0
    s"SELECT $a, $b, $a $op $b, IF($a $op $b, 1, 2), CASE WHEN $a $op $b THEN 1 ELSE 2 END " +
      s"FROM $tableName " +
      s"ORDER BY $a, $b;"
  }

  private def generateCast(r: Random, spark: SparkSession, numFiles: Int): String = {
    val tableName = s"test${r.nextInt(numFiles)}"
    val table = spark.table(tableName)

    val toType = Utils.randomWeightedChoice(Meta.dataTypes, r).sql
    val arg = Utils.randomChoice(table.columns, r)

    // We test both `cast` and `try_cast` to cover LEGACY and TRY eval modes. It is not
    // recommended to run Comet Fuzz with ANSI enabled currently.
    // Example SELECT c0, cast(c0 as float), try_cast(c0 as float) FROM test0
    s"SELECT $arg, cast($arg as $toType), try_cast($arg as $toType) " +
      s"FROM $tableName " +
      s"ORDER BY $arg;"
  }

  private def generateJoin(r: Random, spark: SparkSession, numFiles: Int): String = {
    val leftTableName = s"test${r.nextInt(numFiles)}"
    val rightTableName = s"test${r.nextInt(numFiles)}"
    val leftTable = spark.table(leftTableName)
    val rightTable = spark.table(rightTableName)

    val leftCol = Utils.randomChoice(leftTable.columns, r)
    val rightCol = Utils.randomChoice(rightTable.columns, r)

    val joinTypes = Seq(("INNER", 0.4), ("LEFT", 0.3), ("RIGHT", 0.3))
    val joinType = Utils.randomWeightedChoice(joinTypes, r)

    val leftColProjection = leftTable.columns.map(c => s"l.$c").mkString(", ")
    val rightColProjection = rightTable.columns.map(c => s"r.$c").mkString(", ")
    "SELECT " +
      s"$leftColProjection, " +
      s"$rightColProjection " +
      s"FROM $leftTableName l " +
      s"$joinType JOIN $rightTableName r " +
      s"ON l.$leftCol = r.$rightCol " +
      "ORDER BY " +
      s"$leftColProjection, " +
      s"$rightColProjection;"
  }

}

case class Function(name: String, num_args: Int)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy