com.uber.engsec.dp.sql.TreePrinter.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of sql-differential-privacy Show documentation
Show all versions of sql-differential-privacy Show documentation
Differential privacy for SQL queries
The newest version!
/*
* Copyright (c) 2017 Uber Technologies, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
package com.uber.engsec.dp.sql
import com.facebook.presto.sql.tree.{AliasedRelation, AllColumns, DereferenceExpression, FunctionCall, GroupBy, QualifiedNameReference, Query, QueryBody, SimpleGroupBy, SingleColumn, SortItem, StringLiteral, Table, With, WithQuery, Expression => PrestoExpression, Join => PrestoJoin, Node => PrestoNode, Select => PrestoSelect}
import com.uber.engsec.dp.dataflow.column.NodeColumnFacts
import com.uber.engsec.dp.sql.ast.ASTFunctions
import com.uber.engsec.dp.sql.dataflow_graph.reference.{ColumnReference, Function, UnstructuredReference}
import com.uber.engsec.dp.sql.dataflow_graph.relation.{Relation => DFGRelation, _}
import com.uber.engsec.dp.sql.dataflow_graph.{Node => DFGNode}
import com.uber.engsec.dp.sql.relational_algebra.{Expression, RelOrExpr, Relation}
import com.uber.engsec.dp.util.IdentityHashMap
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core.{Aggregate, Filter, Project, Sort, TableScan, Join => RelJoin}
import org.apache.calcite.rex._
import scala.collection.{MapLike, mutable}
/** A class to pretty-print parse trees with analysis results. Useful for debugging.
*/
object TreePrinter {
def printTreeJava(root: DFGNode, resultMap: java.util.Map[DFGNode, _]): Unit = {
import scala.collection.JavaConverters._
printTree(root, resultMap.asScala, None)
}
/** Prints the given Presto tree.
*/
def printTreePresto(root: PrestoNode,
resultMap: MapLike[PrestoNode,_,_] = mutable.Map.empty,
currentNode: Option[PrestoNode] = None): Unit = {
def nodePrintInfo(node: PrestoNode): (String, Seq[LabeledNode[PrestoNode]]) = {
val className = node.getClass.getSimpleName
val printStr = node match {
case t: Table => "Table[" + t.getName + "]"
case w: With => "With"
case w: WithQuery => "Name[" + w.getName + "]"
case j: PrestoJoin => "Join [" + j.getType.toString + "]"
case s: SortItem => "OrderBy[" + s.toString + "]"
case s: PrestoSelect => "Select" + (if (s.isDistinct) " [DISTINCT]" else "")
case q: Query => "Query" + (if (q.getLimit.isPresent) (" [LIMIT " + q.getLimit.get() + "]") else "")
case a: AllColumns => if (a.getPrefix.isPresent) (a.getPrefix.toString + ".*") else "*"
case s: SingleColumn => className + (if (s.getAlias.isPresent) (" [ALIAS: " + s.getAlias.get() + "]") else "")
case l: StringLiteral => className + "[" + l.toString + "]"
case q: QualifiedNameReference => "QualifiedNameReference[" + q.getName + "]"
case d: DereferenceExpression => "DereferenceExpression[" + d.toString + "]"
case f: FunctionCall => "FunctionCall[" + f.getName + "]"
case a: AliasedRelation => "Alias[" + a.getAlias + "]"
case e: PrestoExpression => className + "[" + e.toString + "]"
case q: QueryBody => className
case g: GroupBy => "GroupBy" + (if (g.isDistinct) " [DISTINCT]" else "")
case s: SimpleGroupBy => "SimpleGroupBy"
case _ => "??? [" + className + "]"
}
val namedChildren = ASTFunctions.getChildren(node).map { LabeledNode(None, _) }.toSeq
(printStr, namedChildren)
}
new TreePrinter[PrestoNode]().printTree(root, resultMap, false, nodePrintInfo, currentNode)
}
/** Prints the given relational alegebra tree.
*/
def printRelTree(root: RelOrExpr,
resultMap: MapLike[RelOrExpr,_,_] = mutable.Map.empty,
currentNode: Option[RelOrExpr] = None): Unit = {
import scala.collection.JavaConverters._
var currentInputNode: Option[RelNode] = None
def nodePrintInfo(node: RelOrExpr): (String, Seq[LabeledNode[RelOrExpr]]) = node match {
case Relation(rel) =>
val nodeStr = node.getClass.getSimpleName
val colNames = rel.getRowType.getFieldNames.asScala.toIndexedSeq
rel match {
case p: Project =>
currentInputNode = Some(p.getInput)
("Project", p.getProjects.asScala.zipWithIndex.map{ case (node, idx) => LabeledNode(Some(s"$idx [as ${colNames(idx)}]"), Expression(node)) }.toList ++ List(LabeledNode(Some("input"), Relation(p.getInput))))
case t: TableScan =>
("TableScan[" + t.getTable.getQualifiedName.asScala.mkString(".") + "]", Nil)
case a: Aggregate =>
("Aggregate[" + a.getAggCallList.asScala.mkString(",") + "] groupSets:" + a.groupSets.asScala.map{ _.toList.toString }.mkString(", "), List(LabeledNode(Some("input"), Relation(a.getInput))))
case f: Filter =>
("Filter",
LabeledNode(Some("condition"), Expression(f.getCondition))
:: LabeledNode(Some("input"), Relation(f.getInput()))
:: Nil)
case j: RelJoin =>
("Join[" + j.getJoinType.toString + "]",
LabeledNode(Some("condition"), Expression(j.getCondition))
:: LabeledNode(Some("left"), Relation(j.getLeft))
:: LabeledNode(Some("right"), Relation(j.getRight))
:: Nil)
case s: Sort =>
("Sort[" + s.collation.getFieldCollations.asScala.map{ col => s"${col.getFieldIndex} ${col.direction.shortString}" }.mkString(", ") + "]", List(LabeledNode(Some("input"), Relation(s.getInput))))
case _ =>
// ("??? (" + rel.getClass.getSimpleName + ")", rel.getChildExps.asScala.map { x => ("? expr", Expression(x) ) } ++ rel.getInputs.asScala.map { x => ("? rel", Relation(x) ) })
throw new RuntimeException(s"Unrecognized relational node type: ${rel.getClass.toString})")
}
case Expression(ex) =>
ex match {
case rexCall: RexCall =>
(rexCall.getClass.getSimpleName.substring(3) + "[" + rexCall.op.toString + "]", rexCall.operands.asScala.map{ ex => LabeledNode(None, Expression(ex)) })
case inputRef: RexInputRef =>
("Ref[input." + inputRef.getIndex.toString + "] (" + currentInputNode.get.getRowType.getFieldNames.get(inputRef.getIndex) + ")", Nil)
case rexSlot: RexSlot =>
(rexSlot.getClass.getSimpleName.substring(3) + ": " + rexSlot.getIndex.toString, Nil)
case rexLiteral: RexLiteral =>
("Literal[" + rexLiteral.getTypeName.toString + "]: " + rexLiteral.getValue.toString, Nil)
case _ =>
("UNIMPLEMENTED: " + ex.getClass.getSimpleName.toString + " (" + ex.getKind.toString + ")" , Nil)
}
}
new TreePrinter[RelOrExpr]().printTree(root, resultMap, false, nodePrintInfo, currentNode)
}
/** Prints the given dataflow graph as a tree.
*/
def printTree(root: DFGNode,
resultMap: MapLike[DFGNode, _, _] = mutable.Map.empty,
currentNode: Option[DFGNode] = None): Unit = {
def nodePrintInfo(node: DFGNode): (String, Seq[LabeledNode[DFGNode]]) = {
val nodeStr = node.getClass.getSimpleName + (if (node.nodeStr.isEmpty) "" else "[" + node.nodeStr + "]")
val namedChildren: Seq[LabeledNode[DFGNode]] = node match {
case c: ColumnReference => List(("of", c.of))
case f: Function => f.args.view.zipWithIndex.map{ case (node, idx) => (s"arg${idx}", node) }
case d: DataTable => Nil
case s: Select => s.items.view.zipWithIndex.map{ case (node, idx) => (s"${idx} [as ${node.as}]", node.ref) } ++ List(s.where).flatten.map{ ("where", _) } ++ s.groupBy.map{ idx => ("groupBy", s.items(idx).ref) }
case u: UnstructuredReference => u.children.view.zipWithIndex.map{ case (node, idx) => (s"arg${idx}", node) }
case j: Join => List(("left", j.left), ("right", j.right)) ++ List(j.condition).flatten.map{ ("condition", _) }
case u: Union => u.children.view.zipWithIndex.map{ case (node, idx) => (s"r${idx}", node) }
case e: Except => List(("left", e.left), ("right", e.right))
}
(nodeStr, namedChildren)
}
new TreePrinter[DFGNode]().printTree(root, resultMap, true, nodePrintInfo, currentNode)
}
}
class TreePrinter[T <: AnyRef] {
/* Truncate node strings longer than this for easier readability */
val MAX_NODE_STRING_LENGTH = 64
var maxIndent = 0
var nodeNum = 0
val nodeMap = new IdentityHashMap[T, PrintedNode[T]]()
val printedNodes = scala.collection.mutable.Set[T]()
def printTree(root: T,
resultMap: MapLike[T, _, _] = mutable.Map.empty,
printNodeNumbers: Boolean,
nodePrintInfo: T => (String, Seq[LabeledNode[T]]),
currentNode: Option[T]): Unit = {
import scala.collection.JavaConverters._
def _printTree(node: T, prefix: String, isRoot: Boolean = false): Unit = {
val nodeInfo = nodeMap(node)
val schemaStr = node match {
// Dataflow graph nodes
case r: DFGRelation => " | Columns: " + r.columnNames.mkString(", ")
case c: ColumnReference => " | Column: " + c.of.getColumnName(c.colIndex)
// RelRex nodes
case Relation(r: RelNode) => " | Columns: " + r.getRowType.getFieldNames.asScala.mkString(", ")
case _ => ""
}
val alreadyPrinted = node.isInstanceOf[DFGNode] && printedNodes.contains(node)
val isTail = alreadyPrinted || nodeInfo.isTail
// Generate the node print string
val nodePrintStr =
if (alreadyPrinted)
"..."
else if (nodeInfo.printStr.length < MAX_NODE_STRING_LENGTH)
nodeInfo.printStr
else
nodeInfo.printStr.substring(0, MAX_NODE_STRING_LENGTH) + " ..."
def printNodeAndColFacts(node: T, nodeFact: Option[_], colFacts: IndexedSeq[_]): String = {
val colNames = node match {
// Dataflow graph nodes
case r: DFGRelation => r.columnNames
case c: ColumnReference => List(c.of.getColumnName(c.colIndex))
// RelRex nodes
case Relation(r) => r.getRowType.getFieldNames.asScala
case _ => Nil // All other node types have no formal column names
}
// Only print node result if current node is a relation
val nodeResult =
if (colNames.isEmpty) None
else nodeFact.map{ x => s"{ $x }" }
// Print all column results, one per line
val colResults =
if (colNames.isEmpty) {
assert (colFacts.size == 1)
List(colFacts.head.toString)
}
else {
assert (colNames.size == colFacts.size)
val maxColName = colNames.map(_.length).max
colNames.zip(colFacts).map { x => s"%-${maxColName}s : %s".format(x._1, x._2.toString) }
}
val printLines = nodeResult ++ colResults
val linePrefix = prefix + (if (isTail) " " else "│ ") + (if (nodeInfo.namedChildren.isEmpty) "" else "│")
val formattedResult = printLines.mkString(
s"\n%-${maxIndent}s ${if (printNodeNumbers) " " else ""}".format(linePrefix)
)
formattedResult
}
// Generate the "state" print string
val statePrintStr =
if (alreadyPrinted)
"..."
else if (currentNode.contains(node))
"<== Current node"
else {
resultMap.get(node) match {
// Print column fact analysis results with aligned formatting to help readability
case Some(NodeColumnFacts(nodeFact, colFacts)) => printNodeAndColFacts(node, Some(nodeFact), colFacts)
case Some(colFacts: IndexedSeq[_]) => printNodeAndColFacts(node, None, colFacts)
case Some(x) => x.toString + schemaStr
case None if resultMap.nonEmpty => "" // if the result map is non-empty but current node isn't found, we must have aborted analysis early; don't print schema for these unprocessed nodes
case _ => "- " + schemaStr
}
}
// Figure out what should be printed as a prefix for this and next lines
val prefixStr = prefix + (if (isRoot) "─> " else if (isTail) "└──> " else "├──> ")
val nextPrefix = prefix + (if (isTail) " " else "│ ")
val childPrefix = prefix + (if (isTail) " " else "│")
// Print the node!
if (printNodeNumbers)
System.out.format(s"%-${maxIndent}s %-2s %s\n", prefixStr + nodePrintStr, nodeInfo.nodeNum.toString, statePrintStr)
else
System.out.format(s"%-${maxIndent}s %s\n", prefixStr + nodePrintStr, statePrintStr)
printedNodes += node
if (alreadyPrinted || (nodeInfo.namedChildren.isEmpty && isTail)) {
// If we already printed this node earlier in the graph, or the node has no children and is a tail node, print a
// spacer row (this improves readability of the graph)
System.out.format("%-" + (maxIndent + 11) + "s\n", nextPrefix)
} else {
// Otherwise print all the children (recursively)
nodeInfo.namedChildren.foreach { child =>
if (child.label.isDefined) {
System.out.println(s"${childPrefix} │")
System.out.println(s"${childPrefix} ." + child.label.get)
System.out.println(s"${childPrefix} │")
}
_printTree(child.node, nextPrefix)
}
}
}
// Scan through all the nodes to construct the node string and assign unique identifiers to each node. This
// allows us to calculate the required indentation level before printing the first node.
def preprocess(node: T, isTail: Boolean, depth: Int): Unit = {
if (nodeMap.contains(node)) return
val (nodeStr: String, namedChildren: Seq[LabeledNode[T]]) = nodePrintInfo(node)
maxIndent = math.max(maxIndent, 5*depth + math.min(nodeStr.length, MAX_NODE_STRING_LENGTH+4))
val lastIdx = namedChildren.length - 1
nodeMap += (node -> new PrintedNode[T](nodeStr, nodeNum, isTail, namedChildren))
nodeNum = nodeNum + 1
namedChildren.zipWithIndex.foreach{ case (node, idx) => preprocess(node.node, idx == lastIdx, depth+1) }
}
maxIndent = 0
nodeNum = 0
nodeMap.clear()
printedNodes.clear()
preprocess(root, true, 1)
if (printNodeNumbers) {
System.out.format(s"%-${maxIndent}s %s %s\n", "", "# ", "State")
System.out.format(s"%-${maxIndent}s %s %s\n", "", "--", "-------")
}
else {
System.out.format(s"%-${maxIndent}s State\n", "")
System.out.format(s"%-${maxIndent}s -------\n", "")
}
_printTree(root, "", true)
}
}
case class LabeledNode[+T](label: Option[String], node: T)
object LabeledNode {
import scala.language.implicitConversions
implicit def tuple2LabeledNode[T]( arg: (String, T) ): LabeledNode[T] = LabeledNode(Some(arg._1), arg._2)
implicit def list2LabeledNode[T]( args: Seq[(String,T)] ): Seq[LabeledNode[T]] = args.map{ arg => LabeledNode[T](Some(arg._1), arg._2) }
implicit def node2LabeledNode[T]( arg: T ): LabeledNode[T] = LabeledNode(None, arg)
}
class PrintedNode[T](val printStr: String, val nodeNum: Int, val isTail: Boolean, val namedChildren: Seq[LabeledNode[T]])