com.facebook.presto.iceberg.optimizer.IcebergMetadataOptimizer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of presto-iceberg Show documentation
Show all versions of presto-iceberg Show documentation
Presto - Iceberg Connector
The newest version!
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.iceberg.optimizer;
import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.predicate.NullableValue;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.common.predicate.TupleDomain.ColumnDomain;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.iceberg.IcebergAbstractMetadata;
import com.facebook.presto.iceberg.IcebergTransactionManager;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorPlanOptimizer;
import com.facebook.presto.spi.ConnectorPlanRewriter;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.ConnectorTableLayout;
import com.facebook.presto.spi.Constraint;
import com.facebook.presto.spi.DiscretePredicates;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.connector.ConnectorMetadata;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.ExpressionOptimizer.Level;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionService;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.iceberg.IcebergSessionProperties.getRowsForMetadataOptimizationThreshold;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.util.Objects.requireNonNull;
public class IcebergMetadataOptimizer
implements ConnectorPlanOptimizer
{
public static final CatalogSchemaName DEFAULT_NAMESPACE = new CatalogSchemaName("presto", "default");
private static final Set ALLOWED_FUNCTIONS = ImmutableSet.of(
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "max"),
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "min"),
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "approx_distinct"));
// Min/Max could be folded into LEAST/GREATEST
private static final Map AGGREGATION_SCALAR_MAPPING = ImmutableMap.of(
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "max"), QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "greatest"),
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "min"), QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "least"));
private final FunctionMetadataManager functionMetadataManager;
private final TypeManager typeManager;
private final IcebergTransactionManager icebergTransactionManager;
private final RowExpressionService rowExpressionService;
private final StandardFunctionResolution functionResolution;
public IcebergMetadataOptimizer(FunctionMetadataManager functionMetadataManager,
TypeManager typeManager,
IcebergTransactionManager icebergTransactionManager,
RowExpressionService rowExpressionService,
StandardFunctionResolution functionResolution)
{
this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
this.typeManager = requireNonNull(typeManager, "typeManager is null");
this.icebergTransactionManager = requireNonNull(icebergTransactionManager, "icebergTransactionManager is null");
this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null");
this.functionResolution = requireNonNull(functionResolution, "functionResolution is null");
}
@Override
public PlanNode optimize(PlanNode maxSubplan, ConnectorSession session, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator)
{
int rowsForMetadataOptimizationThreshold = getRowsForMetadataOptimizationThreshold(session);
Optimizer optimizer = new Optimizer(session, idAllocator,
functionMetadataManager,
typeManager,
icebergTransactionManager,
rowExpressionService,
functionResolution,
rowsForMetadataOptimizationThreshold);
PlanNode rewrittenPlan = ConnectorPlanRewriter.rewriteWith(optimizer, maxSubplan, null);
return rewrittenPlan;
}
private static class Optimizer
extends ConnectorPlanRewriter
{
private final ConnectorSession connectorSession;
private final PlanNodeIdAllocator idAllocator;
private final FunctionMetadataManager functionMetadataManager;
private final TypeManager typeManager;
private final IcebergTransactionManager icebergTransactionManager;
private final RowExpressionService rowExpressionService;
private final StandardFunctionResolution functionResolution;
private final int rowsForMetadataOptimizationThreshold;
private Optimizer(ConnectorSession connectorSession,
PlanNodeIdAllocator idAllocator,
FunctionMetadataManager functionMetadataManager,
TypeManager typeManager,
IcebergTransactionManager icebergTransactionManager,
RowExpressionService rowExpressionService,
StandardFunctionResolution functionResolution,
int rowsForMetadataOptimizationThreshold)
{
checkArgument(rowsForMetadataOptimizationThreshold >= 0, "The value of `rowsForMetadataOptimizationThreshold` should not less than 0");
this.connectorSession = connectorSession;
this.idAllocator = idAllocator;
this.functionMetadataManager = functionMetadataManager;
this.icebergTransactionManager = icebergTransactionManager;
this.rowExpressionService = rowExpressionService;
this.functionResolution = functionResolution;
this.typeManager = typeManager;
this.rowsForMetadataOptimizationThreshold = rowsForMetadataOptimizationThreshold;
}
@Override
public PlanNode visitAggregation(AggregationNode node, RewriteContext context)
{
// supported functions are only MIN/MAX/APPROX_DISTINCT or distinct aggregates
for (Aggregation aggregation : node.getAggregations().values()) {
QualifiedObjectName functionName = functionMetadataManager.getFunctionMetadata(aggregation.getFunctionHandle()).getName();
if (!ALLOWED_FUNCTIONS.contains(functionName) && !aggregation.isDistinct()) {
return context.defaultRewrite(node);
}
}
Optional result = findTableScan(node.getSource(), rowExpressionService.getDeterminismEvaluator());
if (!result.isPresent()) {
return context.defaultRewrite(node);
}
// verify all outputs of table scan are partition keys
TableScanNode tableScan = result.get();
ImmutableMap.Builder columnBuilder = ImmutableMap.builder();
List inputs = tableScan.getOutputVariables();
for (VariableReferenceExpression variable : inputs) {
ColumnHandle column = tableScan.getAssignments().get(variable);
columnBuilder.put(variable, column);
}
Map columns = columnBuilder.build();
// Materialize the list of partitions and replace the TableScan node
// with a Values node
ConnectorTableLayout layout;
if (!tableScan.getTable().getLayout().isPresent()) {
layout = getConnectorMetadata(tableScan.getTable()).getTableLayoutForConstraint(connectorSession, tableScan.getTable().getConnectorHandle(), Constraint.alwaysTrue(), Optional.empty()).getTableLayout();
}
else {
layout = getConnectorMetadata(tableScan.getTable()).getTableLayout(connectorSession, tableScan.getTable().getLayout().get());
}
if (!layout.getDiscretePredicates().isPresent()) {
return context.defaultRewrite(node);
}
DiscretePredicates discretePredicates = layout.getDiscretePredicates().get();
// the optimization is only valid if there is no filter on non-partition columns
if (layout.getPredicate().getColumnDomains().isPresent()) {
List predicateColumns = layout.getPredicate().getColumnDomains().get().stream()
.map(ColumnDomain::getColumn)
.collect(toImmutableList());
if (!discretePredicates.getColumns().containsAll(predicateColumns)) {
return context.defaultRewrite(node);
}
}
// Remaining predicate after tuple domain pushdown in getTableLayout(). This doesn't have overlap with discretePredicates.
// So it only references non-partition columns. Disable the optimization in this case.
Optional remainingPredicate = layout.getRemainingPredicate();
if (remainingPredicate.isPresent() && !remainingPredicate.get().equals(TRUE_CONSTANT)) {
return context.defaultRewrite(node);
}
// the optimization is only valid if the aggregation node only relies on partition keys
if (!discretePredicates.getColumns().containsAll(columns.values())) {
return context.defaultRewrite(node);
}
if (isReducible(node, inputs)) {
// Fold min/max aggregations to a constant value
return reduce(node, inputs, columns, context, discretePredicates);
}
// When `rowsForMetadataOptimizationThreshold == 0`, or partitions number exceeds the threshold, skip the optimization
if (rowsForMetadataOptimizationThreshold == 0 || Iterables.size(discretePredicates.getPredicates()) > rowsForMetadataOptimizationThreshold) {
return context.defaultRewrite(node);
}
ImmutableList.Builder> rowsBuilder = ImmutableList.builder();
for (TupleDomain domain : discretePredicates.getPredicates()) {
if (domain.isNone()) {
continue;
}
Map entries = TupleDomain.extractFixedValues(domain).get();
ImmutableList.Builder rowBuilder = ImmutableList.builder();
// for each input column, add a literal expression using the entry value
for (VariableReferenceExpression input : inputs) {
ColumnHandle column = columns.get(input);
NullableValue value = entries.get(column);
if (value == null) {
// partition key does not have a single value, so bail out to be safe
return context.defaultRewrite(node);
}
else {
rowBuilder.add(new ConstantExpression(Optional.empty(), value.getValue(), input.getType()));
}
}
rowsBuilder.add(rowBuilder.build());
}
// replace the tablescan node with a values node
return ConnectorPlanRewriter.rewriteWith(new Replacer(new ValuesNode(node.getSourceLocation(), idAllocator.getNextId(), inputs, rowsBuilder.build(), Optional.empty())), node);
}
private boolean isReducible(AggregationNode node, List inputs)
{
// The aggregation is reducible when there is no group by key
if (node.getAggregations().isEmpty() || !node.getGroupingKeys().isEmpty() || !(node.getSource() instanceof TableScanNode)) {
return false;
}
for (Aggregation aggregation : node.getAggregations().values()) {
FunctionMetadata functionMetadata = functionMetadataManager.getFunctionMetadata(aggregation.getFunctionHandle());
if (!AGGREGATION_SCALAR_MAPPING.containsKey(functionMetadata.getName()) ||
functionMetadata.getArgumentTypes().size() > 1 ||
!inputs.containsAll(aggregation.getCall().getArguments())) {
return false;
}
}
return true;
}
private PlanNode reduce(
AggregationNode node,
List inputs,
Map columns,
RewriteContext context,
DiscretePredicates predicates)
{
// Fold min/max aggregations to a constant value
ImmutableMap.Builder> inputColumnValuesBuilder = ImmutableMap.builder();
// For each input partition column, we keep one tuple domain for each constant value. When we get the resulting value, we will get the corresponding tuple domain and
// check if the partition stats can be trusted.
ImmutableMap.Builder>> inputValueToDomainBuilder = ImmutableMap.builder();
for (VariableReferenceExpression input : inputs) {
ImmutableList.Builder arguments = ImmutableList.builder();
Map> valueToDomain = new HashMap<>();
ColumnHandle column = columns.get(input);
// for each input column, add a literal expression using the entry value
for (TupleDomain domain : predicates.getPredicates()) {
if (domain.isNone()) {
continue;
}
Map entries = TupleDomain.extractFixedValues(domain).get();
NullableValue value = entries.get(column);
if (value == null) {
// partition key does not have a single value, so bail out to be safe
return context.defaultRewrite(node);
}
// min/max ignores null value
else if (value.getValue() != null) {
Type type = input.getType();
ConstantExpression constantExpression = new ConstantExpression(Optional.empty(), value.getValue(), type);
arguments.add(constantExpression);
valueToDomain.putIfAbsent(constantExpression, domain);
}
}
inputColumnValuesBuilder.put(input, arguments.build());
inputValueToDomainBuilder.put(input, valueToDomain);
}
Map> inputColumnValues = inputColumnValuesBuilder.build();
Assignments.Builder assignmentsBuilder = Assignments.builder();
for (VariableReferenceExpression outputVariable : node.getOutputVariables()) {
Aggregation aggregation = node.getAggregations().get(outputVariable);
RowExpression inputVariable = getOnlyElement(aggregation.getArguments());
RowExpression result = evaluateMinMax(
functionMetadataManager.getFunctionMetadata(node.getAggregations().get(outputVariable).getFunctionHandle()),
inputColumnValues.get(inputVariable));
assignmentsBuilder.put(outputVariable, result);
}
Assignments assignments = assignmentsBuilder.build();
ValuesNode valuesNode = new ValuesNode(node.getSourceLocation(), idAllocator.getNextId(), node.getOutputVariables(), ImmutableList.of(new ArrayList<>(assignments.getExpressions())), Optional.empty());
return new ProjectNode(node.getSourceLocation(), idAllocator.getNextId(), valuesNode, assignments, LOCAL);
}
private RowExpression evaluateMinMax(FunctionMetadata aggregationFunctionMetadata, List arguments)
{
Type returnType = typeManager.getType(aggregationFunctionMetadata.getReturnType());
if (arguments.isEmpty()) {
return new ConstantExpression(Optional.empty(), null, returnType);
}
String scalarFunctionName = AGGREGATION_SCALAR_MAPPING.get(aggregationFunctionMetadata.getName()).getObjectName();
while (arguments.size() > 1) {
List reducedArguments = new ArrayList<>();
// We fold for every 100 values because GREATEST/LEAST has argument count limit
for (List partitionedArguments : Lists.partition(arguments, 100)) {
FunctionHandle functionHandle;
if (scalarFunctionName.equals("greatest")) {
functionHandle = functionResolution.greatestFunction(partitionedArguments.stream().map(RowExpression::getType).collect(toImmutableList()));
}
else if (scalarFunctionName.equals("least")) {
functionHandle = functionResolution.leastFunction(partitionedArguments.stream().map(RowExpression::getType).collect(toImmutableList()));
}
else {
throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "unsupported function: " + scalarFunctionName);
}
Object reducedValue = rowExpressionService.getExpressionOptimizer().optimize(
new CallExpression(
Optional.empty(),
scalarFunctionName,
functionHandle,
returnType,
partitionedArguments),
Level.EVALUATED,
connectorSession,
variableReferenceExpression -> null);
reducedArguments.add(new ConstantExpression(reducedValue, returnType));
}
arguments = reducedArguments;
}
return getOnlyElement(arguments);
}
private static Optional findTableScan(PlanNode source, DeterminismEvaluator determinismEvaluator)
{
while (true) {
// allow any chain of linear transformations
if (source instanceof MarkDistinctNode ||
source instanceof FilterNode ||
source instanceof SortNode) {
source = source.getSources().get(0);
}
else if (source instanceof ProjectNode) {
// verify projections are deterministic
ProjectNode project = (ProjectNode) source;
if (!Iterables.all(project.getAssignments().getExpressions(), determinismEvaluator::isDeterministic)) {
return Optional.empty();
}
source = project.getSource();
}
else if (source instanceof TableScanNode) {
return Optional.of((TableScanNode) source);
}
else {
return Optional.empty();
}
}
}
private ConnectorMetadata getConnectorMetadata(TableHandle tableHandle)
{
requireNonNull(icebergTransactionManager, "icebergTransactionManager is null");
ConnectorMetadata metadata = icebergTransactionManager.get(tableHandle.getTransaction());
checkState(metadata instanceof IcebergAbstractMetadata, "metadata must be IcebergAbstractMetadata");
return metadata;
}
}
private static class Replacer
extends ConnectorPlanRewriter
{
private final ValuesNode replacement;
private Replacer(ValuesNode replacement)
{
this.replacement = replacement;
}
@Override
public PlanNode visitTableScan(TableScanNode node, RewriteContext context)
{
return replacement;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy