All Downloads are FREE. Search and download functionalities are using the official Maven repository.

visitor.QueryVisitor.kt Maven / Gradle / Ivy

There is a newer version: 1.1.10
Show newest version
package visitor

import ast.expr.*
import ast.order.SqlOrderBy
import database.DB
import dsl.*
import query.select.SelectQuery
import java.sql.SQLException
import java.util.*

fun getQueryExpr(query: Query?, dbType: DB): QueryExpr {
    return when (query) {
        null -> QueryExpr(SqlNullExpr())

        is QueryColumn -> visitQueryColumn(query)

        is QueryExprFunction -> visitQueryExprFunction(query, dbType)

        is QueryAggFunction -> visitQueryAggFunction(query, dbType)

        is QueryConst<*> -> QueryExpr(getExpr(query.value), query.alias)

        is QueryBinary -> visitQueryBinary(query, dbType)

        is QueryExpr -> query

        is QueryCase<*> -> visitQueryCase(query, dbType)

        is QuerySub -> visitQuerySub(query)

        is QueryTableColumn -> visitQueryTableColumn(query)

        is QueryJson -> visitQueryJson(query, dbType)

        is QueryCast -> visitQueryCast(query, dbType)

        is QueryInList<*> -> visitQueryInList(query, dbType)

        is QueryInSubQuery -> visitQueryInSubQuery(query, dbType)

        is QueryBetween<*> -> visitQueryBetween(query, dbType)

        is QueryAllColumn -> visitQueryAllColumn(query)

        is QueryOver -> visitQueryOver(query, dbType)

        is QuerySubQueryPredicate -> visitQuerySubQueryPredicate(query)
    }
}

fun visitQueryColumn(query: QueryColumn): QueryExpr {
    return if (query.column.contains(".")) {
        val split = query.column.split(".")
        val expr = if (split.last().contains("*")) {
            SqlAllColumnExpr(split.first())
        } else {
            SqlPropertyExpr(split.first(), split.last())
        }
        QueryExpr(expr, query.alias)
    } else {
        val expr = if (query.column.contains("*")) {
            SqlAllColumnExpr()
        } else {
            SqlIdentifierExpr(query.column)
        }
        QueryExpr(expr, query.alias)
    }
}

val specialExprFunction = mapOf(
    "*IFNULL" to ::visitFunctionIfNull,
    "*FIND_IN_SET" to ::visitFunctionFindInSet,
    "*JSON_LENGTH" to ::visitFunctionJsonLength,
    "*CONCAT" to ::visitFunctionConcat,
    "*CONCAT_WS" to ::visitFunctionConcatWs
)

fun visitQueryExprFunction(query: QueryExprFunction, dbType: DB): QueryExpr {
    if (specialExprFunction.contains(query.name)) {
        return (specialExprFunction[query.name]!!)(query, dbType)
    }

    val expr = SqlExprFunctionExpr(query.name)

    query.args.map { getQueryExpr(it, dbType).expr }.forEach { expr.args.add(it) }

    return QueryExpr(expr, query.alias)
}

val specialAggFunction = mapOf("*STRING_AGG" to ::visitFunctionStringAgg, "*ARRAY_AGG" to ::visitFunctionArrayAgg)

fun visitQueryAggFunction(query: QueryAggFunction, dbType: DB): QueryExpr {
    if (specialAggFunction.contains(query.name)) {
        return (specialAggFunction[query.name]!!)(query, dbType)
    }

    val expr = SqlAggFunctionExpr(query.name, distinct = query.distinct)

    query.args.map { getQueryExpr(it, dbType).expr }.forEach { expr.args.add(it) }

    query.attributes?.let { attributes ->
        attributes.forEach { (k, v) -> expr.attributes[k] = getQueryExpr(v, dbType).expr }
    }

    if (query.orderBy.isNotEmpty()) {
        query.orderBy.forEach {
            val orderByItem = SqlOrderBy(getQueryExpr(it.query, dbType).expr, it.order)
            expr.orderBy.add(orderByItem)
        }
    }
    return QueryExpr(expr, query.alias)
}

fun visitQueryBinary(query: QueryBinary, dbType: DB): QueryExpr {
    val left = getQueryExpr(query.left, dbType).expr
    val operator = query.operator
    val right = getQueryExpr(query.right, dbType).expr
    val expr = SqlBinaryExpr(left, operator, right)

    return QueryExpr(expr, query.alias)
}

fun visitQueryCase(query: QueryCase<*>, dbType: DB): QueryExpr {
    val default = if (query.default is Query) {
        getQueryExpr(query.default, dbType).expr
    } else {
        getExpr(query.default)
    }
    val expr = SqlCaseExpr(default = default)


    query.conditions.forEach {
        val then = if (it.then is Query) {
            getQueryExpr(it.then, dbType).expr
        } else {
            getExpr(it.then)
        }
        expr.caseList.add(SqlCase(getQueryExpr(it.query, dbType).expr, then))
    }

    return QueryExpr(expr, query.alias)
}

fun visitQuerySub(query: QuerySub): QueryExpr {
    val expr = SqlSelectQueryExpr(query.selectQuery.getSelect())
    return QueryExpr(expr, query.alias)
}

fun visitQueryTableColumn(query: QueryTableColumn): QueryExpr {
    val expr = SqlPropertyExpr(query.table, query.column)
    return QueryExpr(expr, query.alias)
}

fun visitQueryJson(query: QueryJson, dbType: DB): QueryExpr {
    val operator = query.operator
    return when (dbType) {
        DB.PGSQL -> {
            val valueExpr = when (query.value) {
                is Int -> SqlNumberExpr(query.value)
                is String -> SqlCharExpr(query.value)
                else -> throw TypeCastException("取Json值时,表达式右侧只支持String或Int")
            }

            val cast = if (query.query is QueryJson) {
                query.query
            } else {
                cast(query.query, "JSONB")
            }

            val expr = SqlBinaryExpr(getQueryExpr(cast, dbType).expr, operator, valueExpr)
            QueryExpr(expr, query.alias)
        }
        DB.MYSQL -> {
            val visitMysql = visitMysqlQueryJson(query)
            val expr =
                SqlBinaryExpr(getQueryExpr(visitMysql.second, dbType).expr, operator, SqlCharExpr(visitMysql.first))
            QueryExpr(expr, query.alias)
        }
        else -> throw TypeCastException("Json操作暂不支持此数据库")
    }
}

fun visitQueryCast(query: QueryCast, dbType: DB): QueryExpr {
    val expr = SqlCastExpr(getQueryExpr(query.query, dbType).expr, query.type)
    return QueryExpr(expr, query.alias)
}

fun visitQueryInList(query: QueryInList<*>, dbType: DB): QueryExpr {
    val expr = SqlInExpr(getQueryExpr(query.query, dbType).expr, SqlListExpr(query.list.map {
        if (it is Query) {
            getQueryExpr(it, dbType).expr
        } else {
            getExpr(it)
        }
    }), query.isNot)
    return QueryExpr(expr)
}

fun visitQueryInSubQuery(query: QueryInSubQuery, dbType: DB): QueryExpr {
    val expr =
        SqlInExpr(getQueryExpr(query.query, dbType).expr, SqlSelectQueryExpr(query.subQuery.getSelect()), query.isNot)
    return QueryExpr(expr)
}

fun visitQueryBetween(query: QueryBetween<*>, dbType: DB): QueryExpr {
    val start = if (query.start is Query) {
        getQueryExpr(query.start, dbType).expr
    } else {
        getExpr(query.start)
    }

    val end = if (query.end is Query) {
        getQueryExpr(query.end, dbType).expr
    } else {
        getExpr(query.end)
    }

    val expr = SqlBetweenExpr(getQueryExpr(query.query, dbType).expr, start, end, query.isNot)
    return QueryExpr(expr)
}

fun visitQueryAllColumn(query: QueryAllColumn): QueryExpr {
    val expr = SqlAllColumnExpr(query.owner)

    return QueryExpr(expr)
}

fun visitQueryOver(query: QueryOver, dbType: DB): QueryExpr {
    val agg = visitQueryAggFunction(query.function, dbType).expr as SqlAggFunctionExpr

    val over = SqlOverExpr(agg)

    query.partitionBy.forEach { over.partitionBy.add(getQueryExpr(it, dbType).expr) }
    query.orderBy.forEach {
        val orderBy = SqlOrderBy(getQueryExpr(it.query, dbType).expr, it.order)
        over.orderBy.add(orderBy)
    }

    return QueryExpr(over, query.alias)
}

fun visitFunctionIfNull(query: QueryExprFunction, dbType: DB): QueryExpr {
    val function = when (dbType) {
        DB.MYSQL, DB.SQLITE -> QueryExprFunction("IFNULL", listOf(query.args[0], query.args[1]))
        DB.PGSQL -> QueryExprFunction("COALESCE", listOf(query.args[0], query.args[1]))
        DB.ORACLE -> QueryExprFunction("NVL", listOf(query.args[0], query.args[1]))
        DB.HIVE, DB.CLICKHOUSE -> QueryExprFunction("IF", listOf(query.args[0].isNull(), query.args[1], query.args[0]))
        DB.SQLSERVER -> QueryExprFunction("ISNULL", listOf(query.args[0], query.args[1]))
    }
    return QueryExpr(getQueryExpr(function, dbType).expr, query.alias)
}

fun visitFunctionFindInSet(query: QueryExprFunction, dbType: DB): QueryExpr {
    val function = when (dbType) {
        DB.MYSQL -> QueryExprFunction("FIND_IN_SET", listOf(query.args[0], query.args[1]))
        DB.PGSQL -> QueryBinary(
            cast(query.args[0], "VARCHAR"),
            SqlBinaryOperator.EQ,
            QueryExprFunction("ANY", listOf(QueryExprFunction("STRING_TO_ARRAY", listOf(query.args[1], const(",")))))
        )
        // TODO
        else -> throw TypeCastException("暂不支持该数据库使用此函数")
    }

    return QueryExpr(getQueryExpr(function, dbType).expr, query.alias)
}

fun visitFunctionJsonLength(query: QueryExprFunction, dbType: DB): QueryExpr {
    val arg0 = query.args[0]

    val function = when (dbType) {
        DB.MYSQL -> {
            if (arg0 is QueryJson) {
                val visitMysql = visitMysqlQueryJson(arg0)
                QueryExprFunction("JSON_LENGTH", listOf(visitMysql.second, const(visitMysql.first)))
            } else {
                QueryExprFunction("JSON_LENGTH", listOf(arg0))
            }
        }
        DB.PGSQL -> {
            if (arg0 is QueryJson) {
                QueryExprFunction("JSONB_ARRAY_LENGTH", listOf(arg0))
            } else {
                QueryExprFunction("JSONB_ARRAY_LENGTH", listOf(cast(arg0, "JSONB")))
            }
        }
        // TODO
        else -> throw TypeCastException("暂不支持该数据库使用此函数")
    }

    return QueryExpr(getQueryExpr(function, dbType).expr, query.alias)
}

fun visitFunctionConcat(query: QueryExprFunction, dbType: DB): QueryExpr {
    val function = when (dbType) {
        DB.MYSQL, DB.PGSQL, DB.HIVE, DB.SQLSERVER, DB.CLICKHOUSE -> QueryExprFunction("CONCAT", query.args)
        DB.ORACLE, DB.SQLITE -> {
            var left: QueryBinary? = null
            query.args.forEachIndexed { index, item ->
                if (query.args.size > index + 1) {
                    left = if (left == null) {
                        QueryBinary(item, SqlBinaryOperator.CONCAT, query.args[index + 1])
                    } else {
                        QueryBinary(left!!, SqlBinaryOperator.CONCAT, query.args[index + 1])
                    }
                }
            }
            left
        }
    }
    return QueryExpr(getQueryExpr(function, dbType).expr, query.alias)
}

fun visitFunctionConcatWs(query: QueryExprFunction, dbType: DB): QueryExpr {
    val function = when (dbType) {
        DB.MYSQL, DB.PGSQL, DB.HIVE, DB.SQLSERVER -> QueryExprFunction("CONCAT_WS", query.args)
        DB.ORACLE, DB.SQLITE, DB.CLICKHOUSE -> {
            val args = query.args
                .filterIndexed { index, _ -> index > 0 }
                .flatMap { listOf(it, query.args[0]) }
                .dropLast(1)

            if (dbType == DB.CLICKHOUSE) {
                QueryExprFunction("CONCAT", args)
            } else {
                var left: QueryBinary? = null
                args.forEachIndexed { index, item ->
                    if (args.size > index + 1) {
                        left = if (left == null) {
                            QueryBinary(item, SqlBinaryOperator.CONCAT, args[index + 1])
                        } else {
                            QueryBinary(left!!, SqlBinaryOperator.CONCAT, args[index + 1])
                        }
                    }
                }
                left
            }
        }
    }
    return QueryExpr(getQueryExpr(function, dbType).expr, query.alias)
}

fun visitFunctionStringAgg(query: QueryAggFunction, dbType: DB): QueryExpr {
    val function = when (dbType) {
        DB.MYSQL -> QueryAggFunction(
            "GROUP_CONCAT",
            listOf(query.args[0]),
            attributes = mapOf("SEPARATOR" to query.args[1]),
            distinct = query.distinct,
            orderBy = query.orderBy
        )
        DB.PGSQL -> QueryAggFunction(
            "STRING_AGG",
            listOf(cast(query.args[0], "VARCHAR"), query.args[1]),
            distinct = query.distinct,
            orderBy = query.orderBy
        )
        // TODO
        else -> throw TypeCastException("暂不支持该数据库使用此函数")
    }
    return QueryExpr(getQueryExpr(function, dbType).expr, query.alias)
}


fun visitFunctionArrayAgg(query: QueryAggFunction, dbType: DB): QueryExpr {
    val function = when (dbType) {
        DB.MYSQL -> QueryAggFunction(
            "GROUP_CONCAT",
            listOf(query.args[0]),
            attributes = mapOf("SEPARATOR" to query.args[1]),
            distinct = query.distinct,
            orderBy = query.orderBy
        )
        DB.PGSQL -> QueryExprFunction(
            "ARRAY_TO_STRING",
            listOf(
                QueryAggFunction(
                    "ARRAY_AGG",
                    listOf(cast(query.args[0], "VARCHAR")),
                    distinct = query.distinct,
                    orderBy = query.orderBy
                ), query.args[1]
            )
        )
        // TODO
        else -> throw TypeCastException("暂不支持该数据库使用此函数")
    }
    return QueryExpr(getQueryExpr(function, dbType).expr, query.alias)
}

fun visitMysqlQueryJson(queryJson: QueryJson): Pair {
    fun transMysqlJson(value: Any): String {
        return when (value) {
            is Number -> "[$value]"
            is String -> ".$value"
            else -> throw TypeCastException("取Json值时,表达式右侧只支持String或Int")
        }
    }

    fun visit(queryJson: QueryJson): Pair {
        val query = queryJson.query

        if (query is QueryJson) {
            val result = visit(query)
            return Pair(result.first + transMysqlJson(query.value), result.second)
        }
        return Pair("", query)
    }

    val visitResult = visit(queryJson)
    return Pair("$" + visitResult.first + transMysqlJson(queryJson.value), visitResult.second)
}

fun visitQuerySubQueryPredicate(query: QuerySubQueryPredicate): QueryExpr {
    val select = query.query.getSelect()
    return QueryExpr(SqlSubQueryPredicateExpr(SqlSelectQueryExpr(select), query.predicate), query.alias)
}

fun  getExpr(value: T): SqlExpr {
    return when (value) {
        null -> SqlNullExpr()

        is Number -> SqlNumberExpr(value)

        is Date -> SqlDateExpr(value)

        is String -> SqlCharExpr(value)

        is Boolean -> SqlBooleanExpr(value)

        is List<*> -> SqlListExpr(value.map { getExpr(it) })

        is SelectQuery -> SqlSelectQueryExpr(value.getSelect())

        else -> throw TypeCastException("表达式中存在不合法的数据类型(表达式中参数支持Number,String,Boolean,Data,List,Query表达式类型以及空值null和子查询)")
    }
}

fun checkOLAP(dbType: DB) {
    if (dbType in listOf(DB.CLICKHOUSE, DB.HIVE)) {
        throw SQLException("分析型数据库不支持此操作")
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy