All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.facebook.presto.verifier.rewrite.DefaultTreeRewriter Maven / Gradle / Ivy

There is a newer version: 0.290
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.verifier.rewrite;

import com.facebook.presto.sql.tree.AddColumn;
import com.facebook.presto.sql.tree.AliasedRelation;
import com.facebook.presto.sql.tree.Analyze;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.Call;
import com.facebook.presto.sql.tree.CallArgument;
import com.facebook.presto.sql.tree.ColumnDefinition;
import com.facebook.presto.sql.tree.CreateMaterializedView;
import com.facebook.presto.sql.tree.CreateSchema;
import com.facebook.presto.sql.tree.CreateTable;
import com.facebook.presto.sql.tree.CreateTableAsSelect;
import com.facebook.presto.sql.tree.CreateView;
import com.facebook.presto.sql.tree.Cube;
import com.facebook.presto.sql.tree.Deallocate;
import com.facebook.presto.sql.tree.Delete;
import com.facebook.presto.sql.tree.Except;
import com.facebook.presto.sql.tree.Execute;
import com.facebook.presto.sql.tree.Explain;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FrameBound;
import com.facebook.presto.sql.tree.GroupBy;
import com.facebook.presto.sql.tree.GroupingElement;
import com.facebook.presto.sql.tree.GroupingSets;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.Insert;
import com.facebook.presto.sql.tree.Intersect;
import com.facebook.presto.sql.tree.Join;
import com.facebook.presto.sql.tree.JoinCriteria;
import com.facebook.presto.sql.tree.JoinOn;
import com.facebook.presto.sql.tree.Lateral;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.OrderBy;
import com.facebook.presto.sql.tree.Prepare;
import com.facebook.presto.sql.tree.Property;
import com.facebook.presto.sql.tree.Query;
import com.facebook.presto.sql.tree.QueryBody;
import com.facebook.presto.sql.tree.QuerySpecification;
import com.facebook.presto.sql.tree.RefreshMaterializedView;
import com.facebook.presto.sql.tree.Relation;
import com.facebook.presto.sql.tree.Return;
import com.facebook.presto.sql.tree.Rollup;
import com.facebook.presto.sql.tree.Row;
import com.facebook.presto.sql.tree.SampledRelation;
import com.facebook.presto.sql.tree.Select;
import com.facebook.presto.sql.tree.SelectItem;
import com.facebook.presto.sql.tree.ShowStats;
import com.facebook.presto.sql.tree.SimpleGroupBy;
import com.facebook.presto.sql.tree.SingleColumn;
import com.facebook.presto.sql.tree.SortItem;
import com.facebook.presto.sql.tree.Statement;
import com.facebook.presto.sql.tree.Table;
import com.facebook.presto.sql.tree.TableElement;
import com.facebook.presto.sql.tree.TableSubquery;
import com.facebook.presto.sql.tree.Union;
import com.facebook.presto.sql.tree.Unnest;
import com.facebook.presto.sql.tree.Values;
import com.facebook.presto.sql.tree.Window;
import com.facebook.presto.sql.tree.WindowFrame;
import com.facebook.presto.sql.tree.With;
import com.facebook.presto.sql.tree.WithQuery;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;

import java.util.Iterator;
import java.util.List;
import java.util.Optional;

/**
 * A default implementation of {@link AstVisitor} that reconstructs a node if any of its children is reconstructed.
 * Expression node reconstruction is not supported and left to the users. At the moment it is only used for presto.verifier package.
 * Generalize it with caution.
 */
public class DefaultTreeRewriter
        extends AstVisitor
{
    @Override
    protected Node visitNode(Node node, C context)
    {
        return node;
    }

    @Override
    protected Node visitExpression(Expression node, C context)
    {
        throw new UnsupportedOperationException("Not yet implemented: " + getClass().getSimpleName() + " for " + node.getClass().getName());
    }

    @Override
    protected Node visitAddColumn(AddColumn node, C context)
    {
        Node column = process(node.getColumn(), context);
        if (node.getColumn() == column) {
            return node;
        }

        return new AddColumn(node.getName(), (ColumnDefinition) column, node.isTableExists(), node.isColumnNotExists());
    }

    @Override
    protected Node visitAliasedRelation(AliasedRelation node, C context)
    {
        Node relation = process(node.getRelation(), context);
        Node alias = process(node.getAlias(), context);
        List columnNames = process(node.getColumnNames(), context);
        if (node.getRelation() == relation && node.getAlias() == alias && sameElements(node.getColumnNames(), columnNames)) {
            return node;
        }

        return new AliasedRelation((Relation) relation, (Identifier) alias, columnNames);
    }

    @Override
    protected Node visitAnalyze(Analyze node, C context)
    {
        List properties = process(node.getProperties(), context);
        if (sameElements(node.getProperties(), properties)) {
            return node;
        }

        return new Analyze(node.getTableName(), properties);
    }

    @Override
    protected Node visitCall(Call node, C context)
    {
        List arguments = process(node.getArguments(), context);
        if (sameElements(node.getArguments(), arguments)) {
            return node;
        }

        return new Call(node.getName(), arguments);
    }

    @Override
    protected Node visitCallArgument(CallArgument node, C context)
    {
        Node value = process(node.getValue(), context);
        if (node.getValue() == value) {
            return node;
        }

        return node.getName().isPresent() ? new CallArgument(node.getName().get(), (Expression) value) : new CallArgument((Expression) value);
    }

    @Override
    protected Node visitColumnDefinition(ColumnDefinition node, C context)
    {
        Node name = process(node.getName(), context);
        List properties = process(node.getProperties(), context);
        if (node.getName() == name && sameElements(node.getProperties(), properties)) {
            return node;
        }

        return new ColumnDefinition((Identifier) name, node.getType(), node.isNullable(), properties, node.getComment());
    }

    @Override
    protected Node visitCreateMaterializedView(CreateMaterializedView node, C context)
    {
        Node query = process(node.getQuery(), context);
        List properties = process(node.getProperties(), context);
        if (node.getQuery() == query && node.getProperties() == properties) {
            return node;
        }

        return new CreateMaterializedView(node.getName(), (Query) query, node.isNotExists(), properties, node.getComment());
    }

    @Override
    protected Node visitCreateSchema(CreateSchema node, C context)
    {
        List properties = process(node.getProperties(), context);
        if (sameElements(node.getProperties(), properties)) {
            return node;
        }

        return new CreateSchema(node.getSchemaName(), node.isNotExists(), properties);
    }

    @Override
    protected Node visitCreateTable(CreateTable node, C context)
    {
        List elements = process(node.getElements(), context);
        List properties = process(node.getProperties(), context);
        if (sameElements(node.getElements(), elements) && sameElements(node.getProperties(), properties)) {
            return node;
        }

        return new CreateTable(node.getName(), elements, node.isNotExists(), properties, node.getComment());
    }

    @Override
    protected Node visitCreateTableAsSelect(CreateTableAsSelect node, C context)
    {
        Node query = process(node.getQuery(), context);
        List properties = process(node.getProperties(), context);
        Optional> columnAliases = node.getColumnAliases().map(aliases -> process(aliases, context));
        if (node.getQuery() == query && node.getProperties() == properties && (!columnAliases.isPresent() || sameElements(node.getColumnAliases().get(), columnAliases.get()))) {
            return node;
        }

        return new CreateTableAsSelect(node.getName(), (Query) query, node.isNotExists(), properties, node.isWithData(), node.getColumnAliases(), node.getComment());
    }

    @Override
    protected Node visitCreateView(CreateView node, C context)
    {
        Node query = process(node.getQuery(), context);
        if (node.getQuery() == query) {
            return node;
        }

        return new CreateView(node.getName(), (Query) query, node.isReplace(), node.getSecurity());
    }

    @Override
    protected Node visitCube(Cube node, C context)
    {
        List expressions = process(node.getExpressions(), context);
        if (sameElements(node.getExpressions(), expressions)) {
            return node;
        }

        return new Cube(expressions);
    }

    @Override
    protected Node visitDeallocate(Deallocate node, C context)
    {
        Node name = process(node.getName(), context);
        if (node.getName() == name) {
            return node;
        }

        return new Deallocate((Identifier) name);
    }

    @Override
    protected Node visitDelete(Delete node, C context)
    {
        Node table = process(node.getTable(), context);
        Optional where = process(node.getWhere(), context);
        if (node.getTable() == table && sameElement(node.getWhere(), where)) {
            return node;
        }

        return new Delete((Table) table, where);
    }

    @Override
    protected Node visitExcept(Except node, C context)
    {
        Node left = process(node.getLeft(), context);
        Node right = process(node.getRight(), context);
        if (node.getLeft() == left && node.getRight() == right) {
            return node;
        }

        return new Except((Relation) left, (Relation) right, node.isDistinct());
    }

    @Override
    protected Node visitExecute(Execute node, C context)
    {
        Node name = process(node.getName(), context);
        List parameters = process(node.getParameters(), context);
        if (node.getName() == name && sameElements(node.getParameters(), parameters)) {
            return node;
        }

        return new Execute(node.getName(), parameters);
    }

    @Override
    protected Node visitExplain(Explain node, C context)
    {
        Node statement = process(node.getStatement(), context);
        if (node.getStatement() == statement) {
            return node;
        }

        return new Explain((Statement) statement, node.isAnalyze(), node.isVerbose(), node.getOptions());
    }

    @Override
    protected Node visitFrameBound(FrameBound node, C context)
    {
        Optional value = process(node.getValue(), context);
        if (sameElement(node.getValue(), value)) {
            return node;
        }

        return value.isPresent() ? new FrameBound(node.getType(), value.get()) : new FrameBound(node.getType());
    }

    @Override
    protected Node visitGroupBy(GroupBy node, C context)
    {
        List groupingElements = process(node.getGroupingElements(), context);
        if (sameElements(node.getGroupingElements(), groupingElements)) {
            return node;
        }

        return new GroupBy(node.isDistinct(), groupingElements);
    }

    @Override
    protected Node visitGroupingSets(GroupingSets node, C context)
    {
        List> sets = node.getSets().stream().map(expressionList -> process(expressionList, context)).collect(ImmutableList.toImmutableList());

        if (sameElements(node.getSets(), sets)) {
            return node;
        }

        return new GroupingSets(sets);
    }

    @Override
    protected Node visitInsert(Insert node, C context)
    {
        Node query = process(node.getQuery(), context);
        Optional> columns = node.getColumns().map(columnList -> process(columnList, context));
        if (node.getQuery() == query && (!columns.isPresent() || sameElements(node.getColumns().get(), columns.get()))) {
            return node;
        }

        return new Insert(node.getTarget(), columns, (Query) query);
    }

    @Override
    protected Node visitIntersect(Intersect node, C context)
    {
        List relations = process(node.getRelations(), context);
        if (sameElements(node.getRelations(), relations)) {
            return node;
        }

        return new Intersect(relations, node.isDistinct());
    }

    @Override
    protected Node visitJoin(Join node, C context)
    {
        Node left = process(node.getLeft(), context);
        Node right = process(node.getRight(), context);
        Optional joinCriteria = node.getCriteria()
                .map(criteria -> {
                    if (criteria instanceof JoinOn) {
                        Node expression = process(((JoinOn) criteria).getExpression(), context);
                        if (((JoinOn) criteria).getExpression() == expression) {
                            return criteria;
                        }
                        return new JoinOn((Expression) expression);
                    }
                    return criteria;
                });
        if (node.getLeft() == left && node.getRight() == right && node.getCriteria() == joinCriteria) {
            return node;
        }

        return new Join(node.getType(), (Relation) left, (Relation) right, joinCriteria);
    }

    @Override
    protected Node visitLateral(Lateral node, C context)
    {
        Node query = process(node.getQuery(), context);
        if (node.getQuery() == query) {
            return node;
        }

        return new Lateral((Query) query);
    }

    @Override
    protected Node visitOrderBy(OrderBy node, C context)
    {
        List sortItems = process(node.getSortItems(), context);
        if (sameElements(node.getSortItems(), sortItems)) {
            return node;
        }

        return new OrderBy(sortItems);
    }

    @Override
    protected Node visitPrepare(Prepare node, C context)
    {
        Node statement = process(node.getStatement(), context);
        if (node.getStatement() == statement) {
            return node;
        }

        return new Prepare(node.getName(), (Statement) statement);
    }

    @Override
    protected Node visitProperty(Property node, C context)
    {
        Node name = process(node.getName(), context);
        Node value = process(node.getValue(), context);
        if (node.getName() == name && node.getValue() == value) {
            return node;
        }

        return new Property((Identifier) name, (Expression) value);
    }

    @Override
    protected Node visitQuery(Query node, C context)
    {
        Optional with = process(node.getWith(), context);
        Node queryBody = process(node.getQueryBody(), context);
        Optional orderBy = process(node.getOrderBy(), context);
        if (node.getQueryBody() == queryBody && sameElement(node.getWith(), with) && sameElement(node.getOrderBy(), orderBy)) {
            return node;
        }

        return new Query(with, (QueryBody) queryBody, orderBy, node.getOffset(), node.getLimit());
    }

    @Override
    protected Node visitQuerySpecification(QuerySpecification node, C context)
    {
        Node select = process(node.getSelect(), context);
        Optional from = process(node.getFrom(), context);
        Optional where = process(node.getWhere(), context);
        Optional groupBy = process(node.getGroupBy(), context);
        Optional having = process(node.getHaving(), context);
        Optional orderBy = process(node.getOrderBy(), context);
        if (node.getSelect() ==
                select && sameElement(node.getFrom(), from) && sameElement(node.getWhere(), where) && sameElement(node.getGroupBy(), groupBy) && sameElement(node.getHaving(),
                having) && sameElement(node.getOrderBy(), orderBy)) {
            return node;
        }

        return new QuerySpecification(
                (Select) select,
                from,
                where,
                groupBy,
                having,
                orderBy,
                node.getOffset(),
                node.getLimit());
    }

    @Override
    protected Node visitRefreshMaterializedView(RefreshMaterializedView node, C context)
    {
        Node table = process(node.getTarget(), context);
        Node where = process(node.getWhere(), context);
        if (node.getTarget() == table && node.getWhere() == where) {
            return node;
        }

        return new RefreshMaterializedView((Table) table, (Expression) where);
    }

    @Override
    protected Node visitReturn(Return node, C context)
    {
        Node expression = process(node.getExpression(), context);
        if (node.getExpression() == expression) {
            return node;
        }

        return new Return((Expression) expression);
    }

    @Override
    protected Node visitRollup(Rollup node, C context)
    {
        List expressions = process(node.getExpressions(), context);
        if (sameElements(node.getExpressions(), expressions)) {
            return node;
        }

        return new Rollup(expressions);
    }

    @Override
    protected Node visitRow(Row node, C context)
    {
        List items = process(node.getItems(), context);
        if (sameElements(node.getItems(), items)) {
            return node;
        }

        return new Row(items);
    }

    @Override
    protected Node visitSampledRelation(SampledRelation node, C context)
    {
        Node relation = process(node.getRelation(), context);
        Node samplePercentage = process(node.getSamplePercentage(), context);
        if (node.getRelation() == relation && node.getSamplePercentage() == samplePercentage) {
            return node;
        }

        return new SampledRelation((Relation) relation, node.getType(), (Expression) samplePercentage);
    }

    @Override
    protected Node visitSelect(Select node, C context)
    {
        List selectItems = process(node.getSelectItems(), context);
        if (sameElements(node.getSelectItems(), selectItems)) {
            return node;
        }

        return new Select(node.isDistinct(), selectItems);
    }

    @Override
    protected Node visitShowStats(ShowStats node, C context)
    {
        Node relation = process(node.getRelation(), context);
        if (node.getRelation() == relation) {
            return node;
        }

        return new ShowStats((Relation) relation);
    }

    @Override
    protected Node visitSimpleGroupBy(SimpleGroupBy node, C context)
    {
        List columns = process(node.getExpressions(), context);
        if (sameElements(node.getExpressions(), columns)) {
            return node;
        }

        return new SimpleGroupBy(columns);
    }

    @Override
    protected Node visitSingleColumn(SingleColumn node, C context)
    {
        Node expression = process(node.getExpression(), context);
        if (node.getExpression() == expression) {
            return node;
        }

        return new SingleColumn((Expression) expression, node.getAlias());
    }

    @Override
    protected Node visitSortItem(SortItem node, C context)
    {
        Node sortKey = process(node.getSortKey(), context);
        if (node.getSortKey() == sortKey) {
            return node;
        }

        return new SortItem((Expression) sortKey, node.getOrdering(), node.getNullOrdering());
    }

    @Override
    protected Node visitTableSubquery(TableSubquery node, C context)
    {
        Node query = process(node.getQuery(), context);
        if (node.getQuery() == query) {
            return node;
        }

        return new TableSubquery((Query) query);
    }

    @Override
    protected Node visitUnion(Union node, C context)
    {
        List relations = process(node.getRelations(), context);
        if (sameElements(node.getRelations(), relations)) {
            return node;
        }

        return new Union(relations, node.isDistinct());
    }

    @Override
    protected Node visitUnnest(Unnest node, C context)
    {
        List expressions = process(node.getExpressions(), context);
        if (sameElements(node.getExpressions(), expressions)) {
            return node;
        }

        return new Unnest(expressions, node.isWithOrdinality());
    }

    @Override
    protected Node visitValues(Values node, C context)
    {
        List expressions = process(node.getRows(), context);
        if (sameElements(node.getRows(), expressions)) {
            return node;
        }

        return new Values(expressions);
    }

    @Override
    protected Node visitWindow(Window node, C context)
    {
        List partitionBy = process(node.getPartitionBy(), context);
        Optional orderBy = process(node.getOrderBy(), context);
        Optional frame = process(node.getFrame(), context);
        if (sameElements(node.getPartitionBy(), partitionBy) && sameElement(node.getOrderBy(), orderBy) && sameElement(node.getFrame(), frame)) {
            return node;
        }

        return new Window(partitionBy, orderBy, frame);
    }

    @Override
    protected Node visitWindowFrame(WindowFrame node, C context)
    {
        Node start = process(node.getStart(), context);
        Optional end = process(node.getEnd(), context);
        if (node.getStart() == start && sameElement(node.getEnd(), end)) {
            return node;
        }

        return new WindowFrame(node.getType(), (FrameBound) start, end);
    }

    @Override
    protected Node visitWith(With node, C context)
    {
        List queries = process(node.getQueries(), context);
        if (sameElements(node.getQueries(), queries)) {
            return node;
        }

        return new With(node.isRecursive(), queries);
    }

    @Override
    protected Node visitWithQuery(WithQuery node, C context)
    {
        Node name = process(node.getName(), context);
        Node query = process(node.getQuery(), context);
        Optional> columnNames = node.getColumnNames().map(columnNamesList -> process(columnNamesList, context));
        if (node.getName() == name && node.getQuery() == query && sameElement(node.getColumnNames(), columnNames)) {
            return node;
        }

        return new WithQuery(node.getName(), (Query) query, node.getColumnNames());
    }

    private  List process(List elements, C context)
    {
        if (elements == null) {
            return null;
        }
        List result = elements.stream().map(element -> (T) process(element, context)).collect(ImmutableList.toImmutableList());
        return sameElements(elements, result) ? elements : result;
    }

    private  Optional process(Optional element, C context)
    {
        if (element == null) {
            return null;
        }
        Optional result = element.map(e -> (T) process(e, context));
        return sameElement(element, result) ? element : result;
    }

    private static  boolean sameElement(Optional a, Optional b)
    {
        if (a == null && b == null) {
            return true;
        }
        if (a == null || b == null) {
            return false;
        }
        if (!a.isPresent() && !b.isPresent()) {
            return true;
        }
        else if (a.isPresent() != b.isPresent()) {
            return false;
        }

        return a.get() == b.get();
    }

    @SuppressWarnings("ObjectEquality")
    private static  boolean sameElements(Iterable a, Iterable b)
    {
        if (a == null && b == null) {
            return true;
        }
        if (a == null || b == null) {
            return false;
        }

        if (Iterables.size(a) != Iterables.size(b)) {
            return false;
        }

        Iterator first = a.iterator();
        Iterator second = b.iterator();

        while (first.hasNext() && second.hasNext()) {
            if (first.next() != second.next()) {
                return false;
            }
        }

        return true;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy