All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.vmware.dcm.compiler.TranslateViewToIR Maven / Gradle / Ivy

Go to download

Library for building declarative cluster managers. Please refer to the README at github.com/vmware/declarative-cluster-management/ for instructions on setting up solvers before use.

There is a newer version: 0.15.0
Show newest version
/*
 * Copyright 2018-2020 VMware, Inc. All Rights Reserved.
 * SPDX-License-Identifier: BSD-2
 */

package com.vmware.dcm.compiler;

import com.facebook.presto.sql.tree.AllColumns;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.ExistsPredicate;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.GroupBy;
import com.facebook.presto.sql.tree.GroupingElement;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.Query;
import com.facebook.presto.sql.tree.QuerySpecification;
import com.facebook.presto.sql.tree.SelectItem;
import com.facebook.presto.sql.tree.SingleColumn;
import com.facebook.presto.sql.tree.StringLiteral;
import com.facebook.presto.sql.tree.SubqueryExpression;
import com.google.common.base.Splitter;
import com.vmware.dcm.compiler.ir.BinaryOperatorPredicate;
import com.vmware.dcm.compiler.ir.BinaryOperatorPredicateWithAggregate;
import com.vmware.dcm.compiler.ir.CheckQualifier;
import com.vmware.dcm.compiler.ir.ColumnIdentifier;
import com.vmware.dcm.compiler.ir.Expr;
import com.vmware.dcm.compiler.ir.FunctionCall;
import com.vmware.dcm.compiler.ir.GroupByComprehension;
import com.vmware.dcm.compiler.ir.GroupByQualifier;
import com.vmware.dcm.compiler.ir.Head;
import com.vmware.dcm.compiler.ir.JoinPredicate;
import com.vmware.dcm.compiler.ir.ListComprehension;
import com.vmware.dcm.compiler.ir.Literal;
import com.vmware.dcm.compiler.ir.Qualifier;
import com.vmware.dcm.compiler.ir.TableRowGenerator;
import com.vmware.dcm.compiler.ir.UnaryOperator;

import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class TranslateViewToIR extends DefaultTraversalVisitor, Void> {
    private static final Map ARITHMETIC_OP_TABLE =
            new EnumMap<>(ArithmeticBinaryExpression.Operator.class);
    private static final Map COMPARISON_OP_TABLE =
            new EnumMap<>(ComparisonExpression.Operator.class);
    private static final Map LOGICAL_OP_TABLE =
            new EnumMap<>(LogicalBinaryExpression.Operator.class);

    static {
        ARITHMETIC_OP_TABLE.put(ArithmeticBinaryExpression.Operator.ADD,
                BinaryOperatorPredicate.Operator.ADD);
        ARITHMETIC_OP_TABLE.put(ArithmeticBinaryExpression.Operator.SUBTRACT,
                BinaryOperatorPredicate.Operator.SUBTRACT);
        ARITHMETIC_OP_TABLE.put(ArithmeticBinaryExpression.Operator.MULTIPLY,
                BinaryOperatorPredicate.Operator.MULTIPLY);
        ARITHMETIC_OP_TABLE.put(ArithmeticBinaryExpression.Operator.DIVIDE,
                BinaryOperatorPredicate.Operator.DIVIDE);
        ARITHMETIC_OP_TABLE.put(ArithmeticBinaryExpression.Operator.MODULUS,
                BinaryOperatorPredicate.Operator.MODULUS);

        COMPARISON_OP_TABLE.put(ComparisonExpression.Operator.EQUAL,
                BinaryOperatorPredicate.Operator.EQUAL);
        COMPARISON_OP_TABLE.put(ComparisonExpression.Operator.LESS_THAN,
                BinaryOperatorPredicate.Operator.LESS_THAN);
        COMPARISON_OP_TABLE.put(ComparisonExpression.Operator.GREATER_THAN,
                BinaryOperatorPredicate.Operator.GREATER_THAN);
        COMPARISON_OP_TABLE.put(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL,
                BinaryOperatorPredicate.Operator.LESS_THAN_OR_EQUAL);
        COMPARISON_OP_TABLE.put(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL,
                BinaryOperatorPredicate.Operator.GREATER_THAN_OR_EQUAL);
        COMPARISON_OP_TABLE.put(ComparisonExpression.Operator.NOT_EQUAL,
                BinaryOperatorPredicate.Operator.NOT_EQUAL);

        LOGICAL_OP_TABLE.put(LogicalBinaryExpression.Operator.AND, BinaryOperatorPredicate.Operator.AND);
        LOGICAL_OP_TABLE.put(LogicalBinaryExpression.Operator.OR, BinaryOperatorPredicate.Operator.OR);
    }

    private final IRContext irContext;
    private final Set tablesReferencedInView;
    private final boolean isAggregate;

    TranslateViewToIR(final IRContext irContext, final Set tablesReferencedInView, final boolean isAggregate) {
        this.irContext = irContext;
        this.tablesReferencedInView = tablesReferencedInView;
        this.isAggregate = isAggregate;
    }

    private Expr translateExpression(final Expression expression) {
        return translateExpression(expression, irContext, tablesReferencedInView, isAggregate);
    }

    @Override
    protected Optional visitLogicalBinaryExpression(final LogicalBinaryExpression node, final Void context) {
        final Expr left = translateExpression(node.getLeft());
        final Expr right = translateExpression(node.getRight());
        final BinaryOperatorPredicate.Operator operator = operatorTranslator(node.getOperator());
        return Optional.of(createOperatorPredicate(operator, left, right, isAggregate));
    }

    @Override
    protected Optional visitComparisonExpression(final ComparisonExpression node, final Void context) {
        final Expr left = translateExpression(node.getLeft());
        final Expr right = translateExpression(node.getRight());
        final BinaryOperatorPredicate.Operator operator = operatorTranslator(node.getOperator());
        return Optional.of(createOperatorPredicate(operator, left, right, isAggregate));
    }

    @Override
    protected Optional visitArithmeticBinary(final ArithmeticBinaryExpression node, final Void context) {
        final Expr left = translateExpression(node.getLeft());
        final Expr right = translateExpression(node.getRight());
        final BinaryOperatorPredicate.Operator operator = operatorTranslator(node.getOperator());
        return Optional.of(createOperatorPredicate(operator, left, right, isAggregate));
    }

    @Override
    protected Optional visitArithmeticUnary(final ArithmeticUnaryExpression node, final Void context) {
        final Expr innerExpr = translateExpression(node.getValue());
        final ArithmeticUnaryExpression.Sign sign = node.getSign();
        final UnaryOperator.Operator signStr = sign.equals(ArithmeticUnaryExpression.Sign.MINUS) ?
                UnaryOperator.Operator.MINUS : UnaryOperator.Operator.PLUS;
        final UnaryOperator operatorPredicate = new UnaryOperator(signStr, innerExpr);
        return Optional.of(operatorPredicate);
    }

    @Override
    protected Optional visitExists(final ExistsPredicate node, final Void context) {
        final Expr innerExpr = translateExpression(node.getSubquery());
        final com.vmware.dcm.compiler.ir.ExistsPredicate operatorPredicate =
                new com.vmware.dcm.compiler.ir.ExistsPredicate(innerExpr);
        return Optional.of(operatorPredicate);
    }

    @Override
    protected Optional visitInPredicate(final InPredicate node, final Void context) {
        final Expr left = translateExpression(node.getValue());
        final Expr right = translateExpression(node.getValueList());
        final BinaryOperatorPredicate operatorPredicate =
                new BinaryOperatorPredicate(BinaryOperatorPredicate.Operator.IN, left, right);
        return Optional.of(operatorPredicate);
    }

    @Override
    protected Optional visitFunctionCall(final com.facebook.presto.sql.tree.FunctionCall node,
                                               final Void context) {
        if (node.getArguments().size() == 1
                || (node.getArguments().isEmpty() && "count".equalsIgnoreCase(node.getName().getSuffix()))
                || (node.getArguments().size() == 4 && node.getName().getSuffix().equals("capacity_constraint"))
                || (node.getArguments().size() == 2 && node.getName().getSuffix().equals("contains"))) {
            // Only having clauses will have function calls in the expression.
            final Expr function;
            final String functionNameStr = node.getName().toString().toUpperCase(Locale.US);
            final FunctionCall.Function functionType = FunctionCall.Function.valueOf(functionNameStr);
            if (node.getArguments().size() >= 1) {
                final List arguments = node.getArguments().stream()
                        .map(e -> translateExpression(e, irContext, tablesReferencedInView, isAggregate))
                        .collect(Collectors.toList());
                function = new FunctionCall(functionType, arguments);
            } else if (node.getArguments().isEmpty() &&
                    "count".equalsIgnoreCase(node.getName().getSuffix())) {
                // The presto parser does not consider count(*) as a function with a single
                // argument "*", but instead treats it as a function without any arguments.
                // The parser code therefore has this special case behavior when it
                // comes to the count function. See Presto's ExpressionFormatter.visitFunctionCall() for how
                // this is handled externally from the FunctionCall code.
                //
                // We therefore replace the argument for count with the first column of one of the tables.
                final IRTable table = tablesReferencedInView.iterator().next();
                final IRColumn field = table.getIRColumns().entrySet().iterator().next().getValue();
                final ColumnIdentifier column = new ColumnIdentifier(table.getName(), field, false);
                function = new FunctionCall(functionType, column);
            } else {
                throw new RuntimeException("I don't know what to do with this function call type: " + node);
            }
            return Optional.of(function);
        } else {
            throw new RuntimeException("I don't know what do with the following node: " + node);
        }
    }

    /**
     * Parse columns like 'reference.field'
     */
    @Override
    protected Optional visitDereferenceExpression(final DereferenceExpression node, final Void context) {
        final IRColumn irColumn = getIRColumnFromDereferencedExpression(node, irContext);
        final ColumnIdentifier columnIdentifier = new ColumnIdentifier(irColumn.getIRTable().getName(), irColumn,
                                                true);
        return Optional.of(columnIdentifier);
    }

    @Override
    protected Optional visitSubqueryExpression(final SubqueryExpression node, final Void context) {
        final Query subQuery = node.getQuery();
        return Optional.of(apply(subQuery, Optional.empty(), irContext));
    }

    @Override
    protected Optional visitLiteral(final com.facebook.presto.sql.tree.Literal node, final Void context) {
        return super.visitLiteral(node, context);
    }

    @Override
    protected Optional visitStringLiteral(final StringLiteral node, final Void context) {
        return Optional.of(new Literal<>("'" + node.getValue() + "'", String.class));
    }

    @Override
    protected Optional visitLongLiteral(final LongLiteral node, final Void context) {
        return Optional.of(new Literal<>(Long.valueOf(node.toString()), Long.class));
    }

    @Override
    protected Optional visitBooleanLiteral(final BooleanLiteral node, final Void context) {
        return Optional.of(new Literal<>(Boolean.valueOf(node.toString()), Boolean.class));
    }

    @Override
    protected Optional visitIdentifier(final Identifier node, final Void context) {
        final IRColumn irColumn = irContext.getColumnIfUnique(node.toString(), tablesReferencedInView);
        assert tablesReferencedInView.stream().map(IRTable::getName)
                .collect(Collectors.toSet())
                .contains(irColumn.getIRTable().getName());
        final ColumnIdentifier identifier = new ColumnIdentifier(irColumn.getIRTable().getName(), irColumn, false);
        return Optional.of(identifier);
    }

    @Override
    protected Optional visitNotExpression(final NotExpression node, final Void context) {
        final Expr innerExpr = translateExpression(node.getValue());
        final UnaryOperator operatorPredicate = new UnaryOperator(UnaryOperator.Operator.NOT, innerExpr);
        return Optional.of(operatorPredicate);
    }

    @Override
    protected Optional visitIsNullPredicate(final IsNullPredicate node, final Void context) {
        final Expr innerExpr = translateExpression(node.getValue());
        final com.vmware.dcm.compiler.ir.IsNullPredicate isNullPredicate =
                new com.vmware.dcm.compiler.ir.IsNullPredicate(innerExpr);
        return Optional.of(isNullPredicate);
    }

    @Override
    protected Optional visitIsNotNullPredicate(final IsNotNullPredicate node, final Void context) {
        final Expr innerExpr = translateExpression(node.getValue());
        final com.vmware.dcm.compiler.ir.IsNotNullPredicate isNotNullPredicate =
                new com.vmware.dcm.compiler.ir.IsNotNullPredicate(innerExpr);
        return Optional.of(isNotNullPredicate);
    }


    /**
     * Converts an SQL view into an IR list comprehension.
     *
     * @param view the AST corresponding to an SQL view statement
     * @param irContext an IRContext instance
     * @return A list comprehension corresponding to the view parameter
     */
    static ListComprehension apply(final Query view, final Optional check, final IRContext irContext) {
        final FromExtractor fromParser = new FromExtractor(irContext);
        fromParser.process(view.getQueryBody());

        final Set tables = fromParser.getTables();
        final Optional where = ((QuerySpecification) view.getQueryBody()).getWhere();
        final List joinConditions = fromParser.getJoinConditions();
        final Optional having = ((QuerySpecification) view.getQueryBody()).getHaving();
        final Optional groupBy = ((QuerySpecification) view.getQueryBody()).getGroupBy();

        // Construct Monoid Comprehension
        final List selectItems = ((QuerySpecification) view.getQueryBody()).getSelect().getSelectItems();
        final List selectItemExpr = translateSelectItems(selectItems, irContext, tables);
        final Head head = new Head(selectItemExpr);

        final List qualifiers = new ArrayList<>();
        tables.forEach(t -> qualifiers.add(new TableRowGenerator(t)));
        where.ifPresent(e -> qualifiers.add((Qualifier) translateExpression(e, irContext, tables, false)));
        having.ifPresent(e -> qualifiers.add((Qualifier) translateExpression(e, irContext, tables, true)));
        final UsesAggregateFunctions usesAggregateFunctions = new UsesAggregateFunctions();
        check.ifPresent(e -> {
            usesAggregateFunctions.process(e);
            final Expr expr = translateExpression(e, irContext, tables, usesAggregateFunctions.isFound()
                                                      || groupBy.isPresent() || having.isPresent());
            qualifiers.add(new CheckQualifier(expr));
        });

        joinConditions.forEach(e -> {
            final Qualifier joinQualifier = (Qualifier) translateExpression(e, irContext, tables, false);
            assert joinQualifier instanceof BinaryOperatorPredicate;
            qualifiers.add(new JoinPredicate((BinaryOperatorPredicate) joinQualifier));
        });

        if (groupBy.isPresent()) {
            final List groupingElement = groupBy.get().getGroupingElements();
            final List columnIdentifiers = columnListFromGroupBy(groupingElement, irContext, tables);
            final GroupByQualifier groupByQualifier = new GroupByQualifier(columnIdentifiers);
            final ListComprehension comprehension = new ListComprehension(head, qualifiers);
            return new GroupByComprehension(comprehension, groupByQualifier);
        } else if (usesAggregateFunctions.isFound() || having.isPresent()) { // group by 1
            final GroupByQualifier groupByQualifier = new GroupByQualifier(
                    List.of(new Literal<>(1, Integer.class)));
            final ListComprehension comprehension = new ListComprehension(head, qualifiers);
            return new GroupByComprehension(comprehension, groupByQualifier);
        }
        return new ListComprehension(head, qualifiers);
    }

    /**
     * Translates a list of SQL SelectItems into a corresponding list of IR Expr
     */
    private static List translateSelectItems(final List selectItems,
                                                   final IRContext irContext,
                                                   final Set tables) {
        final List exprs = new ArrayList<>();
        for (final SelectItem selectItem: selectItems) {
            if (selectItem instanceof SingleColumn) {
                final SingleColumn singleColumn = (SingleColumn) selectItem;
                final Expression expression = singleColumn.getExpression();
                final Expr expr = translateExpression(expression, irContext, tables, false);
                singleColumn.getAlias().ifPresent(v -> expr.setAlias(v.toString()));
                exprs.add(expr);
            } else if (selectItem instanceof AllColumns) {
                tables.forEach(
                        table ->
                            table.getIRColumns().forEach((fieldName, irColumn) ->
                                exprs.add(new ColumnIdentifier(table.getName(), irColumn, false)))
                );
            }
        }
        return exprs;
    }

    /**
     * Translates a list of SQL GroupingElements into a corresponding list of ColumnIdentifiers. This method does
     * not work for GroupBy expressions that are not over columns or constant expressions.
     */
    private static List columnListFromGroupBy(final List groupingElements,
                                                                final IRContext irContext, final Set tables) {
        return groupingElements.stream()
                .map(GroupingElement::getExpressions) // We only support SimpleGroupBy
                .flatMap(Collection::stream)
                .map(expr -> translateExpression(expr, irContext, tables, false))
                .collect(Collectors.toList());
    }

    /**
     * Translates an SQL AST expression into an IR Expr type.
     */
    private static Expr translateExpression(final Expression expression, final IRContext irContext,
                                           final Set tablesReferencedInView, final boolean isAggregate) {
        final TranslateViewToIR traverser = new TranslateViewToIR(irContext, tablesReferencedInView, isAggregate);
        return traverser.process(expression).orElseThrow();
    }

    private static BinaryOperatorPredicate createOperatorPredicate(final BinaryOperatorPredicate.Operator operator,
                                                                   final Expr left, final Expr right,
                                                                   final boolean isAggregate) {
        return isAggregate
                ? new BinaryOperatorPredicateWithAggregate(operator, left, right)
                : new BinaryOperatorPredicate(operator, left, right);
    }

    /**
     * Retrieve an IRColumn from a given IRContext that corresponds to a DerefenceExpression node from the SQL AST
     */
    static IRColumn getIRColumnFromDereferencedExpression(final DereferenceExpression node, final IRContext irContext) {
        final List identifier = Splitter.on(".")
                .trimResults()
                .omitEmptyStrings()
                .splitToList(node.toString());

        // Only supports dereference expressions that have exactly 1 dot.
        // At the moment we don't support, e.g. schema.reference.field - that is we only support queries
        // within the same schema
        if (identifier.size() != 2) {
            throw new UnsupportedOperationException("Dereference fields can only be of the format `table.field`");
        }
        final String tableName = identifier.get(0);
        final String fieldName = identifier.get(1);
        return irContext.getColumn(tableName, fieldName);
    }

    private static BinaryOperatorPredicate.Operator operatorTranslator(final ArithmeticBinaryExpression.Operator op) {
        return ARITHMETIC_OP_TABLE.get(op);
    }

    private static BinaryOperatorPredicate.Operator operatorTranslator(final LogicalBinaryExpression.Operator op) {
        return LOGICAL_OP_TABLE.get(op);
    }

    private static BinaryOperatorPredicate.Operator operatorTranslator(final ComparisonExpression.Operator op) {
        return COMPARISON_OP_TABLE.get(op);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy