io.substrait.relation.RelCopyOnWriteVisitor Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of core Show documentation
Show all versions of core Show documentation
Create a well-defined, cross-language specification for data compute operations
package io.substrait.relation;
import static io.substrait.relation.CopyOnWriteUtils.allEmpty;
import static io.substrait.relation.CopyOnWriteUtils.or;
import static io.substrait.relation.CopyOnWriteUtils.transformList;
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
/**
* Class used to visit all child relations from a root relation and optionally replace subtrees by
* overriding a visitor method. The traversal will include relations inside of subquery expressions.
* By default, no subtree substitution will be performed. However, if a visit method is overridden
* to return a non-empty optional value, then that value will replace the relation in the tree.
*/
public class RelCopyOnWriteVisitor
implements RelVisitor, EXCEPTION> {
private final ExpressionCopyOnWriteVisitor expressionCopyOnWriteVisitor;
public RelCopyOnWriteVisitor() {
this.expressionCopyOnWriteVisitor = new ExpressionCopyOnWriteVisitor<>(this);
}
public RelCopyOnWriteVisitor(
ExpressionCopyOnWriteVisitor expressionCopyOnWriteVisitor) {
this.expressionCopyOnWriteVisitor = expressionCopyOnWriteVisitor;
}
public RelCopyOnWriteVisitor(
Function, ExpressionCopyOnWriteVisitor> fn) {
this.expressionCopyOnWriteVisitor = fn.apply(this);
}
protected ExpressionCopyOnWriteVisitor getExpressionCopyOnWriteVisitor() {
return expressionCopyOnWriteVisitor;
}
@Override
public Optional visit(Aggregate aggregate) throws EXCEPTION {
var input = aggregate.getInput().accept(this);
var groupings = transformList(aggregate.getGroupings(), this::visitGrouping);
var measures = transformList(aggregate.getMeasures(), this::visitMeasure);
if (allEmpty(input, groupings, measures)) {
return Optional.empty();
}
return Optional.of(
Aggregate.builder()
.from(aggregate)
.input(input.orElse(aggregate.getInput()))
.groupings(groupings.orElse(aggregate.getGroupings()))
.measures(measures.orElse(aggregate.getMeasures()))
.build());
}
protected Optional visitGrouping(Aggregate.Grouping grouping)
throws EXCEPTION {
return visitExprList(grouping.getExpressions())
.map(exprs -> Aggregate.Grouping.builder().from(grouping).expressions(exprs).build());
}
protected Optional visitMeasure(Aggregate.Measure measure) throws EXCEPTION {
var preMeasureFilter = visitOptionalExpression(measure.getPreMeasureFilter());
var afi = visitAggregateFunction(measure.getFunction());
if (allEmpty(preMeasureFilter, afi)) {
return Optional.empty();
}
return Optional.of(
Aggregate.Measure.builder()
.from(measure)
.preMeasureFilter(or(preMeasureFilter, measure::getPreMeasureFilter))
.function(afi.orElse(measure.getFunction()))
.build());
}
protected Optional visitAggregateFunction(
AggregateFunctionInvocation afi) throws EXCEPTION {
var arguments = visitFunctionArguments(afi.arguments());
var sort = transformList(afi.sort(), this::visitSortField);
if (allEmpty(arguments, sort)) {
return Optional.empty();
}
return Optional.of(
AggregateFunctionInvocation.builder()
.from(afi)
.arguments(arguments.orElse(afi.arguments()))
.sort(sort.orElse(afi.sort()))
.build());
}
@Override
public Optional visit(EmptyScan emptyScan) throws EXCEPTION {
Optional filter = visitOptionalExpression(emptyScan.getFilter());
if (allEmpty(filter)) {
return Optional.empty();
}
return Optional.of(
EmptyScan.builder()
.from(emptyScan)
.filter(filter.isPresent() ? filter : emptyScan.getFilter())
.build());
}
@Override
public Optional visit(Fetch fetch) throws EXCEPTION {
return fetch
.getInput()
.accept(this)
.map(input -> Fetch.builder().from(fetch).input(input).build());
}
@Override
public Optional visit(Filter filter) throws EXCEPTION {
var input = filter.getInput().accept(this);
var condition = filter.getCondition().accept(getExpressionCopyOnWriteVisitor());
if (allEmpty(input, condition)) {
return Optional.empty();
}
return Optional.of(
Filter.builder()
.from(filter)
.input(input.orElse(filter.getInput()))
.condition(condition.orElse(filter.getCondition()))
.build());
}
@Override
public Optional visit(Join join) throws EXCEPTION {
var left = join.getLeft().accept(this);
var right = join.getRight().accept(this);
var condition = visitOptionalExpression(join.getCondition());
var postFilter = visitOptionalExpression(join.getPostJoinFilter());
if (allEmpty(left, right, condition, postFilter)) {
return Optional.empty();
}
return Optional.of(
ImmutableJoin.builder()
.from(join)
.left(left.orElse(join.getLeft()))
.right(right.orElse(join.getRight()))
.condition(or(condition, join::getCondition))
.postJoinFilter(or(postFilter, join::getPostJoinFilter))
.build());
}
@Override
public Optional visit(Set set) throws EXCEPTION {
return transformList(set.getInputs(), t -> t.accept(this))
.map(s -> Set.builder().from(set).inputs(s).build());
}
@Override
public Optional visit(NamedScan namedScan) throws EXCEPTION {
var filter = visitOptionalExpression(namedScan.getFilter());
if (allEmpty(filter)) {
return Optional.empty();
}
return Optional.of(
NamedScan.builder().from(namedScan).filter(or(filter, namedScan::getFilter)).build());
}
@Override
public Optional visit(LocalFiles localFiles) throws EXCEPTION {
var filter = visitOptionalExpression(localFiles.getFilter());
if (allEmpty(filter)) {
return Optional.empty();
}
return Optional.of(
LocalFiles.builder().from(localFiles).filter(or(filter, localFiles::getFilter)).build());
}
@Override
public Optional visit(Project project) throws EXCEPTION {
var input = project.getInput().accept(this);
var expressions = visitExprList(project.getExpressions());
if (allEmpty(input, expressions)) {
return Optional.empty();
}
return Optional.of(
Project.builder()
.from(project)
.input(input.orElse(project.getInput()))
.expressions(expressions.orElse(project.getExpressions()))
.build());
}
@Override
public Optional visit(Expand expand) throws EXCEPTION {
throw new UnsupportedOperationException();
}
@Override
public Optional visit(Sort sort) throws EXCEPTION {
var input = sort.getInput().accept(this);
var sortFields = transformList(sort.getSortFields(), this::visitSortField);
if (allEmpty(input, sortFields)) {
return Optional.empty();
}
return Optional.of(
Sort.builder()
.from(sort)
.input(input.orElse(sort.getInput()))
.sortFields(sortFields.orElse(sort.getSortFields()))
.build());
}
@Override
public Optional visit(Cross cross) throws EXCEPTION {
var left = cross.getLeft().accept(this);
var right = cross.getRight().accept(this);
if (allEmpty(left, right)) {
return Optional.empty();
}
return Optional.of(
Cross.builder()
.from(cross)
.left(left.orElse(cross.getLeft()))
.right(right.orElse(cross.getRight()))
.build());
}
@Override
public Optional visit(VirtualTableScan virtualTableScan) throws EXCEPTION {
var filter = visitOptionalExpression(virtualTableScan.getFilter());
if (allEmpty(filter)) {
return Optional.empty();
}
return Optional.of(
VirtualTableScan.builder()
.from(virtualTableScan)
.filter(or(filter, virtualTableScan::getFilter))
.build());
}
@Override
public Optional visit(ExtensionLeaf extensionLeaf) throws EXCEPTION {
return Optional.empty();
}
@Override
public Optional visit(ExtensionSingle extensionSingle) throws EXCEPTION {
return extensionSingle
.getInput()
.accept(this)
.map(input -> ExtensionSingle.builder().from(extensionSingle).input(input).build());
}
@Override
public Optional visit(ExtensionMulti extensionMulti) throws EXCEPTION {
return transformList(extensionMulti.getInputs(), rel -> rel.accept(this))
.map(inputs -> ExtensionMulti.builder().from(extensionMulti).inputs(inputs).build());
}
@Override
public Optional visit(ExtensionTable extensionTable) throws EXCEPTION {
var filter = visitOptionalExpression(extensionTable.getFilter());
if (allEmpty(filter)) {
return Optional.empty();
}
return Optional.of(
ExtensionTable.builder()
.from(extensionTable)
.filter(or(filter, extensionTable::getFilter))
.build());
}
@Override
public Optional visit(HashJoin hashJoin) throws EXCEPTION {
var left = hashJoin.getLeft().accept(this);
var right = hashJoin.getRight().accept(this);
var leftKeys = transformList(hashJoin.getLeftKeys(), this::visitFieldReference);
var rightKeys = transformList(hashJoin.getRightKeys(), this::visitFieldReference);
var postFilter = visitOptionalExpression(hashJoin.getPostJoinFilter());
if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) {
return Optional.empty();
}
return Optional.of(
HashJoin.builder()
.from(hashJoin)
.left(left.orElse(hashJoin.getLeft()))
.right(right.orElse(hashJoin.getRight()))
.leftKeys(leftKeys.orElse(hashJoin.getLeftKeys()))
.rightKeys(rightKeys.orElse(hashJoin.getRightKeys()))
.postJoinFilter(or(postFilter, hashJoin::getPostJoinFilter))
.build());
}
@Override
public Optional visit(MergeJoin mergeJoin) throws EXCEPTION {
var left = mergeJoin.getLeft().accept(this);
var right = mergeJoin.getRight().accept(this);
var leftKeys = transformList(mergeJoin.getLeftKeys(), this::visitFieldReference);
var rightKeys = transformList(mergeJoin.getRightKeys(), this::visitFieldReference);
var postFilter = visitOptionalExpression(mergeJoin.getPostJoinFilter());
if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) {
return Optional.empty();
}
return Optional.of(
MergeJoin.builder()
.from(mergeJoin)
.left(left.orElse(mergeJoin.getLeft()))
.right(right.orElse(mergeJoin.getRight()))
.leftKeys(leftKeys.orElse(mergeJoin.getLeftKeys()))
.rightKeys(rightKeys.orElse(mergeJoin.getRightKeys()))
.postJoinFilter(or(postFilter, mergeJoin::getPostJoinFilter))
.build());
}
@Override
public Optional visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
var left = nestedLoopJoin.getLeft().accept(this);
var right = nestedLoopJoin.getRight().accept(this);
var condition = nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor());
if (allEmpty(left, right, condition)) {
return Optional.empty();
}
return Optional.of(
NestedLoopJoin.builder()
.from(nestedLoopJoin)
.left(left.orElse(nestedLoopJoin.getLeft()))
.right(right.orElse(nestedLoopJoin.getRight()))
.condition(condition.orElse(nestedLoopJoin.getCondition()))
.build());
}
@Override
public Optional visit(ConsistentPartitionWindow consistentPartitionWindow) throws EXCEPTION {
var windowFunctions =
transformList(consistentPartitionWindow.getWindowFunctions(), this::visitWindowRelFunction);
var partitionExpressions =
transformList(
consistentPartitionWindow.getPartitionExpressions(),
t -> t.accept(getExpressionCopyOnWriteVisitor()));
var sorts = transformList(consistentPartitionWindow.getSorts(), this::visitSortField);
if (allEmpty(windowFunctions, partitionExpressions, sorts)) {
return Optional.empty();
}
return Optional.of(
ConsistentPartitionWindow.builder()
.from(consistentPartitionWindow)
.partitionExpressions(
partitionExpressions.orElse(consistentPartitionWindow.getPartitionExpressions()))
.sorts(sorts.orElse(consistentPartitionWindow.getSorts()))
.windowFunctions(windowFunctions.orElse(consistentPartitionWindow.getWindowFunctions()))
.build());
}
protected Optional visitWindowRelFunction(
ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunctionInvocation)
throws EXCEPTION {
var functionArgs = visitFunctionArguments(windowRelFunctionInvocation.arguments());
if (allEmpty(functionArgs)) {
return Optional.empty();
}
return Optional.of(
ConsistentPartitionWindow.WindowRelFunctionInvocation.builder()
.from(windowRelFunctionInvocation)
.arguments(functionArgs.orElse(windowRelFunctionInvocation.arguments()))
.build());
}
// utilities
protected Optional> visitExprList(List exprs) throws EXCEPTION {
return transformList(exprs, t -> t.accept(getExpressionCopyOnWriteVisitor()));
}
public Optional visitFieldReference(FieldReference fieldReference)
throws EXCEPTION {
var inputExpression = visitOptionalExpression(fieldReference.inputExpression());
if (allEmpty(inputExpression)) {
return Optional.empty();
}
return Optional.of(FieldReference.builder().inputExpression(inputExpression).build());
}
protected Optional> visitFunctionArguments(List funcArgs)
throws EXCEPTION {
return CopyOnWriteUtils.transformList(
funcArgs,
arg -> {
if (arg instanceof Expression expr) {
return expr.accept(getExpressionCopyOnWriteVisitor())
.flatMap(Optional::of);
} else {
return Optional.empty();
}
});
}
protected Optional visitSortField(Expression.SortField sortField)
throws EXCEPTION {
return sortField
.expr()
.accept(getExpressionCopyOnWriteVisitor())
.map(expr -> Expression.SortField.builder().from(sortField).expr(expr).build());
}
private Optional visitOptionalExpression(Optional optExpr)
throws EXCEPTION {
// not using optExpr.map to allow us to propagate the THROWABLE nicely
if (optExpr.isPresent()) {
return optExpr.get().accept(getExpressionCopyOnWriteVisitor());
}
return Optional.empty();
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy