All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.ImplementTableFunctionSource Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.type.Type;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableFunctionNode;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification;
import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties;
import io.trino.sql.planner.plan.TableFunctionProcessorNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.plan.WindowNode.Frame;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.CoalesceExpression;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.GenericLiteral;
import io.trino.sql.tree.IfExpression;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.NullLiteral;

import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.planner.plan.JoinNode.Type.FULL;
import static io.trino.sql.planner.plan.JoinNode.Type.INNER;
import static io.trino.sql.planner.plan.JoinNode.Type.LEFT;
import static io.trino.sql.planner.plan.JoinNode.Type.RIGHT;
import static io.trino.sql.planner.plan.Patterns.tableFunction;
import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN;
import static io.trino.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM;
import static io.trino.sql.tree.FrameBound.Type.UNBOUNDED_FOLLOWING;
import static io.trino.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING;
import static io.trino.sql.tree.LogicalExpression.Operator.AND;
import static io.trino.sql.tree.LogicalExpression.Operator.OR;
import static io.trino.sql.tree.WindowFrame.Type.ROWS;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

/**
 * This rule prepares cartesian product of partitions
 * from all inputs of table function.
 * 

* It rewrites TableFunctionNode with potentially many sources * into a TableFunctionProcessorNode. The new node has one * source being a combination of the original sources. *

* The original sources are combined with joins. The join * conditions depend on the prune when empty property, and on * the co-partitioning of sources. *

* The resulting source should be partitioned and ordered * according to combined schemas from the component sources. *

* Example transformation for two sources, both with set semantics * and KEEP WHEN EMPTY property: *

 * - TableFunction foo
 *      - source T1(a1, b1) PARTITION BY a1 ORDER BY b1
 *      - source T2(a2, b2) PARTITION BY a2
 * 
* Is transformed into: *
 * - TableFunctionProcessor foo
 *      PARTITION BY (a1, a2), ORDER BY combined_row_number
 *      - Project
 *          marker_1 <= IF(table1_row_number = combined_row_number, table1_row_number, CAST(null AS bigint))
 *          marker_2 <= IF(table2_row_number = combined_row_number, table2_row_number, CAST(null AS bigint))
 *          - Project
 *              combined_row_number <= IF(COALESCE(table1_row_number, BIGINT '-1') > COALESCE(table2_row_number, BIGINT '-1'), table1_row_number, table2_row_number)
 *              combined_partition_size <= IF(COALESCE(table1_partition_size, BIGINT '-1') > COALESCE(table2_partition_size, BIGINT '-1'), table1_partition_size, table2_partition_size)
 *              - FULL Join
 *                  [table1_row_number = table2_row_number OR
 *                   table1_row_number > table2_partition_size AND table2_row_number = BIGINT '1' OR
 *                   table2_row_number > table1_partition_size AND table1_row_number = BIGINT '1']
 *                  - Window [PARTITION BY a1 ORDER BY b1]
 *                      table1_row_number <= row_number()
 *                      table1_partition_size <= count()
 *                          - source T1(a1, b1)
 *                  - Window [PARTITION BY a2]
 *                      table2_row_number <= row_number()
 *                      table2_partition_size <= count()
 *                          - source T2(a2, b2)
 * 
*/ public class ImplementTableFunctionSource implements Rule { private static final Pattern PATTERN = tableFunction(); private static final Frame FULL_FRAME = new Frame( ROWS, UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); private static final DataOrganizationSpecification UNORDERED_SINGLE_PARTITION = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); private final Metadata metadata; public ImplementTableFunctionSource(Metadata metadata) { this.metadata = requireNonNull(metadata, "metadata is null"); } @Override public Pattern getPattern() { return PATTERN; } @Override public Result apply(TableFunctionNode node, Captures captures, Context context) { if (node.getSources().isEmpty()) { return Result.ofPlanNode(new TableFunctionProcessorNode( node.getId(), node.getName(), node.getProperOutputs(), Optional.empty(), false, ImmutableList.of(), ImmutableList.of(), Optional.empty(), Optional.empty(), ImmutableSet.of(), 0, Optional.empty(), node.getHandle())); } if (node.getSources().size() == 1) { // Single source does not require pre-processing. // If the source has row semantics, its specification is empty. // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. // This property can be used later to choose optimal distribution. TableArgumentProperties sourceProperties = getOnlyElement(node.getTableArgumentProperties()); return Result.ofPlanNode(new TableFunctionProcessorNode( node.getId(), node.getName(), node.getProperOutputs(), Optional.of(getOnlyElement(node.getSources())), sourceProperties.isPruneWhenEmpty(), ImmutableList.of(sourceProperties.getPassThroughSpecification()), ImmutableList.of(sourceProperties.getRequiredColumns()), Optional.empty(), sourceProperties.getSpecification(), ImmutableSet.of(), 0, Optional.empty(), node.getHandle())); } Map sources = mapSourcesByName(node.getSources(), node.getTableArgumentProperties()); ImmutableList.Builder intermediateResultsBuilder = ImmutableList.builder(); ResolvedFunction rowNumberFunction = metadata.resolveBuiltinFunction("row_number", ImmutableList.of()); ResolvedFunction countFunction = metadata.resolveBuiltinFunction("count", ImmutableList.of()); // handle co-partitioned sources for (List copartitioningList : node.getCopartitioningLists()) { List sourceList = copartitioningList.stream() .map(sources::get) .collect(toImmutableList()); intermediateResultsBuilder.add(copartition(sourceList, rowNumberFunction, countFunction, context)); } // prepare non-co-partitioned sources Set copartitionedSources = node.getCopartitioningLists().stream() .flatMap(Collection::stream) .collect(toImmutableSet()); sources.entrySet().stream() .filter(entry -> !copartitionedSources.contains(entry.getKey())) .map(entry -> planWindowFunctionsForSource(entry.getValue().source(), entry.getValue().properties(), rowNumberFunction, countFunction, context)) .forEach(intermediateResultsBuilder::add); NodeWithSymbols finalResultSource; List intermediateResultSources = intermediateResultsBuilder.build(); if (intermediateResultSources.size() == 1) { finalResultSource = getOnlyElement(intermediateResultSources); } else { NodeWithSymbols first = intermediateResultSources.get(0); NodeWithSymbols second = intermediateResultSources.get(1); JoinedNodes joined = join(first, second, context); for (int i = 2; i < intermediateResultSources.size(); i++) { NodeWithSymbols joinedWithSymbols = appendHelperSymbolsForJoinedNodes(joined, context); joined = join(joinedWithSymbols, intermediateResultSources.get(i), context); } finalResultSource = appendHelperSymbolsForJoinedNodes(joined, context); } // For each source, all source's output symbols are mapped to the source's row number symbol. // The row number symbol will be later converted to a marker of "real" input rows vs "filler" input rows of the source. // The "filler" input rows are the rows appended while joining partitions of different lengths, // to fill the smaller partition up to the bigger partition's size. They are a side effect of the algorithm, // and should not be processed by the table function. Map rowNumberSymbols = finalResultSource.rowNumberSymbolsMapping(); // The max row number symbol from all joined partitions. Symbol finalRowNumberSymbol = finalResultSource.rowNumber(); // Combined partitioning lists from all sources. List finalPartitionBy = finalResultSource.partitionBy(); NodeWithMarkers marked = appendMarkerSymbols(finalResultSource.node(), ImmutableSet.copyOf(rowNumberSymbols.values()), finalRowNumberSymbol, context); // Remap the symbol mapping: replace the row number symbol with the corresponding marker symbol. // In the new map, every source symbol is associated with the corresponding marker symbol. // Null value of the marker indicates that the source value should be ignored by the table function. ImmutableMap markerSymbols = rowNumberSymbols.entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> marked.symbolToMarker().get(entry.getValue()))); // Use the final row number symbol for ordering the combined sources. // It runs along each partition in the cartesian product, numbering the partition's rows according to the expected ordering / orderings. // note: ordering is necessary even if all the source tables are not ordered. Thanks to the ordering, the original rows // of each input table come before the "filler" rows. Optional finalOrderBy = Optional.of(new OrderingScheme(ImmutableList.of(finalRowNumberSymbol), ImmutableMap.of(finalRowNumberSymbol, ASC_NULLS_LAST))); // derive the prune when empty property boolean pruneWhenEmpty = node.getTableArgumentProperties().stream().anyMatch(TableArgumentProperties::isPruneWhenEmpty); // Combine the pass through specifications from all sources List passThroughSpecifications = node.getTableArgumentProperties().stream() .map(TableArgumentProperties::getPassThroughSpecification) .collect(toImmutableList()); // Combine the required symbols from all sources List> requiredSymbols = node.getTableArgumentProperties().stream() .map(TableArgumentProperties::getRequiredColumns) .collect(toImmutableList()); return Result.ofPlanNode(new TableFunctionProcessorNode( node.getId(), node.getName(), node.getProperOutputs(), Optional.of(marked.node()), pruneWhenEmpty, passThroughSpecifications, requiredSymbols, Optional.of(markerSymbols), Optional.of(new DataOrganizationSpecification(finalPartitionBy, finalOrderBy)), ImmutableSet.of(), 0, Optional.empty(), node.getHandle())); } private static Map mapSourcesByName(List sources, List properties) { return Streams.zip(sources.stream(), properties.stream(), SourceWithProperties::new) .collect(toImmutableMap(entry -> entry.properties().getArgumentName(), identity())); } private static NodeWithSymbols planWindowFunctionsForSource( PlanNode source, TableArgumentProperties argumentProperties, ResolvedFunction rowNumberFunction, ResolvedFunction countFunction, Context context) { String argumentName = argumentProperties.getArgumentName(); Symbol rowNumber = context.getSymbolAllocator().newSymbol(argumentName + "_row_number", BIGINT); Map rowNumberSymbolMapping = source.getOutputSymbols().stream() .collect(toImmutableMap(identity(), symbol -> rowNumber)); Symbol partitionSize = context.getSymbolAllocator().newSymbol(argumentName + "_partition_size", BIGINT); // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. // If the source has row semantics, its specification is empty. Currently, such source is processed // as if it was a single partition. Alternatively, it could be split into smaller partitions of arbitrary size. DataOrganizationSpecification specification = argumentProperties.getSpecification().orElse(UNORDERED_SINGLE_PARTITION); PlanNode window = new WindowNode( context.getIdAllocator().getNextId(), source, specification, ImmutableMap.of( rowNumber, new WindowNode.Function(rowNumberFunction, ImmutableList.of(), FULL_FRAME, false), partitionSize, new WindowNode.Function(countFunction, ImmutableList.of(), FULL_FRAME, false)), Optional.empty(), ImmutableSet.of(), 0); return new NodeWithSymbols(window, rowNumber, partitionSize, specification.getPartitionBy(), argumentProperties.isPruneWhenEmpty(), rowNumberSymbolMapping); } private static NodeWithSymbols copartition( List sourceList, ResolvedFunction rowNumberFunction, ResolvedFunction countFunction, Context context) { checkArgument(sourceList.size() >= 2, "co-partitioning list should contain at least two tables"); // Reorder the co-partitioned sources to process the sources with prune when empty property first. // It allows to use inner or side joins instead of outer joins. sourceList = sourceList.stream() .sorted(Comparator.comparingInt(source -> source.properties().isPruneWhenEmpty() ? -1 : 1)) .collect(toImmutableList()); NodeWithSymbols first = planWindowFunctionsForSource(sourceList.get(0).source(), sourceList.get(0).properties(), rowNumberFunction, countFunction, context); NodeWithSymbols second = planWindowFunctionsForSource(sourceList.get(1).source(), sourceList.get(1).properties(), rowNumberFunction, countFunction, context); JoinedNodes copartitioned = copartition(first, second, context); for (int i = 2; i < sourceList.size(); i++) { NodeWithSymbols copartitionedWithSymbols = appendHelperSymbolsForCopartitionedNodes(copartitioned, context); NodeWithSymbols next = planWindowFunctionsForSource(sourceList.get(i).source(), sourceList.get(i).properties(), rowNumberFunction, countFunction, context); copartitioned = copartition(copartitionedWithSymbols, next, context); } return appendHelperSymbolsForCopartitionedNodes(copartitioned, context); } private static JoinedNodes copartition(NodeWithSymbols left, NodeWithSymbols right, Context context) { checkArgument(left.partitionBy().size() == right.partitionBy().size(), "co-partitioning lists do not match"); // In StatementAnalyzer we require that co-partitioned tables have non-empty partitioning column lists. // Co-partitioning tables with empty partition by would be ineffective. checkState(!left.partitionBy().isEmpty(), "co-partitioned tables must have partitioning columns"); Expression leftRowNumber = left.rowNumber().toSymbolReference(); Expression leftPartitionSize = left.partitionSize().toSymbolReference(); List leftPartitionBy = left.partitionBy().stream() .map(Symbol::toSymbolReference) .collect(toImmutableList()); Expression rightRowNumber = right.rowNumber().toSymbolReference(); Expression rightPartitionSize = right.partitionSize().toSymbolReference(); List rightPartitionBy = right.partitionBy().stream() .map(Symbol::toSymbolReference) .collect(toImmutableList()); List copartitionConjuncts = Streams.zip( leftPartitionBy.stream(), rightPartitionBy.stream(), (leftColumn, rightColumn) -> new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, leftColumn, rightColumn))) .collect(toImmutableList()); // Align matching partitions (co-partitions) from left and right source, according to row number. // Matching partitions are identified by their corresponding partitioning columns being NOT DISTINCT from each other. // If one or both sources are ordered, the row number reflects the ordering. // The second and third disjunct in the join condition account for the situation when partitions have different sizes. // It preserves the outstanding rows from the bigger partition, matching them to the first row from the smaller partition. // // (P1_1 IS NOT DISTINCT FROM P2_1) AND (P1_2 IS NOT DISTINCT FROM P2_2) AND ... // AND ( // R1 = R2 // OR // (R1 > S2 AND R2 = 1) // OR // (R2 > S1 AND R1 = 1)) Expression joinCondition = new LogicalExpression( AND, ImmutableList.builder() .addAll(copartitionConjuncts) .add(new LogicalExpression(OR, ImmutableList.of( new ComparisonExpression(EQUAL, leftRowNumber, rightRowNumber), new LogicalExpression(AND, ImmutableList.of( new ComparisonExpression(GREATER_THAN, leftRowNumber, rightPartitionSize), new ComparisonExpression(EQUAL, rightRowNumber, new GenericLiteral("BIGINT", "1")))), new LogicalExpression(AND, ImmutableList.of( new ComparisonExpression(GREATER_THAN, rightRowNumber, leftPartitionSize), new ComparisonExpression(EQUAL, leftRowNumber, new GenericLiteral("BIGINT", "1"))))))) .build()); // The join type depends on the prune when empty property of the sources. // If a source is prune when empty, we should not process any co-partition which is not present in this source, // so effectively the other source becomes inner side of the join. // // example: // table T1 partition by P1 table T2 partition by P2 // P1 C1 P2 C2 // ---------- ---------- // 1 'a' 2 'c' // 2 'b' 3 'd' // // co-partitioning results: // 1) T1 is prune when empty: do LEFT JOIN to drop co-partition '3' // P1 C1 P2 C2 // ------------------------ // 1 'a' null null // 2 'b' 2 'c' // // 2) T2 is prune when empty: do RIGHT JOIN to drop co-partition '1' // P1 C1 P2 C2 // ------------------------ // 2 'b' 2 'c' // null null 3 'd' // // 3) T1 and T2 are both prune when empty: do INNER JOIN to drop co-partitions '1' and '3' // P1 C1 P2 C2 // ------------------------ // 2 'b' 2 'c' // // 4) neither table is prune when empty: do FULL JOIN to preserve all co-partitions // P1 C1 P2 C2 // ------------------------ // 1 'a' null null // 2 'b' 2 'c' // null null 3 'd' JoinNode.Type joinType; if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { joinType = INNER; } else if (left.pruneWhenEmpty()) { joinType = LEFT; } else if (right.pruneWhenEmpty()) { joinType = RIGHT; } else { joinType = FULL; } return new JoinedNodes( new JoinNode( context.getIdAllocator().getNextId(), joinType, left.node(), right.node(), ImmutableList.of(), left.node().getOutputSymbols(), right.node().getOutputSymbols(), false, Optional.of(joinCondition), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty()), left.rowNumber(), left.partitionSize(), left.partitionBy(), left.pruneWhenEmpty(), left.rowNumberSymbolsMapping(), right.rowNumber(), right.partitionSize(), right.partitionBy(), right.pruneWhenEmpty(), right.rowNumberSymbolsMapping()); } private static NodeWithSymbols appendHelperSymbolsForCopartitionedNodes( JoinedNodes copartitionedNodes, Context context) { checkArgument(copartitionedNodes.leftPartitionBy().size() == copartitionedNodes.rightPartitionBy().size(), "co-partitioning lists do not match"); Expression leftRowNumber = copartitionedNodes.leftRowNumber().toSymbolReference(); Expression leftPartitionSize = copartitionedNodes.leftPartitionSize().toSymbolReference(); Expression rightRowNumber = copartitionedNodes.rightRowNumber().toSymbolReference(); Expression rightPartitionSize = copartitionedNodes.rightPartitionSize().toSymbolReference(); // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. Symbol joinedRowNumber = context.getSymbolAllocator().newSymbol("combined_row_number", BIGINT); Expression rowNumberExpression = new IfExpression( new ComparisonExpression( GREATER_THAN, new CoalesceExpression(leftRowNumber, new GenericLiteral("BIGINT", "-1")), new CoalesceExpression(rightRowNumber, new GenericLiteral("BIGINT", "-1"))), leftRowNumber, rightRowNumber); // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. Symbol joinedPartitionSize = context.getSymbolAllocator().newSymbol("combined_partition_size", BIGINT); Expression partitionSizeExpression = new IfExpression( new ComparisonExpression( GREATER_THAN, new CoalesceExpression(leftPartitionSize, new GenericLiteral("BIGINT", "-1")), new CoalesceExpression(rightPartitionSize, new GenericLiteral("BIGINT", "-1"))), leftPartitionSize, rightPartitionSize); // Derive partitioning columns for joined partitions. // Either the combined partitioning columns are pairwise NOT DISTINCT (this is the co-partitioning rule), // or one of them is null as a result of outer join. ImmutableList.Builder joinedPartitionBy = ImmutableList.builder(); Assignments.Builder joinedPartitionByAssignments = Assignments.builder(); for (int i = 0; i < copartitionedNodes.leftPartitionBy().size(); i++) { Symbol leftColumn = copartitionedNodes.leftPartitionBy().get(i); Symbol rightColumn = copartitionedNodes.rightPartitionBy().get(i); Type type = context.getSymbolAllocator().getTypes().get(leftColumn); Symbol joinedColumn = context.getSymbolAllocator().newSymbol("combined_partition_column", type); joinedPartitionByAssignments.put(joinedColumn, new CoalesceExpression(leftColumn.toSymbolReference(), rightColumn.toSymbolReference())); joinedPartitionBy.add(joinedColumn); } PlanNode project = new ProjectNode( context.getIdAllocator().getNextId(), copartitionedNodes.joinedNode(), Assignments.builder() .putIdentities(copartitionedNodes.joinedNode().getOutputSymbols()) .put(joinedRowNumber, rowNumberExpression) .put(joinedPartitionSize, partitionSizeExpression) .putAll(joinedPartitionByAssignments.build()) .build()); boolean joinedPruneWhenEmpty = copartitionedNodes.leftPruneWhenEmpty() || copartitionedNodes.rightPruneWhenEmpty(); Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() .putAll(copartitionedNodes.leftRowNumberSymbolsMapping()) .putAll(copartitionedNodes.rightRowNumberSymbolsMapping()) .buildOrThrow(); return new NodeWithSymbols(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy.build(), joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); } private static JoinedNodes join(NodeWithSymbols left, NodeWithSymbols right, Context context) { Expression leftRowNumber = left.rowNumber().toSymbolReference(); Expression leftPartitionSize = left.partitionSize().toSymbolReference(); Expression rightRowNumber = right.rowNumber().toSymbolReference(); Expression rightPartitionSize = right.partitionSize().toSymbolReference(); // Align rows from left and right source according to row number. Because every partition is row-numbered, this produces cartesian product of partitions. // If one or both sources are ordered, the row number reflects the ordering. // The second and third disjunct in the join condition account for the situation when partitions have different sizes. It preserves the outstanding rows // from the bigger partition, matching them to the first row from the smaller partition. // // R1 = R2 // OR // (R1 > S2 AND R2 = 1) // OR // (R2 > S1 AND R1 = 1) Expression joinCondition = new LogicalExpression(OR, ImmutableList.of( new ComparisonExpression(EQUAL, leftRowNumber, rightRowNumber), new LogicalExpression(AND, ImmutableList.of( new ComparisonExpression(GREATER_THAN, leftRowNumber, rightPartitionSize), new ComparisonExpression(EQUAL, rightRowNumber, new GenericLiteral("BIGINT", "1")))), new LogicalExpression(AND, ImmutableList.of( new ComparisonExpression(GREATER_THAN, rightRowNumber, leftPartitionSize), new ComparisonExpression(EQUAL, leftRowNumber, new GenericLiteral("BIGINT", "1")))))); JoinNode.Type joinType; if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { joinType = INNER; } else if (left.pruneWhenEmpty()) { joinType = LEFT; } else if (right.pruneWhenEmpty()) { joinType = RIGHT; } else { joinType = FULL; } return new JoinedNodes( new JoinNode( context.getIdAllocator().getNextId(), joinType, left.node(), right.node(), ImmutableList.of(), left.node().getOutputSymbols(), right.node().getOutputSymbols(), false, Optional.of(joinCondition), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty()), left.rowNumber(), left.partitionSize(), left.partitionBy(), left.pruneWhenEmpty(), left.rowNumberSymbolsMapping(), right.rowNumber(), right.partitionSize(), right.partitionBy(), right.pruneWhenEmpty(), right.rowNumberSymbolsMapping()); } private static NodeWithSymbols appendHelperSymbolsForJoinedNodes(JoinedNodes joinedNodes, Context context) { Expression leftRowNumber = joinedNodes.leftRowNumber().toSymbolReference(); Expression leftPartitionSize = joinedNodes.leftPartitionSize().toSymbolReference(); Expression rightRowNumber = joinedNodes.rightRowNumber().toSymbolReference(); Expression rightPartitionSize = joinedNodes.rightPartitionSize().toSymbolReference(); // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. Symbol joinedRowNumber = context.getSymbolAllocator().newSymbol("combined_row_number", BIGINT); Expression rowNumberExpression = new IfExpression( new ComparisonExpression( GREATER_THAN, new CoalesceExpression(leftRowNumber, new GenericLiteral("BIGINT", "-1")), new CoalesceExpression(rightRowNumber, new GenericLiteral("BIGINT", "-1"))), leftRowNumber, rightRowNumber); // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. Symbol joinedPartitionSize = context.getSymbolAllocator().newSymbol("combined_partition_size", BIGINT); Expression partitionSizeExpression = new IfExpression( new ComparisonExpression( GREATER_THAN, new CoalesceExpression(leftPartitionSize, new GenericLiteral("BIGINT", "-1")), new CoalesceExpression(rightPartitionSize, new GenericLiteral("BIGINT", "-1"))), leftPartitionSize, rightPartitionSize); PlanNode project = new ProjectNode( context.getIdAllocator().getNextId(), joinedNodes.joinedNode(), Assignments.builder() .putIdentities(joinedNodes.joinedNode().getOutputSymbols()) .put(joinedRowNumber, rowNumberExpression) .put(joinedPartitionSize, partitionSizeExpression) .build()); List joinedPartitionBy = ImmutableList.builder() .addAll(joinedNodes.leftPartitionBy()) .addAll(joinedNodes.rightPartitionBy()) .build(); boolean joinedPruneWhenEmpty = joinedNodes.leftPruneWhenEmpty() || joinedNodes.rightPruneWhenEmpty(); Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() .putAll(joinedNodes.leftRowNumberSymbolsMapping()) .putAll(joinedNodes.rightRowNumberSymbolsMapping()) .buildOrThrow(); return new NodeWithSymbols(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy, joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); } private static NodeWithMarkers appendMarkerSymbols(PlanNode node, Set symbols, Symbol referenceSymbol, Context context) { Assignments.Builder assignments = Assignments.builder(); assignments.putIdentities(node.getOutputSymbols()); ImmutableMap.Builder symbolsToMarkers = ImmutableMap.builder(); for (Symbol symbol : symbols) { Symbol marker = context.getSymbolAllocator().newSymbol("marker", BIGINT); symbolsToMarkers.put(symbol, marker); Expression actual = symbol.toSymbolReference(); Expression reference = referenceSymbol.toSymbolReference(); assignments.put(marker, new IfExpression(new ComparisonExpression(EQUAL, actual, reference), actual, new Cast(new NullLiteral(), toSqlType(BIGINT)))); } PlanNode project = new ProjectNode( context.getIdAllocator().getNextId(), node, assignments.build()); return new NodeWithMarkers(project, symbolsToMarkers.buildOrThrow()); } private record SourceWithProperties(PlanNode source, TableArgumentProperties properties) { SourceWithProperties { requireNonNull(source, "source is null"); requireNonNull(properties, "properties is null"); } } private record NodeWithSymbols(PlanNode node, Symbol rowNumber, Symbol partitionSize, List partitionBy, boolean pruneWhenEmpty, Map rowNumberSymbolsMapping) { NodeWithSymbols { requireNonNull(node, "node is null"); requireNonNull(rowNumber, "rowNumber is null"); requireNonNull(partitionSize, "partitionSize is null"); partitionBy = ImmutableList.copyOf(partitionBy); rowNumberSymbolsMapping = ImmutableMap.copyOf(rowNumberSymbolsMapping); } } private record JoinedNodes( PlanNode joinedNode, Symbol leftRowNumber, Symbol leftPartitionSize, List leftPartitionBy, boolean leftPruneWhenEmpty, Map leftRowNumberSymbolsMapping, Symbol rightRowNumber, Symbol rightPartitionSize, List rightPartitionBy, boolean rightPruneWhenEmpty, Map rightRowNumberSymbolsMapping) { JoinedNodes { requireNonNull(joinedNode, "joinedNode is null"); requireNonNull(leftRowNumber, "leftRowNumber is null"); requireNonNull(leftPartitionSize, "leftPartitionSize is null"); leftPartitionBy = ImmutableList.copyOf(leftPartitionBy); leftRowNumberSymbolsMapping = ImmutableMap.copyOf(leftRowNumberSymbolsMapping); requireNonNull(rightRowNumber, "rightRowNumber is null"); requireNonNull(rightPartitionSize, "rightPartitionSize is null"); rightPartitionBy = ImmutableList.copyOf(rightPartitionBy); rightRowNumberSymbolsMapping = ImmutableMap.copyOf(rightRowNumberSymbolsMapping); } } private record NodeWithMarkers(PlanNode node, Map symbolToMarker) { NodeWithMarkers { requireNonNull(node, "node is null"); symbolToMarker = ImmutableMap.copyOf(symbolToMarker); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy