org.apache.flink.table.planner.plan.QueryOperationConverter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of flink-table-planner-blink_2.12 Show documentation
Show all versions of flink-table-planner-blink_2.12 Show documentation
This module bridges Table/SQL API and runtime. It contains
all resources that are required during pre-flight and runtime
phase. The content of this module is work-in-progress. It will
replace flink-table-planner once it is stable. See FLINK-11439
and FLIP-32 for more details.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan;
import org.apache.flink.annotation.Internal;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.catalog.CatalogManager;
import org.apache.flink.table.catalog.ConnectorCatalogTable;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.catalog.ObjectIdentifier;
import org.apache.flink.table.catalog.UnresolvedIdentifier;
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ExpressionDefaultVisitor;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.ValueLiteralExpression;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.table.functions.TableFunctionDefinition;
import org.apache.flink.table.operations.AggregateQueryOperation;
import org.apache.flink.table.operations.CalculatedQueryOperation;
import org.apache.flink.table.operations.CatalogQueryOperation;
import org.apache.flink.table.operations.DistinctQueryOperation;
import org.apache.flink.table.operations.FilterQueryOperation;
import org.apache.flink.table.operations.JavaDataStreamQueryOperation;
import org.apache.flink.table.operations.JoinQueryOperation;
import org.apache.flink.table.operations.JoinQueryOperation.JoinType;
import org.apache.flink.table.operations.ProjectQueryOperation;
import org.apache.flink.table.operations.QueryOperation;
import org.apache.flink.table.operations.QueryOperationVisitor;
import org.apache.flink.table.operations.ScalaDataStreamQueryOperation;
import org.apache.flink.table.operations.SetQueryOperation;
import org.apache.flink.table.operations.SortQueryOperation;
import org.apache.flink.table.operations.TableSourceQueryOperation;
import org.apache.flink.table.operations.ValuesQueryOperation;
import org.apache.flink.table.operations.WindowAggregateQueryOperation;
import org.apache.flink.table.operations.WindowAggregateQueryOperation.ResolvedGroupWindow;
import org.apache.flink.table.operations.utils.QueryOperationDefaultVisitor;
import org.apache.flink.table.planner.calcite.FlinkContext;
import org.apache.flink.table.planner.calcite.FlinkRelBuilder;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.expressions.PlannerProctimeAttribute;
import org.apache.flink.table.planner.expressions.PlannerRowtimeAttribute;
import org.apache.flink.table.planner.expressions.PlannerWindowEnd;
import org.apache.flink.table.planner.expressions.PlannerWindowReference;
import org.apache.flink.table.planner.expressions.PlannerWindowStart;
import org.apache.flink.table.planner.expressions.RexNodeExpression;
import org.apache.flink.table.planner.expressions.SqlAggFunctionVisitor;
import org.apache.flink.table.planner.expressions.converter.ExpressionConverter;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
import org.apache.flink.table.planner.functions.utils.TableSqlFunction;
import org.apache.flink.table.planner.operations.DataStreamQueryOperation;
import org.apache.flink.table.planner.operations.PlannerQueryOperation;
import org.apache.flink.table.planner.operations.RichTableSourceQueryOperation;
import org.apache.flink.table.planner.plan.logical.LogicalWindow;
import org.apache.flink.table.planner.plan.logical.SessionGroupWindow;
import org.apache.flink.table.planner.plan.logical.SlidingGroupWindow;
import org.apache.flink.table.planner.plan.logical.TumblingGroupWindow;
import org.apache.flink.table.planner.plan.schema.DataStreamTable;
import org.apache.flink.table.planner.plan.schema.DataStreamTable$;
import org.apache.flink.table.planner.plan.schema.LegacyTableSourceTable;
import org.apache.flink.table.planner.plan.schema.TypedFlinkTableFunction;
import org.apache.flink.table.planner.plan.stats.FlinkStatistic;
import org.apache.flink.table.planner.sources.TableSourceUtil;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.sources.LookupableTableSource;
import org.apache.flink.table.sources.StreamTableSource;
import org.apache.flink.table.sources.TableSource;
import org.apache.flink.table.types.DataType;
import org.apache.flink.util.Preconditions;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
import org.apache.calcite.rel.logical.LogicalTableScan;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilder.AggCall;
import org.apache.calcite.tools.RelBuilder.GroupKey;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import scala.Some;
import static java.util.Arrays.asList;
import static java.util.stream.Collectors.toList;
import static org.apache.flink.table.expressions.ApiExpressionUtils.isFunctionOfKind;
import static org.apache.flink.table.expressions.ExpressionUtils.extractValue;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AS;
import static org.apache.flink.table.functions.FunctionKind.AGGREGATE;
import static org.apache.flink.table.functions.FunctionKind.TABLE_AGGREGATE;
import static org.apache.flink.table.types.utils.TypeConversions.fromDataToLogicalType;
import static org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType;
/**
* Converter from Flink's specific relational representation: {@link QueryOperation} to Calcite's specific relational
* representation: {@link RelNode}.
*/
@Internal
public class QueryOperationConverter extends QueryOperationDefaultVisitor {
private final FlinkRelBuilder relBuilder;
private final SingleRelVisitor singleRelVisitor = new SingleRelVisitor();
private final ExpressionConverter expressionConverter;
private final AggregateVisitor aggregateVisitor = new AggregateVisitor();
private final TableAggregateVisitor tableAggregateVisitor = new TableAggregateVisitor();
private final JoinExpressionVisitor joinExpressionVisitor = new JoinExpressionVisitor();
public QueryOperationConverter(FlinkRelBuilder relBuilder) {
this.relBuilder = relBuilder;
this.expressionConverter = new ExpressionConverter(relBuilder);
}
@Override
public RelNode defaultMethod(QueryOperation other) {
other.getChildren().forEach(child -> relBuilder.push(child.accept(this)));
return other.accept(singleRelVisitor);
}
private class SingleRelVisitor implements QueryOperationVisitor {
@Override
public RelNode visit(ProjectQueryOperation projection) {
List rexNodes = convertToRexNodes(projection.getProjectList());
return relBuilder.project(rexNodes, asList(projection.getTableSchema().getFieldNames()), true).build();
}
@Override
public RelNode visit(AggregateQueryOperation aggregate) {
List aggregations = aggregate.getAggregateExpressions()
.stream()
.map(this::getAggCall)
.collect(toList());
List groupings = convertToRexNodes(aggregate.getGroupingExpressions());
GroupKey groupKey = relBuilder.groupKey(groupings);
return relBuilder.aggregate(groupKey, aggregations).build();
}
@Override
public RelNode visit(WindowAggregateQueryOperation windowAggregate) {
List aggregations = windowAggregate.getAggregateExpressions()
.stream()
.map(this::getAggCall)
.collect(toList());
List groupings = convertToRexNodes(windowAggregate.getGroupingExpressions());
LogicalWindow logicalWindow = toLogicalWindow(windowAggregate.getGroupWindow());
PlannerWindowReference windowReference = logicalWindow.aliasAttribute();
List windowProperties = windowAggregate
.getWindowPropertiesExpressions()
.stream()
.map(expr -> convertToWindowProperty(expr, windowReference))
.collect(toList());
GroupKey groupKey = relBuilder.groupKey(groupings);
return relBuilder.windowAggregate(logicalWindow, groupKey, windowProperties, aggregations).build();
}
private FlinkRelBuilder.PlannerNamedWindowProperty convertToWindowProperty(Expression expression,
PlannerWindowReference windowReference) {
Preconditions.checkArgument(expression instanceof CallExpression, "This should never happened");
CallExpression aliasExpr = (CallExpression) expression;
Preconditions.checkArgument(
BuiltInFunctionDefinitions.AS == aliasExpr.getFunctionDefinition(),
"This should never happened");
String name = ((ValueLiteralExpression) aliasExpr.getChildren().get(1)).getValueAs(String.class)
.orElseThrow(
() -> new TableException("Invalid literal."));
Expression windowPropertyExpr = aliasExpr.getChildren().get(0);
Preconditions.checkArgument(windowPropertyExpr instanceof CallExpression, "This should never happened");
CallExpression windowPropertyCallExpr = (CallExpression) windowPropertyExpr;
FunctionDefinition fd = windowPropertyCallExpr.getFunctionDefinition();
if (BuiltInFunctionDefinitions.WINDOW_START == fd) {
return new FlinkRelBuilder.PlannerNamedWindowProperty(name, new PlannerWindowStart(windowReference));
} else if (BuiltInFunctionDefinitions.WINDOW_END == fd) {
return new FlinkRelBuilder.PlannerNamedWindowProperty(name, new PlannerWindowEnd(windowReference));
} else if (BuiltInFunctionDefinitions.PROCTIME == fd) {
return new FlinkRelBuilder.PlannerNamedWindowProperty(name, new PlannerProctimeAttribute(windowReference));
} else if (BuiltInFunctionDefinitions.ROWTIME == fd) {
return new FlinkRelBuilder.PlannerNamedWindowProperty(name, new PlannerRowtimeAttribute(windowReference));
} else {
throw new TableException("Invalid literal.");
}
}
/**
* Get the {@link AggCall} correspond to the aggregate or table aggregate expression.
*/
private AggCall getAggCall(Expression aggregateExpression) {
if (isFunctionOfKind(aggregateExpression, TABLE_AGGREGATE)) {
return aggregateExpression.accept(tableAggregateVisitor);
} else {
return aggregateExpression.accept(aggregateVisitor);
}
}
@Override
public RelNode visit(JoinQueryOperation join) {
final Set corSet;
if (join.isCorrelated()) {
corSet = Collections.singleton(relBuilder.peek().getCluster().createCorrel());
} else {
corSet = Collections.emptySet();
}
return relBuilder.join(
convertJoinType(join.getJoinType()),
join.getCondition().accept(joinExpressionVisitor),
corSet)
.build();
}
@Override
public RelNode visit(SetQueryOperation setOperation) {
switch (setOperation.getType()) {
case INTERSECT:
relBuilder.intersect(setOperation.isAll());
break;
case MINUS:
relBuilder.minus(setOperation.isAll());
break;
case UNION:
relBuilder.union(setOperation.isAll());
break;
}
return relBuilder.build();
}
@Override
public RelNode visit(FilterQueryOperation filter) {
RexNode rexNode = convertExprToRexNode(filter.getCondition());
return relBuilder.filter(rexNode).build();
}
@Override
public RelNode visit(DistinctQueryOperation distinct) {
return relBuilder.distinct().build();
}
@Override
public RelNode visit(SortQueryOperation sort) {
List rexNodes = convertToRexNodes(sort.getOrder());
return relBuilder.sortLimit(sort.getOffset(), sort.getFetch(), rexNodes)
.build();
}
@Override
public RelNode visit(CalculatedQueryOperation calculatedTable) {
FunctionDefinition functionDefinition = calculatedTable.getFunctionDefinition();
List parameters = convertToRexNodes(calculatedTable.getArguments());
FlinkTypeFactory typeFactory = relBuilder.getTypeFactory();
if (functionDefinition instanceof TableFunctionDefinition) {
return convertLegacyTableFunction(
calculatedTable,
(TableFunctionDefinition) functionDefinition,
parameters,
typeFactory);
}
DataTypeFactory dataTypeFactory = ShortcutUtils.unwrapContext(relBuilder.getCluster())
.getCatalogManager()
.getDataTypeFactory();
final BridgingSqlFunction sqlFunction = BridgingSqlFunction.of(
dataTypeFactory,
typeFactory,
SqlKind.OTHER_FUNCTION,
calculatedTable.getFunctionIdentifier().orElse(null),
calculatedTable.getFunctionDefinition(),
calculatedTable.getFunctionDefinition().getTypeInference(dataTypeFactory));
return relBuilder.functionScan(
sqlFunction,
0,
parameters)
.rename(Arrays.asList(calculatedTable.getTableSchema().getFieldNames()))
.build();
}
private RelNode convertLegacyTableFunction(
CalculatedQueryOperation calculatedTable,
TableFunctionDefinition functionDefinition,
List parameters,
FlinkTypeFactory typeFactory) {
String[] fieldNames = calculatedTable.getTableSchema().getFieldNames();
TableFunction tableFunction = functionDefinition.getTableFunction();
DataType resultType = fromLegacyInfoToDataType(functionDefinition.getResultType());
TypedFlinkTableFunction function = new TypedFlinkTableFunction(
tableFunction,
fieldNames,
resultType
);
final TableSqlFunction sqlFunction = new TableSqlFunction(
calculatedTable.getFunctionIdentifier().orElse(null),
tableFunction.toString(),
tableFunction,
resultType,
typeFactory,
function,
scala.Option.empty());
return LogicalTableFunctionScan.create(
relBuilder.peek().getCluster(),
Collections.emptyList(),
relBuilder.getRexBuilder()
.makeCall(function.getRowType(typeFactory), sqlFunction, parameters),
function.getElementType(null),
function.getRowType(typeFactory),
null);
}
@Override
public RelNode visit(CatalogQueryOperation catalogTable) {
ObjectIdentifier objectIdentifier = catalogTable.getTableIdentifier();
return relBuilder.scan(
objectIdentifier.getCatalogName(),
objectIdentifier.getDatabaseName(),
objectIdentifier.getObjectName()
).build();
}
@Override
public RelNode visit(ValuesQueryOperation values) {
RelDataType rowType = relBuilder.getTypeFactory().buildRelNodeRowType(values.getTableSchema());
if (values.getValues().isEmpty()) {
relBuilder.values(rowType);
return relBuilder.build();
}
List> rexLiterals = new ArrayList<>();
List> rexProjections = new ArrayList<>();
splitToProjectionsAndLiterals(values, rexLiterals, rexProjections);
int inputs = 0;
if (rexLiterals.size() != 0) {
inputs += 1;
relBuilder.values(rexLiterals, rowType);
}
if (rexProjections.size() != 0) {
inputs += rexProjections.size();
applyProjections(values, rexProjections);
}
if (inputs > 1) {
relBuilder.union(true, inputs);
}
return relBuilder.build();
}
private void applyProjections(ValuesQueryOperation values, List> rexProjections) {
List relNodes = rexProjections.stream().map(exprs -> {
relBuilder.push(LogicalValues.createOneRow(relBuilder.getCluster()));
relBuilder.project(exprs, asList(values.getTableSchema().getFieldNames()));
return relBuilder.build();
}).collect(toList());
relBuilder.pushAll(relNodes);
}
private void splitToProjectionsAndLiterals(
ValuesQueryOperation values,
List> rexValues,
List> rexProjections) {
values.getValues().stream()
.map(this::convertToRexNodes)
.forEach(row -> {
boolean allLiterals = row.stream().allMatch(expr -> expr instanceof RexLiteral);
if (allLiterals) {
rexValues.add(row.stream().map(expr -> (RexLiteral) expr).collect(toList()));
} else {
rexProjections.add(row);
}
}
);
}
@Override
public RelNode visit(QueryOperation other) {
if (other instanceof PlannerQueryOperation) {
return ((PlannerQueryOperation) other).getCalciteTree();
} else if (other instanceof DataStreamQueryOperation) {
return convertToDataStreamScan((DataStreamQueryOperation) other);
} else if (other instanceof JavaDataStreamQueryOperation) {
JavaDataStreamQueryOperation dataStreamQueryOperation = (JavaDataStreamQueryOperation) other;
return convertToDataStreamScan(
dataStreamQueryOperation.getDataStream(),
dataStreamQueryOperation.getFieldIndices(),
dataStreamQueryOperation.getTableSchema(),
dataStreamQueryOperation.getIdentifier());
} else if (other instanceof ScalaDataStreamQueryOperation) {
ScalaDataStreamQueryOperation dataStreamQueryOperation = (ScalaDataStreamQueryOperation) other;
return convertToDataStreamScan(
dataStreamQueryOperation.getDataStream(),
dataStreamQueryOperation.getFieldIndices(),
dataStreamQueryOperation.getTableSchema(),
dataStreamQueryOperation.getIdentifier());
}
throw new TableException("Unknown table operation: " + other);
}
@Override
public RelNode visit(TableSourceQueryOperation tableSourceOperation) {
TableSource tableSource = tableSourceOperation.getTableSource();
boolean isBatch;
if (tableSource instanceof LookupableTableSource) {
isBatch = tableSourceOperation.isBatch();
} else if (tableSource instanceof StreamTableSource) {
isBatch = ((StreamTableSource) tableSource).isBounded();
} else {
throw new TableException(String.format("%s is not supported.", tableSource.getClass().getSimpleName()));
}
FlinkStatistic statistic;
ObjectIdentifier tableIdentifier;
if (tableSourceOperation instanceof RichTableSourceQueryOperation &&
((RichTableSourceQueryOperation) tableSourceOperation).getIdentifier() != null) {
tableIdentifier = ((RichTableSourceQueryOperation) tableSourceOperation).getIdentifier();
statistic = ((RichTableSourceQueryOperation) tableSourceOperation).getStatistic();
} else {
statistic = FlinkStatistic.UNKNOWN();
// TableSourceScan requires a unique name of a Table for computing a digest.
// We are using the identity hash of the TableSource object.
String refId = "Unregistered_TableSource_" + System.identityHashCode(tableSource);
CatalogManager catalogManager = relBuilder.getCluster().getPlanner().getContext()
.unwrap(FlinkContext.class).getCatalogManager();
tableIdentifier = catalogManager.qualifyIdentifier(UnresolvedIdentifier.of(refId));
}
RelDataType rowType = TableSourceUtil.getSourceRowTypeFromSource(
relBuilder.getTypeFactory(),
tableSource,
!isBatch);
LegacyTableSourceTable tableSourceTable = new LegacyTableSourceTable<>(
relBuilder.getRelOptSchema(),
tableIdentifier,
rowType,
statistic,
tableSource,
!isBatch,
ConnectorCatalogTable.source(tableSource, isBatch));
return LogicalTableScan.create(relBuilder.getCluster(), tableSourceTable);
}
private RelNode convertToDataStreamScan(DataStreamQueryOperation operation) {
List names;
ObjectIdentifier identifier = operation.getIdentifier();
if (identifier != null) {
names = Arrays.asList(
identifier.getCatalogName(),
identifier.getDatabaseName(),
identifier.getObjectName());
} else {
String refId = String.format("Unregistered_DataStream_%s", operation.getDataStream().getId());
names = Collections.singletonList(refId);
}
final RelDataType rowType = DataStreamTable$.MODULE$
.getRowType(relBuilder.getTypeFactory(),
operation.getDataStream(),
operation.getTableSchema().getFieldNames(),
operation.getFieldIndices(),
scala.Option.apply(operation.getFieldNullables()));
DataStreamTable dataStreamTable = new DataStreamTable<>(
relBuilder.getRelOptSchema(),
names,
rowType,
operation.getDataStream(),
operation.getFieldIndices(),
operation.getTableSchema().getFieldNames(),
operation.getStatistic(),
scala.Option.apply(operation.getFieldNullables()));
return LogicalTableScan.create(relBuilder.getCluster(), dataStreamTable);
}
private RelNode convertToDataStreamScan(
DataStream dataStream,
int[] fieldIndices,
TableSchema tableSchema,
Optional identifier) {
List names;
if (identifier.isPresent()) {
names = Arrays.asList(
identifier.get().getCatalogName(),
identifier.get().getDatabaseName(),
identifier.get().getObjectName());
} else {
String refId = String.format("Unregistered_DataStream_%s", dataStream.getId());
names = Collections.singletonList(refId);
}
final RelDataType rowType = DataStreamTable$.MODULE$
.getRowType(relBuilder.getTypeFactory(),
dataStream,
tableSchema.getFieldNames(),
fieldIndices,
scala.Option.empty());
DataStreamTable dataStreamTable = new DataStreamTable<>(
relBuilder.getRelOptSchema(),
names,
rowType,
dataStream,
fieldIndices,
tableSchema.getFieldNames(),
FlinkStatistic.UNKNOWN(),
scala.Option.empty());
return LogicalTableScan.create(relBuilder.getCluster(), dataStreamTable);
}
private List convertToRexNodes(List expressions) {
return expressions
.stream()
.map(QueryOperationConverter.this::convertExprToRexNode)
.collect(toList());
}
private LogicalWindow toLogicalWindow(ResolvedGroupWindow window) {
DataType windowType = window.getTimeAttribute().getOutputDataType();
PlannerWindowReference windowReference = new PlannerWindowReference(window.getAlias(),
new Some<>(fromDataToLogicalType(windowType)));
switch (window.getType()) {
case SLIDE:
return new SlidingGroupWindow(
windowReference,
window.getTimeAttribute(),
window.getSize().orElseThrow(() -> new TableException("missed size parameters!")),
window.getSlide().orElseThrow(() -> new TableException("missed slide parameters!"))
);
case SESSION:
return new SessionGroupWindow(
windowReference,
window.getTimeAttribute(),
window.getGap().orElseThrow(() -> new TableException("missed gap parameters!"))
);
case TUMBLE:
return new TumblingGroupWindow(
windowReference,
window.getTimeAttribute(),
window.getSize().orElseThrow(() -> new TableException("missed size parameters!"))
);
default:
throw new TableException("Unknown window type");
}
}
private JoinRelType convertJoinType(JoinType joinType) {
switch (joinType) {
case INNER:
return JoinRelType.INNER;
case LEFT_OUTER:
return JoinRelType.LEFT;
case RIGHT_OUTER:
return JoinRelType.RIGHT;
case FULL_OUTER:
return JoinRelType.FULL;
default:
throw new TableException("Unknown join type: " + joinType);
}
}
}
private class JoinExpressionVisitor extends ExpressionDefaultVisitor {
private static final int numberOfJoinInputs = 2;
@Override
public RexNode visit(CallExpression callExpression) {
final List newChildren = callExpression.getChildren().stream().map(expr -> {
RexNode convertedNode = expr.accept(this);
return new RexNodeExpression(convertedNode, ((ResolvedExpression) expr).getOutputDataType());
}).collect(Collectors.toList());
CallExpression newCall;
if (callExpression.getFunctionIdentifier().isPresent()) {
newCall = new CallExpression(
callExpression.getFunctionIdentifier().get(), callExpression.getFunctionDefinition(), newChildren,
callExpression.getOutputDataType());
} else {
newCall = new CallExpression(
callExpression.getFunctionDefinition(), newChildren, callExpression.getOutputDataType());
}
return convertExprToRexNode(newCall);
}
@Override
public RexNode visit(FieldReferenceExpression fieldReference) {
return relBuilder.field(numberOfJoinInputs, fieldReference.getInputIndex(), fieldReference.getFieldIndex());
}
@Override
protected RexNode defaultMethod(Expression expression) {
return convertExprToRexNode(expression);
}
}
private class AggregateVisitor extends ExpressionDefaultVisitor {
@Override
public AggCall visit(CallExpression unresolvedCall) {
if (unresolvedCall.getFunctionDefinition() == AS) {
String aggregateName = extractValue(unresolvedCall.getChildren().get(1), String.class)
.orElseThrow(() -> new TableException("Unexpected name."));
Expression aggregate = unresolvedCall.getChildren().get(0);
if (isFunctionOfKind(aggregate, AGGREGATE)) {
return aggregate.accept(
new AggCallVisitor(relBuilder, expressionConverter, aggregateName, false));
}
}
throw new TableException("Expected named aggregate. Got: " + unresolvedCall);
}
@Override
protected AggCall defaultMethod(Expression expression) {
throw new TableException("Unexpected expression: " + expression);
}
private class AggCallVisitor extends ExpressionDefaultVisitor {
private final RelBuilder relBuilder;
private final SqlAggFunctionVisitor sqlAggFunctionVisitor;
private final ExpressionConverter expressionConverter;
private final String name;
private final boolean isDistinct;
public AggCallVisitor(RelBuilder relBuilder, ExpressionConverter expressionConverter, String name,
boolean isDistinct) {
this.relBuilder = relBuilder;
this.sqlAggFunctionVisitor = new SqlAggFunctionVisitor((FlinkTypeFactory) relBuilder.getTypeFactory());
this.expressionConverter = expressionConverter;
this.name = name;
this.isDistinct = isDistinct;
}
@Override
public RelBuilder.AggCall visit(CallExpression call) {
FunctionDefinition def = call.getFunctionDefinition();
if (BuiltInFunctionDefinitions.DISTINCT == def) {
Expression innerAgg = call.getChildren().get(0);
return innerAgg.accept(new AggCallVisitor(relBuilder, expressionConverter, name, true));
} else {
SqlAggFunction sqlAggFunction = call.accept(sqlAggFunctionVisitor);
return relBuilder.aggregateCall(
sqlAggFunction,
isDistinct,
false,
null,
name,
call.getChildren().stream().map(expr -> expr.accept(expressionConverter))
.collect(Collectors.toList()));
}
}
@Override
protected RelBuilder.AggCall defaultMethod(Expression expression) {
throw new TableException("Unexpected expression: " + expression);
}
}
}
private class TableAggregateVisitor extends ExpressionDefaultVisitor {
@Override
public AggCall visit(CallExpression call) {
if (isFunctionOfKind(call, TABLE_AGGREGATE)) {
return call.accept(new TableAggCallVisitor(relBuilder, expressionConverter));
}
return defaultMethod(call);
}
@Override
protected AggCall defaultMethod(Expression expression) {
throw new TableException("Expected table aggregate. Got: " + expression);
}
private class TableAggCallVisitor extends ExpressionDefaultVisitor {
private final RelBuilder relBuilder;
private final SqlAggFunctionVisitor sqlAggFunctionVisitor;
private final ExpressionConverter expressionConverter;
public TableAggCallVisitor(RelBuilder relBuilder, ExpressionConverter expressionConverter) {
this.relBuilder = relBuilder;
this.sqlAggFunctionVisitor = new SqlAggFunctionVisitor((FlinkTypeFactory) relBuilder.getTypeFactory());
this.expressionConverter = expressionConverter;
}
@Override
public RelBuilder.AggCall visit(CallExpression call) {
SqlAggFunction sqlAggFunction = call.accept(sqlAggFunctionVisitor);
return relBuilder.aggregateCall(
sqlAggFunction,
false,
false,
null,
sqlAggFunction.toString(),
call.getChildren().stream().map(expr -> expr.accept(expressionConverter)).collect(toList()));
}
@Override
protected RelBuilder.AggCall defaultMethod(Expression expression) {
throw new TableException("Expected table aggregate. Got: " + expression);
}
}
}
private RexNode convertExprToRexNode(Expression expr) {
return expr.accept(expressionConverter);
}
}