com.uber.engsec.dp.sql.ast.Transformer.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of sql-differential-privacy Show documentation
Show all versions of sql-differential-privacy Show documentation
Differential privacy for SQL queries
The newest version!
/*
* Copyright (c) 2017 Uber Technologies, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
package com.uber.engsec.dp.sql.ast
import com.facebook.presto.sql.tree.{AliasedRelation, AllColumns, ArithmeticBinaryExpression, ArithmeticUnaryExpression, AtTimeZone, BetweenPredicate, Cast, CoalesceExpression, ComparisonExpression, CurrentTime, DereferenceExpression, ExistsPredicate, Expression, Extract, FunctionCall, InListExpression, InPredicate, IsNotNullPredicate, IsNullPredicate, JoinOn, LikePredicate, Literal, LogicalBinaryExpression, LongLiteral, NotExpression, NullIfExpression, QualifiedNameReference, Query, QuerySpecification, Row, SearchedCaseExpression, SimpleCaseExpression, SingleColumn, SubqueryExpression, Table, TableSubquery, WhenClause, Except => PrestoExcept, Join => PrestoJoin, Node => PrestoNode, SelectItem => PrestoSelectItem, Union => PrestoUnion}
import com.uber.engsec.dp.analysis.name_resolution.{NameResolution, NameResolutionAnalysis, ReferenceInfo}
import com.uber.engsec.dp.exception._
import com.uber.engsec.dp.schema.{DatabaseModel, Schema}
import com.uber.engsec.dp.sql.dataflow_graph.reference.{ColumnReference, Function, Reference, UnstructuredReference}
import com.uber.engsec.dp.sql.dataflow_graph.relation._
import com.uber.engsec.dp.sql.dataflow_graph.{Node => DFGNode}
import com.uber.engsec.dp.sql.{AbstractAnalysis, TreePrinter}
import com.uber.engsec.dp.util.IdentityHashMap
import scala.collection.JavaConverters._
import scala.collection.mutable
/** Transforms a parsed AST (Presto tree) into a dataflow graph.
*/
class Transformer {
private val prestoReferences: IdentityHashMap[PrestoNode, ReferenceInfo] = IdentityHashMap.empty
private val prestoToDFGNode: IdentityHashMap[PrestoNode, DFGNode] = IdentityHashMap.empty
private val inferredSchemaForTables: mutable.HashMap[String, mutable.Set[String]] with mutable.MultiMap[String, String] = new mutable.HashMap[String, mutable.Set[String]] with mutable.MultiMap[String, String]
def convertToDataflowGraph(statement: PrestoNode): DFGNode = {
if (!statement.isInstanceOf[Query])
throw new IllegalArgumentException("convertToDataflowGraph can only be called on an AST (Presto) tree (Query node), found type " + statement.getClass().getSimpleName)
val nameResolutionResults: NameResolution = new NameResolutionAnalysis().run(statement)
prestoToDFGNode.clear()
prestoReferences.clear()
inferredSchemaForTables.clear()
// Pre-processing step: infer the schema for each table from the query structure using results of name resolution
// analysis. The inferred schema will be merged with the config schema during tree transformation, subject to the
// schemaMode flag.
nameResolutionResults.getColumnRefs.foreach { case (node, refInfo) =>
prestoReferences += (node -> refInfo)
val targetRelation =
if (refInfo.innerRelation.isDefined)
refInfo.innerRelation
else if (refInfo.ref.isUnique)
Some(refInfo.ref.getOnly)
else
None
targetRelation match {
case Some(table: Table) =>
val tableName = DatabaseModel.normalizeTableName(table.getName.toString)
node match {
case q: QualifiedNameReference => inferredSchemaForTables.addBinding(tableName, q.getName.toString)
case d: DereferenceExpression => inferredSchemaForTables.addBinding(tableName, d.getFieldName)
case _ => ()
}
case _ => ()
}
}
if (AbstractAnalysis.DEBUG) {
// Print name resolution results
TreePrinter.printTreePresto(statement, prestoReferences)
// Print inferred schema
println("Inferred schema: ")
println(inferredSchemaForTables.toString)
}
// Validate results (i.e., make sure we've resolved all the columns) before starting transformation.
nameResolutionResults.validate()
val stackDepth: Array[Integer] = Array(0)
// Transform root node (recursively)
val result: DFGNode = toGraphNode(statement)
if (AbstractAnalysis.DEBUG) {
TreePrinter.printTree(result)
}
result
}
def wrapInReference(node: DFGNode): Reference = node match {
case r : Reference => r // it's already a Reference type, no need to do anything
case _ => UnstructuredReference("wrappedRelation", List(node))(node.prestoSource)
}
def processSelectItem(prestoNode: PrestoSelectItem): Seq[SelectItem] = prestoNode match {
case singleColumn: SingleColumn =>
// Select column by name: Create a SelectItem node pointing to the target column
val expr = singleColumn.getExpression
val alias =
if (singleColumn.getAlias.isPresent)
singleColumn.getAlias.get
else
DatabaseModel.getImplicitColumnName(expr)
val targetRelation = wrapInReference(toGraphNode(expr))
List(SelectItem(alias, targetRelation, Some(singleColumn)))
case all : AllColumns =>
// SELECT * : Create a SelectItem for every column in the target relation.
val targetRelation = getReferencedRelationNode(all)
if (targetRelation.numCols == 0) throw new AmbiguousWildcardException("Empty schema for SELECT * from table " + targetRelation.toString)
targetRelation.columnNames.zipWithIndex.map{ case (colName, idx) => SelectItem(colName, ColumnReference(idx, targetRelation)) }
case _ => throw new RuntimeException("Unsupported SelectItem type: " + prestoNode.getClass.getSimpleName + " : " + prestoNode.toString)
}
private def toGraphNode(prestoNode: PrestoNode): DFGNode = {
if (prestoToDFGNode.contains(prestoNode)) {
// If we've already processed this node, return it immediately.
prestoToDFGNode(prestoNode)
} else {
// Otherwise convert the Presto node into a dataflow graph node and store the result before returning.
implicit val prestoSource = Some(prestoNode) // automatically assign current node as prestoSource for constructed dataflow graph nodes
val DFGNode: DFGNode = prestoNode match {
case query: Query => toGraphNode(query.getQueryBody)
case spec: QuerySpecification =>
val items = spec.getSelect.getSelectItems.asScala.map { processSelectItem(_) }.toList.flatten
val where = if (!spec.getWhere.isPresent) None else Some(toGraphNode(spec.getWhere.get).asInstanceOf[Reference])
val groupBy = if (!spec.getGroupBy.isPresent) Nil else
spec.getGroupBy.get.getGroupingElements.asScala.flatMap {
_.enumerateGroupingSets.asScala.flatMap {
_.asScala.map {
case l: LongLiteral => l.getValue.toInt - 1 // SQL indices start at 1, dataflow graph indices at 0
case q: QualifiedNameReference =>
val colName = q.getName.toString
val matchedItems = items.filter { _.as == colName }
val groupedItem = matchedItems match {
case Nil => throw new UnknownColumnException(colName)
case List(x) => x
case _ => throw new AmbiguousColumnReference(colName)
}
items.indexOf(groupedItem)
case d: DereferenceExpression =>
// We call toGraphNode to benefit from our handling of DereferenceExpression, which performs the
// same logic we would otherwise do here (i.e., normalizing col index in joins, column name
// uniqueness checks, etc.)
val targetRef = toGraphNode(d).asInstanceOf[ColumnReference]
// This gives us the column index into the dereferenced relation, but we need to convert that
// into the column of the Select relation we're about to create.
items.map{ _.ref }.indexWhere{
case ColumnReference(colIndex, of) => colIndex == targetRef.colIndex && of == targetRef.of
case _ => false
}
case e: Expression => throw new UnsupportedConstructException("Unsupported grouping element type: " + e.toString)
}
}
}.toList
Select(items, where, groupBy)
case func: FunctionCall =>
val funcName = func.getName.toString
val args =
if (funcName == "count") {
// Count has special semantics for its arguments that we need to handle separately from other functions.
val arg = if (func.getArguments.isEmpty) None else Some(func.getArguments.get(0))
if (arg.isDefined && (arg.get.isInstanceOf[QualifiedNameReference] || arg.get.isInstanceOf[DereferenceExpression])) {
// A count of a specific column. These cases include COUNT(column_name) and COUNT(table.column_name).
// Process as a regular column reference
List(wrapInReference(toGraphNode(arg.get)))
} else {
// Otherwise, it's COUNT(*), COUNT(1), etc., i.e., a use of COUNT() that does not reference a
// specific column. Create an unstructured reference node named 'countAll' that points to the
// target relation in order to preserve the data dependence.
val targetRelation = getReferencedRelationNode(func)
List(UnstructuredReference("countAll", targetRelation))
}
} else {
// Regular function. Create dataflow graph nodes for each of the function's arguments.
func.getArguments.asScala.zipWithIndex.map { case (arg, idx) =>
if (DatabaseModel.isFunctionArgumentLiteral(funcName, idx))
// Skip processing of nodes that we know by context are literals to avoid incorrectly processing them
// as column references even when they aren't escaped (this would otherwise result in UnknownColumnException)
UnstructuredReference(arg.toString)
else
wrapInReference(toGraphNode(arg))
}.toList
}
Function(funcName, args)
case ref: QualifiedNameReference =>
val name = ref.getName.toString
if (DatabaseModel.isBuiltInFunction(name))
Function(name)
else {
val targetRelation = getReferencedRelationNode(ref)
val targetIndex =
targetRelation.getColumnIndexes(name) match {
case Nil => throw new UnknownColumnException(name)
case x :: Nil => x
case _ => throw new AmbiguousColumnReference(name)
}
ColumnReference(targetIndex, targetRelation)
}
case table: Table =>
if (prestoReferences.contains(table)) {
// This is an aliased table (not a database table), so resolve the alias and return the graph node for it.
val aliasedRelation = prestoReferences(table).ref.getOnly
toGraphNode(aliasedRelation)
} else {
// It's not an aliased table, create a new DataTable node.
val tableName = DatabaseModel.normalizeTableName(table.getName.toString)
if (Schema.getSchemaMapForTable(tableName).isEmpty && Transformer.isStrictMode())
throw new UndefinedSchemaException(tableName)
val configSchema = Schema.getSchemaForTable(tableName).map{ _.name }.toList
val effectiveSchema =
if (Transformer.isBestEffortMode())
mergeSchemas(configSchema, inferredSchemaForTables.getOrElse(tableName, Set()).toSet)
else
configSchema
DataTable(tableName, effectiveSchema.toIndexedSeq)
}
case join: PrestoJoin =>
val leftRelation = toGraphNode(join.getLeft).asInstanceOf[Relation]
val rightRelation = toGraphNode(join.getRight).asInstanceOf[Relation]
val joinType = JoinType.parse(join.getType.toString)
val joinCondition =
if (!join.getCriteria.isPresent) None else {
join.getCriteria.get match {
case jo : JoinOn => Some(toGraphNode(jo.getExpression).asInstanceOf[Reference])
case _ => throw new UnsupportedOperationException("Unsupported join criteria type: " + join.getCriteria.get.getClass.getSimpleName)
}
}
Join(leftRelation, rightRelation, joinType, joinCondition)
case deref: DereferenceExpression =>
val colName = deref.getFieldName
val targetRelation = getReferencedRelationNode(deref)
val targetColIndex = targetRelation match {
// If this dereferences into a JOIN-created relation, calculate the correct index by figuring out which
// inner relation's column is being referenced.
case join : Join =>
val innerRelationNode = toGraphNode(prestoReferences(deref).innerRelation.get).asInstanceOf[Relation]
join.getColumnIndexForInnerRelation(colName, innerRelationNode) match {
case Some(idx) => idx
case None => throw new JoinException("No column " + colName + " found in inner relation " + innerRelationNode.toString + ", referenced from " + deref.toString)
}
// For all other relation node types, ask the graph node for the named column's index.
case _ => targetRelation.getColumnIndexes(colName) match {
case Nil => throw new UnknownColumnException(colName)
case x :: Nil => x
case _ => throw new AmbiguousColumnReference(colName)
}
}
ColumnReference(targetColIndex, targetRelation)
case union: PrestoUnion =>
val children = union.getRelations.asScala.map{ toGraphNode(_).asInstanceOf[Relation] }.toList
// In SQL, all relations in a UNION must have identical schema
if (!children.forall(_.columnNames == children.head.columnNames)) throw new InvalidQueryException("Schema mismatch in UNION")
Union(children)
case except: PrestoExcept =>
val (left, right) = (toGraphNode(except.getLeft).asInstanceOf[Relation], toGraphNode(except.getRight).asInstanceOf[Relation])
// In SQL, all relations in a EXCEPT must have identical schema
if (left.columnNames != right.columnNames) throw new InvalidQueryException("Schema mismatch in EXCEPT")
Except(left, right)
// We represent comparison expressions as Function nodes
case c: ComparisonExpression => Function(c.getType.toString, wrapInReference(toGraphNode(c.getLeft)), wrapInReference(toGraphNode(c.getRight)))
// Pass-through nodes (i.e., nodes for which we follow the data dependence chain but aren't represented explicitly in dataflow graphs)
case t: TableSubquery => toGraphNode(t.getQuery)
case a: AliasedRelation => toGraphNode(a.getRelation)
case s: SubqueryExpression => toGraphNode(s.getQuery)
case c: Cast => toGraphNode(c.getExpression)
// Nodes represented as UnstructuredReference
case l: Literal => UnstructuredReference(s"literal(${l.toString})")
case a: ArithmeticBinaryExpression => UnstructuredReference("arithmeticBinary", toGraphNodes(a.getLeft, a.getRight))
case t: CurrentTime => UnstructuredReference("currentTime")
case n: NullIfExpression => UnstructuredReference("nullIf", toGraphNodes(n.getFirst, n.getSecond))
case a: ArithmeticUnaryExpression => UnstructuredReference("arithmeticUnary", toGraphNodes(a.getValue))
case c: CoalesceExpression => UnstructuredReference("coalesce", toGraphNodes(c.getOperands.asScala: _*))
case a: AtTimeZone => UnstructuredReference("atTimeZone", toGraphNodes(a.getValue, a.getTimeZone))
case i: IsNullPredicate => UnstructuredReference("isNull", toGraphNodes(i.getValue))
case b: BetweenPredicate => UnstructuredReference("between", toGraphNodes(b.getValue, b.getMax, b.getMin))
case l: LikePredicate => UnstructuredReference("like", toGraphNodes(l.getValue, l.getEscape))
case n: NotExpression => UnstructuredReference("not", toGraphNodes(n.getValue))
case i: InListExpression => UnstructuredReference("in", toGraphNodes(i.getValues.asScala: _*))
case i: InPredicate => UnstructuredReference("in", toGraphNodes(i.getValue, i.getValueList))
case i: IsNotNullPredicate => UnstructuredReference("notNull", toGraphNodes(i.getValue))
case l: LogicalBinaryExpression => UnstructuredReference(s"logicalBinary-${l.getType}", toGraphNodes(l.getLeft, l.getRight))
case w: WhenClause => UnstructuredReference("whenClause", toGraphNodes(w.getOperand, w.getResult))
case s: SearchedCaseExpression => UnstructuredReference("case", toGraphNodes(s.getWhenClauses.asScala ++ List(s.getDefaultValue.orElse(null)): _*))
case e: Extract => UnstructuredReference("extract", toGraphNodes(e.getExpression))
case r: Row => UnstructuredReference("row", toGraphNodes(r.getItems.asScala: _*))
case s: SimpleCaseExpression => UnstructuredReference("simpleCase", toGraphNodes(s.getWhenClauses.asScala ++ List(s.getDefaultValue.orElse(null)) ++ List(s.getOperand): _* ))
case e: ExistsPredicate => UnstructuredReference("exists", toGraphNodes(e.getSubquery))
case _ => throw new RuntimeException("Unsupported presto type: " + prestoNode.getClass.getSimpleName + " : " + prestoNode.toString)
}
prestoToDFGNode += (prestoNode -> DFGNode)
DFGNode
}
}
// Helper function to convert a variable number of Presto nodes into a list of dataflow graph nodes (this is
// a common pattern when processing functions and unstructured references)
private def toGraphNodes(prestoNodes: PrestoNode*): List[DFGNode] = prestoNodes.filter{ _ != null }.map{ toGraphNode }.toList
private def getReferencedRelationNode(node: PrestoNode): Relation = {
val referenceInfo = prestoReferences(node).ref
val result =
if (!referenceInfo.hasTwoRelations)
toGraphNode(referenceInfo.getOnly)
else {
// This node may refer to either of multiple relations. We use the schema information to figure
// out which one, throwing AmbiguousColumnReference if both relations have the target column, and
// UnknownColumnException if neither has it.
if (!node.isInstanceOf[QualifiedNameReference]) throw new RuntimeException("Unsupported node type for ambiguous reference: " + node.getClass.getSimpleName)
val referencedColumnName = node.asInstanceOf[QualifiedNameReference].getName.toString
val firstRelation = toGraphNode(referenceInfo.first).asInstanceOf[Relation]
val secondRelation = toGraphNode(referenceInfo.second.get).asInstanceOf[Relation]
val foundInFirst = firstRelation.getColumnIndexes(referencedColumnName).nonEmpty
val foundInSecond = secondRelation.getColumnIndexes(referencedColumnName).nonEmpty
(foundInFirst, foundInSecond) match {
case (true, true) => throw new AmbiguousColumnReference(referencedColumnName)
case (false, false) => throw new UnknownColumnException(referencedColumnName)
case (true, _) => firstRelation
case (_, true) => secondRelation
}
}
result.asInstanceOf[Relation]
}
/** Adds the schema elements from schema2 which are not already present in schema1 to the end of schema1.
*/
def mergeSchemas(schema1: List[String], schema2: Set[String]): List[String] = {
schema1 ++ (schema2 -- schema1.toSet).toList
}
def nodeToStr(node: DFGNode): String = {
if (node == null) "[null]"
else node.getClass.getSimpleName + "[" + node.toString + "]"
}
}
object Transformer {
object SCHEMA_MODE extends Enumeration {
/** Strict schema mode: only predefined schema information is used. Queries that reference a table with an undefined
* schema, or conflict with the known schema are determined to be invalid, raising an exception. This mode is useful
* for guaranteeing correct semantics of dataflow graphs (including validation of the query) when the complete
* schema is known.
*/
val STRICT = Value("Strict")
/** Best-effort mode: use schema information, if available, to resolve ambiguous semantics (e.g., wildcard selection
* and left/right resolution for column selection of joined tables), while using the query structure to infer
* schema everywhere else. This mode is useful if the query is known to be valid but schema information is not
* necessarily up-to-date, since the schema is only consulted as a hint when necessary and missing elements of the
* schema are filled in, where possible, by inspection of the query.
*
* IMPORTANT: this mode assumes the query is valid and therefore certain types of errors can be ruled out. This can
* significantly improve transformation success rate when schema information is not available. However, this mode
* may not preserve semantics if the query is invalid (in which case we may incorrectly generate a valid graph) or
* uses wildcard selection (in which case we may generate a query with incorrect/incomplete column names).
*
* For example, consider the query:
*
* SELECT z FROM known_table
*
* If the config schema defines columns [a, b, c] in table 'known_table', this query suggests that column z
* must also exist in 'known_table', so it is added automatically. In STRICT mode this query would raise
* an UnknownColumnException error. While this has no effect for typical analyses (which care only about columns
* referenced by the query, i.e., column "z"), it can cause subtle issues with wildcard selection. For example:
*
* WITH t1 AS (SELECT z FROM known_table) SELECT * FROM known_table
*
* will produce a dataflow graph with SelectItems: "a", "b", "c", "z" -- where the first 3 columns are taken from
* the config schema and the latter column added automatically based on inspection of the WITH clause.
* This is the best guess for the query's semantics given available information, however it may be incorrect
* if the schema is wrong or incomplete. As a general rule, wildcard handling is only guaranteed to be correct in
* strict mode. Wildcard selection will only fail in best-effort mode if the combined (config and/or inferred)
* schema is completely empty.
*
* As another example, consider the query:
*
* SELECT col1 from known_table JOIN unknown_table
*
* If the schema for table 'known_table' contains a column named col1, then the dataflow path for the selected
* column is assumed to flow from that table even if we don't have any schema information for 'unknown_table'.
* In reality, this may not be true: if 'unknown_table' also contains a column named 'col1', this query would
* result in an ambiguous column runtime error.
*/
val BEST_EFFORT = Value("Best effort")
}
var schemaMode: SCHEMA_MODE.Value = SCHEMA_MODE.BEST_EFFORT
def isStrictMode() = { schemaMode == SCHEMA_MODE.STRICT }
def isBestEffortMode() = { schemaMode == SCHEMA_MODE.BEST_EFFORT }
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy