All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.partiql.lang.ast.passes.StatementRedactor.kt Maven / Gradle / Ivy

There is a newer version: 1.0.0-perf.1
Show newest version
package org.partiql.lang.ast.passes

import com.amazon.ion.system.IonSystemBuilder
import com.amazon.ionelement.api.StringElement
import org.partiql.lang.ast.SourceLocationMeta
import org.partiql.lang.ast.sourceLocation
import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.syntax.PartiQLParserBuilder

/**
 * This is a function alias for determining which UDF input arguments need to be redacted.
 *
 * There are two components needed for implementation:
 *     1. Which arguments are needed for [SafeFieldName] validation
 *     2. Which arguments are needed for redaction should be returned
 *
 * For example, for a given function in which argument number is static, func(a, b, c, d),
 * we can validate whether `a` and `b` are a [SafeFieldName], if yes, `c` and `d` will be redacted.
 */
typealias UserDefinedFunctionRedactionLambda = (List) -> List

private val ion = IonSystemBuilder.standard().build()
private val parser = PartiQLParserBuilder().ionSystem(ion).build()
private const val maskPattern = "***(Redacted)"

const val INVALID_NUM_ARGS = "Invalid number of args in node"
const val INPUT_AST_STATEMENT_MISMATCH = "Unable to redact the statement. Please check that the input ast is the parsed result of the input statement"

/**
 * Returns true if the given [node] type is to be skipped for redaction or its text is one of the [safeFieldNames].
 */
fun skipRedaction(node: PartiqlAst.Expr, safeFieldNames: Set): Boolean {
    if (safeFieldNames.isEmpty()) {
        return false
    }

    return when (node) {
        is PartiqlAst.Expr.Id -> safeFieldNames.contains(node.name.text)
        is PartiqlAst.Expr.Lit -> {
            when (node.value) {
                is StringElement -> safeFieldNames.contains(node.value.stringValue)
                else -> false
            }
        }
        is PartiqlAst.Expr.Path -> {
            node.steps.any {
                when (it) {
                    is PartiqlAst.PathStep.PathExpr -> skipRedaction(it.index, safeFieldNames)
                    else -> false
                }
            }
        }
        else -> true // Skip redaction for other nodes
    }
}

/**
 * From the input PartiQL [statement], returns a statement in which [PartiqlAst.Expr.Lit]s not assigned with
 * [providedSafeFieldNames] are redacted to "***(Redacted)".
 *
 * [providedSafeFieldNames] is an optional set of fields whose values are not to be redacted. If no set is provided,
 * all literals will be redacted.
 *     For example, given a [providedSafeFieldNames] of Set('hashkey')
 *     The query of `SELECT * FROM tb WHERE hashkey = 'a' AND attr = 12345` will be redacted to:
 *                  `SELECT * FROM tb WHERE hashkey = 'a' AND attr = ***(Redacted)`
 * [userDefinedFunctionRedactionConfig] is an optional mapping of UDF names to functions determining which call
 * arguments are to be redacted. For an example, please check StatementRedactorTest.kt for more details.
 */
fun redact(
    statement: String,
    providedSafeFieldNames: Set = emptySet(),
    userDefinedFunctionRedactionConfig: Map = emptyMap()
): String {
    return redact(statement, parser.parseAstStatement(statement), providedSafeFieldNames, userDefinedFunctionRedactionConfig)
}

/**
 * From the input PartiQL [statement], returns a statement in which [PartiqlAst.Expr.Lit]s not assigned with
 * [providedSafeFieldNames] are redacted to "***(Redacted)". Assumes that the parsed PartiQL [statement] is the same
 * as the input [ast].
 *
 * [providedSafeFieldNames] is an optional set of fields whose values are not to be redacted. If no set is provided,
 * all literals will be redacted.
 *     For example, given a [providedSafeFieldNames] of Set('hashkey')
 *     The query of `SELECT * FROM tb WHERE hashkey = 'a' AND attr = 12345` will be redacted to:
 *                  `SELECT * FROM tb WHERE hashkey = 'a' AND attr = ***(Redacted)`
 * [userDefinedFunctionRedactionConfig] is an optional mapping of UDF names to functions determining which call
 * arguments are to be redacted. For an example, please check StatementRedactorTest.kt for more details.
 */
fun redact(
    statement: String,
    partiqlAst: PartiqlAst.Statement,
    providedSafeFieldNames: Set = emptySet(),
    userDefinedFunctionRedactionConfig: Map = emptyMap()
): String {

    val statementRedactionVisitor = StatementRedactionVisitor(statement, providedSafeFieldNames, userDefinedFunctionRedactionConfig)
    statementRedactionVisitor.walkStatement(partiqlAst)
    return statementRedactionVisitor.getRedactedStatement()
}

/**
 * Redact [PartiqlAst.Expr.Lit]s not assigned with [safeFieldNames] to "***(Redacted)". Function calls that have an
 * entry in [userDefinedFunctionRedactionConfig] will have their arguments redacted based on the redaction lambda.
 */
private class StatementRedactionVisitor(
    private val statement: String,
    private val safeFieldNames: Set,
    private val userDefinedFunctionRedactionConfig: Map
) : PartiqlAst.Visitor() {
    private val sourceLocationMetaForRedaction = arrayListOf()

    /**
     * Returns the redacted [statement].
     */
    fun getRedactedStatement(): String {
        val lines = statement.lines()
        val totalCharactersInPreviousLines = IntArray(lines.size)
        for (lineNum in 1 until lines.size) {
            totalCharactersInPreviousLines[lineNum] = totalCharactersInPreviousLines[lineNum - 1] + lines[lineNum - 1].length + 1
        }

        val redactedStatement = StringBuilder(statement)
        var offset = 0
        sourceLocationMetaForRedaction.sortWith(compareBy { it.lineNum }.thenBy { it.charOffset })

        sourceLocationMetaForRedaction.map {
            val length = it.length.toInt()
            val lineNum = it.lineNum.toInt()
            if (lineNum < 1 || lineNum > totalCharactersInPreviousLines.size) {
                throw IllegalArgumentException("$INPUT_AST_STATEMENT_MISMATCH, line number: $lineNum")
            }
            val start = totalCharactersInPreviousLines[lineNum - 1] + it.charOffset.toInt() - 1 + offset
            if (start < 0 || length < 0 || start >= redactedStatement.length || start > redactedStatement.length - length) {
                throw IllegalArgumentException(INPUT_AST_STATEMENT_MISMATCH)
            }
            redactedStatement.replace(start, start + length, maskPattern)
            offset = offset + maskPattern.length - length
        }
        return redactedStatement.toString()
    }

    override fun visitExprSelect(node: PartiqlAst.Expr.Select) {
        node.where?.let { redactExpr(it) }
    }

    override fun visitStatementDml(node: PartiqlAst.Statement.Dml) {
        node.where?.let { redactExpr(it) }
    }

    override fun visitAssignment(node: PartiqlAst.Assignment) {
        if (!skipRedaction(node.target, safeFieldNames)) {
            redactExpr(node.value)
        }
    }

    override fun visitDmlOpInsertValue(node: PartiqlAst.DmlOp.InsertValue) {
        when (node.value) {
            is PartiqlAst.Expr.Struct -> redactStructInInsertValueOp(node.value)
            else -> redactExpr(node.value)
        }
    }

    override fun visitDmlOpInsert(node: PartiqlAst.DmlOp.Insert) {
        when (node.values) {
            is PartiqlAst.Expr.Bag -> redactBagInInserOpValues(node.values)
            else -> redactExpr(node.values)
        }
    }

    private fun redactExpr(node: PartiqlAst.Expr) {
        if (node.isNAry()) {
            redactNAry(node)
        } else when (node) {
            is PartiqlAst.Expr.Lit -> redactLiteral(node)
            is PartiqlAst.Expr.List -> redactSeq(node)
            is PartiqlAst.Expr.Sexp -> redactSeq(node)
            is PartiqlAst.Expr.Bag -> redactSeq(node)
            is PartiqlAst.Expr.Struct -> redactStruct(node)
            is PartiqlAst.Expr.IsType -> redactTypes(node)
            else -> { /* other nodes are not currently redacted */ }
        }
    }

    private fun redactLogicalOp(args: List) { args.forEach { redactExpr(it) } }

    private fun redactComparisonOp(args: List) {
        if (args.size != 2) {
            throw IllegalArgumentException(INVALID_NUM_ARGS)
        }
        if (!skipRedaction(args[0], safeFieldNames)) {
            redactExpr(args[1])
        }
    }

    private fun plusMinusRedaction(args: List) {
        when (args.size) {
            2 -> {
                redactExpr(args[0])
                redactExpr(args[1])
            }
            else -> throw IllegalArgumentException(INVALID_NUM_ARGS)
        }
    }

    private fun posNegRedaction(expr: PartiqlAst.Expr) {
        redactExpr(expr)
    }

    private fun arithmeticOpRedaction(args: List) {
        if (args.size != 2) {
            throw IllegalArgumentException(INVALID_NUM_ARGS)
        }
        redactExpr(args[0])
        redactExpr(args[1])
    }

    private fun redactNAry(node: PartiqlAst.Expr) {
        when (node) {
            // Logical Ops
            is PartiqlAst.Expr.And -> redactLogicalOp(node.operands)
            is PartiqlAst.Expr.Or -> redactLogicalOp(node.operands)
            is PartiqlAst.Expr.Not -> redactExpr(node.expr)
            // Comparison Ops
            is PartiqlAst.Expr.Eq -> redactComparisonOp(node.operands)
            is PartiqlAst.Expr.Ne -> redactComparisonOp(node.operands)
            is PartiqlAst.Expr.Gt -> redactComparisonOp(node.operands)
            is PartiqlAst.Expr.Gte -> redactComparisonOp(node.operands)
            is PartiqlAst.Expr.Lt -> redactComparisonOp(node.operands)
            is PartiqlAst.Expr.Lte -> redactComparisonOp(node.operands)
            is PartiqlAst.Expr.InCollection -> redactComparisonOp(node.operands)
            // Arithmetic Ops
            is PartiqlAst.Expr.Plus -> plusMinusRedaction(node.operands)
            is PartiqlAst.Expr.Minus -> plusMinusRedaction(node.operands)
            is PartiqlAst.Expr.Pos -> posNegRedaction(node.expr)
            is PartiqlAst.Expr.Neg -> posNegRedaction(node.expr)
            is PartiqlAst.Expr.Times -> arithmeticOpRedaction(node.operands)
            is PartiqlAst.Expr.Divide -> arithmeticOpRedaction(node.operands)
            is PartiqlAst.Expr.Modulo -> arithmeticOpRedaction(node.operands)
            is PartiqlAst.Expr.Concat -> arithmeticOpRedaction(node.operands)
            // BETWEEN
            is PartiqlAst.Expr.Between -> {
                if (!skipRedaction(node.value, safeFieldNames)) {
                    redactExpr(node.from)
                    redactExpr(node.to)
                }
            }
            // CALL
            is PartiqlAst.Expr.Call -> redactCall(node)
            else -> { /* intentionally blank */ }
        }
    }

    private fun redactLiteral(literal: PartiqlAst.Expr.Lit) {
        val sourceLocation = literal.metas.sourceLocation ?: error("Cannot redact due to missing source location")
        sourceLocationMetaForRedaction.add(sourceLocation)
    }

    // once `bag`, `list`, and `sexp` modeled as described here: https://github.com/partiql/partiql-lang-kotlin/issues/239,
    // delete duplicated code
    private fun redactSeq(seq: PartiqlAst.Expr.List) = seq.values.map { redactExpr(it) }
    private fun redactSeq(seq: PartiqlAst.Expr.Bag) = seq.values.map { redactExpr(it) }
    private fun redactSeq(seq: PartiqlAst.Expr.Sexp) = seq.values.map { redactExpr(it) }

    private fun redactStruct(struct: PartiqlAst.Expr.Struct) {
        struct.fields.map {
            if (it.first is PartiqlAst.Expr.Lit) {
                redactLiteral(it.first)
            }
            redactExpr(it.second)
        }
    }

    private fun redactCall(node: PartiqlAst.Expr.Call) {
        val funcName = node.funcName.text
        when (val redactionLambda = userDefinedFunctionRedactionConfig[funcName]) {
            null -> node.args.map { redactExpr(it) }
            else -> {
                redactionLambda(node.args).map { redactExpr(it) }
            }
        }
    }

    private fun redactTypes(typed: PartiqlAst.Expr.IsType) {
        if (typed.value is PartiqlAst.Expr.Id && !skipRedaction(typed.value, safeFieldNames)) {
            val sourceLocation = typed.type.metas.sourceLocation ?: error("Cannot redact due to missing source location")
            sourceLocationMetaForRedaction.add(sourceLocation)
        }
    }

    /**
     * For [PartiqlAst.DmlOp.Insert], redacts every element of VALUES clause BAG value; for struct elements in the bag, it
     * follows the redaction rules that [redactStructInInsertValueOp] applies.
     * For example, given:
     * INSERT INTO tb <<{ 'hk': 'a', 'rk': 1, 'attr': { 'hk': 'a' }}>>"
     * REPLACE INTO tb << { 'dummy1' : 'hashKey', 'dummy2' : 'rangeKey', 'dummyTestAttribute' : '123' } >>
     *
     * Expected:
     * INSERT INTO tb <<{ 'hk': 'a', 'rk': 1, 'attr': { ***(Redacted): ***(Redacted) }}>>
     * REPLACE INTO tb <<{ 'dummy1' : ***(Redacted), 'dummy2' : ***(Redacted), 'dummyTestAttribute' : ***(Redacted) }>>
     */
    private fun redactBagInInserOpValues(bag: PartiqlAst.Expr.Bag) {
        bag.values.map {
            if (it is PartiqlAst.Expr.Struct) {
                redactStructInInsertValueOp(it)
            } else {
                redactExpr(it)
            }
        }
    }

    /**
     * For [PartiqlAst.DmlOp.InsertValue], only the outermost level of struct files could have a key attribute.
     * For example, in the struct { 'hk': 'a', 'rk': 1, 'attr': { 'hk': 'a' }},
     * only 'hk' in 'attr': { 'hk': 'a' } will be redacted
     */
    private fun redactStructInInsertValueOp(struct: PartiqlAst.Expr.Struct) {
        struct.fields.map {
            when (it.first) {
                is PartiqlAst.Expr.Lit ->
                    if (!skipRedaction(it.first, safeFieldNames)) {
                        redactExpr(it.second)
                    } else { /* intentionally blank */ }
            }
        }
    }

    // once NAry node modeled better in PIG (https://github.com/partiql/partiql-lang-kotlin/issues/241), this code can be
    // refactored
    // TODO: other NAry ops that not modeled (LIKE, INTERSECT, INTERSECT_ALL, EXCEPT, EXCEPT_ALL, UNION, UNION_ALL)
    private fun PartiqlAst.Expr.isNAry(): Boolean {
        return this is PartiqlAst.Expr.And ||
            this is PartiqlAst.Expr.Or ||
            this is PartiqlAst.Expr.Not ||
            this is PartiqlAst.Expr.Eq ||
            this is PartiqlAst.Expr.Ne ||
            this is PartiqlAst.Expr.Gt ||
            this is PartiqlAst.Expr.Gte ||
            this is PartiqlAst.Expr.Lt ||
            this is PartiqlAst.Expr.Lte ||
            this is PartiqlAst.Expr.InCollection ||
            this is PartiqlAst.Expr.Pos ||
            this is PartiqlAst.Expr.Neg ||
            this is PartiqlAst.Expr.Plus ||
            this is PartiqlAst.Expr.Minus ||
            this is PartiqlAst.Expr.Times ||
            this is PartiqlAst.Expr.Divide ||
            this is PartiqlAst.Expr.Modulo ||
            this is PartiqlAst.Expr.Concat ||
            this is PartiqlAst.Expr.Between ||
            this is PartiqlAst.Expr.Call
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy