All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.facebook.presto.verifier.rewrite.FunctionCallRewriter Maven / Gradle / Ivy

There is a newer version: 0.290
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.verifier.rewrite;

import com.facebook.presto.sql.ExpressionFormatter;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.Query;
import com.facebook.presto.sql.tree.SubqueryExpression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

public class FunctionCallRewriter
        extends DefaultTreeRewriter
{
    public static class FunctionCallSubstitute
    {
        private final QualifiedName name;
        private final List originalArgumentIndices;

        public FunctionCallSubstitute(QualifiedName name, List originalArgumentIndices)
        {
            this.name = requireNonNull(name, "name is null");
            this.originalArgumentIndices = ImmutableList.copyOf(requireNonNull(originalArgumentIndices, "originalArgumentIndices is null"));
        }

        public QualifiedName name()
        {
            return name;
        }

        public List originalArgumentIndices()
        {
            return originalArgumentIndices;
        }

        @Override
        public boolean equals(Object o)
        {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            FunctionCallSubstitute that = (FunctionCallSubstitute) o;
            return name.equals(that.name) && Objects.equals(originalArgumentIndices, that.originalArgumentIndices);
        }

        @Override
        public int hashCode()
        {
            return Objects.hash(name, originalArgumentIndices);
        }
    }

    // Pattern of a function call with function name and a comma-separated argument list, ex., func_x(c_0,c_1), func_x(_,c_1) or func_x(,c_1).
    private static final Pattern FUNCTION_CALL_PATTERN = Pattern.compile("(\\w+)\\(([\\w|,]+)\\)");
    // Pattern to specify function call substitution, with the first element as the original and the second element as the substitute.
    private static final Pattern FUNCTION_CALL_SUBSTITUTION_PATTERN = Pattern.compile(String.format("/%s/%s/", FUNCTION_CALL_PATTERN, FUNCTION_CALL_PATTERN));

    private final Map functionCallSubstituteMap;

    private List> processedFunctionCallSubstitutes = ImmutableList.of();

    private FunctionCallRewriter(Map functionCallSubstituteMap)
    {
        this.functionCallSubstituteMap = requireNonNull(functionCallSubstituteMap, "functionCallSubstituteMap is null.");
    }

    public static FunctionCallRewriter getInstance(String functionCallSubstitutes)
    {
        Map functionCallSubstituteMap = constructFunctionCallSubstituteMap(functionCallSubstitutes);
        return functionCallSubstituteMap.isEmpty() ? null : new FunctionCallRewriter(functionCallSubstituteMap);
    }

    public static boolean validateFunctionCallSubstitutes(String functionCallSubstitutes)
    {
        if (functionCallSubstitutes == null) {
            return false;
        }

        Matcher matcher = FUNCTION_CALL_SUBSTITUTION_PATTERN.matcher(functionCallSubstitutes);
        return matcher.find();
    }

    public static Map constructFunctionCallSubstituteMap(String functionCallSubstitutes)
    {
        ImmutableMap.Builder map = ImmutableMap.builder();
        if (functionCallSubstitutes == null) {
            return map.build();
        }

        Matcher matcher = FUNCTION_CALL_SUBSTITUTION_PATTERN.matcher(functionCallSubstitutes);

        while (matcher.find()) {
            String originalName = matcher.group(1);
            List originalArgumentList = ImmutableList.copyOf(matcher.group(2).split(","));

            String substituteName = matcher.group(3);
            List substituteArgumentList = ImmutableList.copyOf(matcher.group(4).split(","));
            List originalArgumentIndices = substituteArgumentList.stream().map(originalArgumentList::indexOf).collect(Collectors.toList());
            FunctionCallSubstitute substitute = new FunctionCallSubstitute(QualifiedName.of(substituteName), originalArgumentIndices);

            map.put(QualifiedName.of(originalName), substitute);
        }

        return map.build();
    }

    public Node rewrite(Node root)
    {
        Context context = new Context();
        Node rewritten = process(root, context);
        processedFunctionCallSubstitutes = context.getFunctionCallSubstitutes();
        return rewritten;
    }

    public String getProcessedFunctionCallSubstitutes()
    {
        return processedFunctionCallSubstitutes.stream().map(functionCallSubstitute -> {
                    String formattedOriginal = ExpressionFormatter.formatExpression(functionCallSubstitute.get(0), Optional.empty());
                    String formattedSubstitute = ExpressionFormatter.formatExpression(functionCallSubstitute.get(1), Optional.empty());
                    return String.format("% is substituted with %s", formattedOriginal, formattedSubstitute);
                }
        ).collect(Collectors.joining(", "));
    }

    @Override
    protected Node visitExpression(Expression node, FunctionCallRewriter.Context context)
    {
        FunctionCallRewriter nodeRewritter = this;

        return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter()
        {
            @Override
            public Expression rewriteFunctionCall(FunctionCall expression, Void voidContext, ExpressionTreeRewriter treeRewriter)
            {
                if (!functionCallSubstituteMap.containsKey(expression.getName())) {
                    return expression;
                }

                FunctionCall defaultRewrite = treeRewriter.defaultRewrite(expression, voidContext);

                FunctionCallSubstitute substitute = functionCallSubstituteMap.get(expression.getName());
                List originalArguments = expression.getArguments();
                List rewrittenArguments = substitute.originalArgumentIndices.stream()
                        .map(originalIndex -> {
                            Expression originalArgument = originalArguments.get(originalIndex);
                            return treeRewriter.rewrite(originalArgument, voidContext);
                        }).collect(toImmutableList());

                FunctionCall rewritten = new FunctionCall(substitute.name, defaultRewrite.getWindow(), defaultRewrite.getFilter(), defaultRewrite.getOrderBy(),
                        defaultRewrite.isDistinct(), defaultRewrite.isIgnoreNulls(), rewrittenArguments);
                context.addFunctionCallSubstitute(expression, rewritten);

                return rewritten;
            }

            public Expression rewriteSubqueryExpression(SubqueryExpression expression, Void voidContext, ExpressionTreeRewriter treeRewriter)
            {
                Node query = nodeRewritter.process(expression.getQuery(), context);
                if (expression.getQuery() == query) {
                    return expression;
                }

                return new SubqueryExpression((Query) query);
            }
        }, node);
    }

    public static class Context
    {
        private List> functionCallSubstitutes = new ArrayList<>();

        public void addFunctionCallSubstitute(FunctionCall original, FunctionCall substitute)
        {
            functionCallSubstitutes.add(Arrays.asList(original, substitute));
        }

        public List> getFunctionCallSubstitutes()
        {
            return ImmutableList.copyOf(functionCallSubstitutes);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy