All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.optimizations.PlanNodeSearcher Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.PlanNode;

import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.alwaysFalse;
import static com.google.common.base.Predicates.alwaysTrue;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.sql.planner.iterative.Lookup.noLookup;
import static io.trino.sql.planner.plan.ChildReplacer.replaceChildren;
import static java.util.Arrays.asList;
import static java.util.Objects.requireNonNull;

public class PlanNodeSearcher
{
    public static PlanNodeSearcher searchFrom(PlanNode node)
    {
        return searchFrom(node, noLookup());
    }

    /**
     * Use it in optimizer {@link io.trino.sql.planner.iterative.Rule} only if you truly do not have a better option
     * 

* TODO: replace it with a support for plan (physical) properties in rules pattern matching */ public static PlanNodeSearcher searchFrom(PlanNode node, Lookup lookup) { return new PlanNodeSearcher(node, lookup); } private final PlanNode node; private final Lookup lookup; private Predicate where = alwaysTrue(); private Predicate recurseOnlyWhen = alwaysTrue(); private PlanNodeSearcher(PlanNode node, Lookup lookup) { this.node = requireNonNull(node, "node is null"); this.lookup = requireNonNull(lookup, "lookup is null"); } @SafeVarargs public final PlanNodeSearcher whereIsInstanceOfAny(Class... classes) { return whereIsInstanceOfAny(asList(classes)); } public final PlanNodeSearcher whereIsInstanceOfAny(List> classes) { Predicate predicate = alwaysFalse(); for (Class clazz : classes) { predicate = predicate.or(clazz::isInstance); } return where(predicate); } public PlanNodeSearcher where(Predicate where) { this.where = requireNonNull(where, "where is null"); return this; } public PlanNodeSearcher recurseOnlyWhen(Predicate skipOnly) { this.recurseOnlyWhen = requireNonNull(skipOnly, "skipOnly is null"); return this; } public Optional findFirst() { return findFirstRecursive(node); } private Optional findFirstRecursive(PlanNode node) { node = lookup.resolve(node); if (where.test(node)) { return Optional.of(node); } if (recurseOnlyWhen.test(node)) { for (PlanNode source : node.getSources()) { Optional found = findFirstRecursive(source); if (found.isPresent()) { return found; } } } return Optional.empty(); } /** * Return a list of matching nodes ordered as in pre-order traversal of the plan tree. */ public List findAll() { ImmutableList.Builder nodes = ImmutableList.builder(); findAllRecursive(node, nodes); return nodes.build(); } public PlanNode findOnlyElement() { return getOnlyElement(findAll()); } private void findAllRecursive(PlanNode node, ImmutableList.Builder nodes) { node = lookup.resolve(node); if (where.test(node)) { nodes.add(node); } if (recurseOnlyWhen.test(node)) { for (PlanNode source : node.getSources()) { findAllRecursive(source, nodes); } } } public PlanNode removeAll() { return removeAllRecursive(node); } private PlanNode removeAllRecursive(PlanNode node) { node = lookup.resolve(node); if (where.test(node)) { checkArgument( node.getSources().size() == 1, "Unable to remove plan node as it contains 0 or more than 1 children"); return getOnlyElement(node.getSources()); } if (recurseOnlyWhen.test(node)) { List sources = node.getSources().stream() .map(this::removeAllRecursive) .collect(toImmutableList()); return replaceChildren(node, sources); } return node; } public PlanNode removeFirst() { return removeFirstRecursive(node); } private PlanNode removeFirstRecursive(PlanNode node) { node = lookup.resolve(node); if (where.test(node)) { checkArgument( node.getSources().size() == 1, "Unable to remove plan node as it contains 0 or more than 1 children"); return getOnlyElement(node.getSources()); } if (recurseOnlyWhen.test(node)) { List sources = node.getSources(); if (sources.isEmpty()) { return node; } if (sources.size() == 1) { return replaceChildren(node, ImmutableList.of(removeFirstRecursive(getOnlyElement(sources)))); } throw new IllegalArgumentException("Unable to remove first node when a node has multiple children, use removeAll instead"); } return node; } public PlanNode replaceAll(PlanNode newPlanNode) { return replaceAllRecursive(node, newPlanNode); } private PlanNode replaceAllRecursive(PlanNode node, PlanNode nodeToReplace) { node = lookup.resolve(node); if (where.test(node)) { return nodeToReplace; } if (recurseOnlyWhen.test(node)) { List sources = node.getSources().stream() .map(source -> replaceAllRecursive(source, nodeToReplace)) .collect(toImmutableList()); return replaceChildren(node, sources); } return node; } public PlanNode replaceFirst(PlanNode newPlanNode) { return replaceFirstRecursive(node, newPlanNode); } private PlanNode replaceFirstRecursive(PlanNode node, PlanNode nodeToReplace) { node = lookup.resolve(node); if (where.test(node)) { return nodeToReplace; } List sources = node.getSources(); if (sources.isEmpty()) { return node; } if (sources.size() == 1) { return replaceChildren(node, ImmutableList.of(replaceFirstRecursive(node, getOnlyElement(sources)))); } throw new IllegalArgumentException("Unable to replace first node when a node has multiple children, use replaceAll instead"); } public boolean matches() { return findFirst().isPresent(); } public int count() { return findAll().size(); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy