All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.LambdaCaptureDesugaringRewriter Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.sql.ir.Bind;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.Lambda;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;

import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols;
import static java.util.Objects.requireNonNull;

public final class LambdaCaptureDesugaringRewriter
{
    public static Expression rewrite(Expression expression, SymbolAllocator symbolAllocator)
    {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(symbolAllocator), expression, new Context());
    }

    private LambdaCaptureDesugaringRewriter() {}

    private static class Visitor
            extends ExpressionRewriter
    {
        private final SymbolAllocator symbolAllocator;

        public Visitor(SymbolAllocator symbolAllocator)
        {
            this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        @Override
        public Expression rewriteLambda(Lambda node, Context context, ExpressionTreeRewriter treeRewriter)
        {
            // Use linked hash set to guarantee deterministic iteration order
            LinkedHashSet referencedSymbols = new LinkedHashSet<>();
            Expression rewrittenBody = treeRewriter.rewrite(node.body(), context.withReferencedSymbols(referencedSymbols));

            List lambdaArguments = node.arguments();

            Set captureSymbols = Sets.difference(referencedSymbols, ImmutableSet.copyOf(lambdaArguments));

            // x -> f(x, captureSymbol)    will be rewritten into
            // "Bind"(captureSymbol, (extraSymbol, x) -> f(x, extraSymbol))

            ImmutableMap.Builder captureSymbolToExtraSymbol = ImmutableMap.builder();
            ImmutableList.Builder newLambdaArguments = ImmutableList.builder();
            for (Symbol captureSymbol : captureSymbols) {
                Symbol extraSymbol = symbolAllocator.newSymbol(captureSymbol.name(), captureSymbol.type());
                captureSymbolToExtraSymbol.put(captureSymbol, extraSymbol);
                newLambdaArguments.add(extraSymbol);
            }
            newLambdaArguments.addAll(node.arguments());

            ImmutableMap symbolsMap = captureSymbolToExtraSymbol.buildOrThrow();
            Function symbolMapping = symbol -> symbolsMap.getOrDefault(symbol, symbol).toSymbolReference();
            Lambda lambda = new Lambda(newLambdaArguments.build(), inlineSymbols(symbolMapping, rewrittenBody));

            Expression rewrittenExpression = lambda;
            if (captureSymbols.size() != 0) {
                List capturedValues = captureSymbols.stream()
                        .map(symbol -> new Reference(symbol.type(), symbol.name()))
                        .collect(toImmutableList());
                rewrittenExpression = new Bind(capturedValues, lambda);
            }

            context.getReferencedSymbols().addAll(captureSymbols);
            return rewrittenExpression;
        }

        @Override
        public Expression rewriteReference(Reference node, Context context, ExpressionTreeRewriter treeRewriter)
        {
            context.getReferencedSymbols().add(new Symbol(node.type(), node.name()));
            return null;
        }
    }

    private static class Context
    {
        // Use linked hash set to guarantee deterministic iteration order
        final LinkedHashSet referencedSymbols;

        public Context()
        {
            this(new LinkedHashSet<>());
        }

        private Context(LinkedHashSet referencedSymbols)
        {
            this.referencedSymbols = referencedSymbols;
        }

        public LinkedHashSet getReferencedSymbols()
        {
            return referencedSymbols;
        }

        public Context withReferencedSymbols(LinkedHashSet symbols)
        {
            return new Context(symbols);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy