Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.apache.flink.table.planner.expressions.converter.CustomizedConvertRule Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.expressions.converter;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.expressions.ApiExpressionUtils;
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ExpressionUtils;
import org.apache.flink.table.expressions.TableReferenceExpression;
import org.apache.flink.table.expressions.TypeLiteralExpression;
import org.apache.flink.table.expressions.ValueLiteralExpression;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.operations.QueryOperation;
import org.apache.flink.table.planner.calcite.FlinkRelBuilder;
import org.apache.flink.table.planner.functions.InternalFunctionDefinitions;
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable;
import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction;
import org.apache.flink.table.types.DataType;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.sql.fun.SqlTrimFunction;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static org.apache.calcite.sql.type.SqlTypeName.VARCHAR;
import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;
import static org.apache.flink.table.planner.expressions.converter.ExpressionConverter.extractValue;
import static org.apache.flink.table.planner.expressions.converter.ExpressionConverter.toRexNodes;
import static org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType;
import static org.apache.flink.table.runtime.typeutils.TypeCheckUtils.isCharacterString;
import static org.apache.flink.table.runtime.typeutils.TypeCheckUtils.isTemporal;
import static org.apache.flink.table.runtime.typeutils.TypeCheckUtils.isTimeInterval;
/**
* Customized {@link CallExpressionConvertRule}, Functions conversion here all require special
* logic, and there may be some special rules, such as needing get the literal values of inputs,
* such as converting to combinations of functions, to convert to RexNode of calcite.
*/
public class CustomizedConvertRule implements CallExpressionConvertRule {
private static final Map DEFINITION_RULE_MAP = new HashMap<>();
static {
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.CAST, CustomizedConvertRule::convertCast);
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.REINTERPRET_CAST,
CustomizedConvertRule::convertReinterpretCast);
DEFINITION_RULE_MAP.put(BuiltInFunctionDefinitions.IN, CustomizedConvertRule::convertIn);
DEFINITION_RULE_MAP.put(BuiltInFunctionDefinitions.GET, CustomizedConvertRule::convertGet);
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.TRIM, CustomizedConvertRule::convertTrim);
DEFINITION_RULE_MAP.put(BuiltInFunctionDefinitions.AS, CustomizedConvertRule::convertAs);
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.BETWEEN, CustomizedConvertRule::convertBetween);
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.NOT_BETWEEN, CustomizedConvertRule::convertNotBetween);
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.REPLACE, CustomizedConvertRule::convertReplace);
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.PLUS, CustomizedConvertRule::convertPlus);
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.TEMPORAL_OVERLAPS,
CustomizedConvertRule::convertTemporalOverlaps);
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.TIMESTAMP_DIFF,
CustomizedConvertRule::convertTimestampDiff);
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.ARRAY, CustomizedConvertRule::convertArray);
DEFINITION_RULE_MAP.put(BuiltInFunctionDefinitions.MAP, CustomizedConvertRule::convertMap);
DEFINITION_RULE_MAP.put(BuiltInFunctionDefinitions.ROW, CustomizedConvertRule::convertRow);
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.ORDER_ASC, CustomizedConvertRule::convertOrderAsc);
DEFINITION_RULE_MAP.put(
BuiltInFunctionDefinitions.SQRT, CustomizedConvertRule::convertSqrt);
// blink expression
DEFINITION_RULE_MAP.put(
InternalFunctionDefinitions.THROW_EXCEPTION,
CustomizedConvertRule::convertThrowException);
}
@Override
public Optional convert(CallExpression call, ConvertContext context) {
Conversion conversion = DEFINITION_RULE_MAP.get(call.getFunctionDefinition());
return Optional.ofNullable(conversion).map(c -> c.convert(call, context));
}
private static RexNode convertCast(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 2);
final RexNode child = context.toRexNode(call.getChildren().get(0));
final TypeLiteralExpression targetType = (TypeLiteralExpression) call.getChildren().get(1);
final RelDataType targetRelDataType =
context.getTypeFactory()
.createFieldTypeFromLogicalType(
targetType.getOutputDataType().getLogicalType());
return context.getRelBuilder().getRexBuilder().makeAbstractCast(targetRelDataType, child);
}
private static RexNode convertOrderAsc(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 1);
return context.toRexNode(call.getChildren().get(0));
}
private static RexNode convertTimestampDiff(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 3);
List childrenRexNode = toRexNodes(context, call.getChildren());
// different orders between flink table api and calcite.
return context.getRelBuilder()
.call(
FlinkSqlOperatorTable.TIMESTAMP_DIFF,
childrenRexNode.get(0),
childrenRexNode.get(2),
childrenRexNode.get(1));
}
private static RexNode convertNotBetween(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 3);
List childrenRexNode = toRexNodes(context, call.getChildren());
RexNode expr = childrenRexNode.get(0);
RexNode lowerBound = childrenRexNode.get(1);
RexNode upperBound = childrenRexNode.get(2);
return context.getRelBuilder()
.or(
context.getRelBuilder()
.call(FlinkSqlOperatorTable.LESS_THAN, expr, lowerBound),
context.getRelBuilder()
.call(FlinkSqlOperatorTable.GREATER_THAN, expr, upperBound));
}
private static RexNode convertBetween(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 3);
List childrenRexNode = toRexNodes(context, call.getChildren());
RexNode expr = childrenRexNode.get(0);
RexNode lowerBound = childrenRexNode.get(1);
RexNode upperBound = childrenRexNode.get(2);
return context.getRelBuilder()
.and(
context.getRelBuilder()
.call(
FlinkSqlOperatorTable.GREATER_THAN_OR_EQUAL,
expr,
lowerBound),
context.getRelBuilder()
.call(FlinkSqlOperatorTable.LESS_THAN_OR_EQUAL, expr, upperBound));
}
private static RexNode convertArray(CallExpression call, ConvertContext context) {
List childrenRexNode = toRexNodes(context, call.getChildren());
RelDataType relDataType =
context.getTypeFactory()
.createFieldTypeFromLogicalType(call.getOutputDataType().getLogicalType());
return context.getRelBuilder()
.getRexBuilder()
.makeCall(
relDataType,
FlinkSqlOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
childrenRexNode);
}
private static RexNode convertMap(CallExpression call, ConvertContext context) {
List children = call.getChildren();
checkArgument(call, !children.isEmpty() && children.size() % 2 == 0);
List childrenRexNode = toRexNodes(context, children);
RelDataType mapType =
context.getTypeFactory()
.createFieldTypeFromLogicalType(call.getOutputDataType().getLogicalType());
return context.getRelBuilder()
.getRexBuilder()
.makeCall(mapType, FlinkSqlOperatorTable.MAP_VALUE_CONSTRUCTOR, childrenRexNode);
}
private static RexNode convertRow(CallExpression call, ConvertContext context) {
List childrenRexNode = toRexNodes(context, call.getChildren());
RelDataType relDataType =
context.getTypeFactory()
.createFieldTypeFromLogicalType(call.getOutputDataType().getLogicalType());
return context.getRelBuilder()
.getRexBuilder()
.makeCall(relDataType, FlinkSqlOperatorTable.ROW, childrenRexNode);
}
private static RexNode convertTemporalOverlaps(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 4);
List childrenRexNode = toRexNodes(context, call.getChildren());
// Standard conversion of the OVERLAPS operator.
// Source: [[org.apache.calcite.sql2rel.StandardConvertletTable#convertOverlaps()]]
RexNode leftTimePoint = childrenRexNode.get(0);
RexNode leftTemporal = childrenRexNode.get(1);
RexNode rightTimePoint = childrenRexNode.get(2);
RexNode rightTemporal = childrenRexNode.get(3);
RexNode convLeftT;
if (isTimeInterval(toLogicalType(leftTemporal.getType()))) {
convLeftT =
context.getRelBuilder()
.call(FlinkSqlOperatorTable.DATETIME_PLUS, leftTimePoint, leftTemporal);
} else {
convLeftT = leftTemporal;
}
// sort end points into start and end, such that (s0 <= e0) and (s1 <= e1).
RexNode leftLe =
context.getRelBuilder()
.call(FlinkSqlOperatorTable.LESS_THAN_OR_EQUAL, leftTimePoint, convLeftT);
RexNode s0 =
context.getRelBuilder()
.call(FlinkSqlOperatorTable.CASE, leftLe, leftTimePoint, convLeftT);
RexNode e0 =
context.getRelBuilder()
.call(FlinkSqlOperatorTable.CASE, leftLe, convLeftT, leftTimePoint);
RexNode convRightT;
if (isTimeInterval(toLogicalType(rightTemporal.getType()))) {
convRightT =
context.getRelBuilder()
.call(
FlinkSqlOperatorTable.DATETIME_PLUS,
rightTimePoint,
rightTemporal);
} else {
convRightT = rightTemporal;
}
RexNode rightLe =
context.getRelBuilder()
.call(FlinkSqlOperatorTable.LESS_THAN_OR_EQUAL, rightTimePoint, convRightT);
RexNode s1 =
context.getRelBuilder()
.call(FlinkSqlOperatorTable.CASE, rightLe, rightTimePoint, convRightT);
RexNode e1 =
context.getRelBuilder()
.call(FlinkSqlOperatorTable.CASE, rightLe, convRightT, rightTimePoint);
// (e0 >= s1) AND (e1 >= s0)
RexNode leftPred =
context.getRelBuilder().call(FlinkSqlOperatorTable.GREATER_THAN_OR_EQUAL, e0, s1);
RexNode rightPred =
context.getRelBuilder().call(FlinkSqlOperatorTable.GREATER_THAN_OR_EQUAL, e1, s0);
return context.getRelBuilder().call(FlinkSqlOperatorTable.AND, leftPred, rightPred);
}
private static RexNode convertPlus(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 2);
List childrenRexNode = toRexNodes(context, call.getChildren());
if (isCharacterString(toLogicalType(childrenRexNode.get(0).getType()))) {
return context.getRelBuilder()
.call(
FlinkSqlOperatorTable.CONCAT,
childrenRexNode.get(0),
context.getRelBuilder().cast(childrenRexNode.get(1), VARCHAR));
} else if (isCharacterString(toLogicalType(childrenRexNode.get(1).getType()))) {
return context.getRelBuilder()
.call(
FlinkSqlOperatorTable.CONCAT,
context.getRelBuilder().cast(childrenRexNode.get(0), VARCHAR),
childrenRexNode.get(1));
} else if (isTimeInterval(toLogicalType(childrenRexNode.get(0).getType()))
&& childrenRexNode.get(0).getType() == childrenRexNode.get(1).getType()) {
return context.getRelBuilder().call(FlinkSqlOperatorTable.PLUS, childrenRexNode);
} else if (isTimeInterval(toLogicalType(childrenRexNode.get(0).getType()))
&& isTemporal(toLogicalType(childrenRexNode.get(1).getType()))) {
// Calcite has a bug that can't apply INTERVAL + DATETIME (INTERVAL at left)
// we manually switch them here
return context.getRelBuilder()
.call(
FlinkSqlOperatorTable.DATETIME_PLUS,
childrenRexNode.get(1),
childrenRexNode.get(0));
} else if (isTemporal(toLogicalType(childrenRexNode.get(0).getType()))
&& isTemporal(toLogicalType(childrenRexNode.get(1).getType()))) {
return context.getRelBuilder()
.call(FlinkSqlOperatorTable.DATETIME_PLUS, childrenRexNode);
} else {
return context.getRelBuilder().call(FlinkSqlOperatorTable.PLUS, childrenRexNode);
}
}
private static RexNode convertReplace(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 2, 3);
List children = call.getChildren();
List childrenRexNode = toRexNodes(context, children);
if (children.size() == 2) {
return context.getRelBuilder()
.call(
FlinkSqlOperatorTable.REPLACE,
childrenRexNode.get(0),
childrenRexNode.get(1),
context.getRelBuilder()
.call(
FlinkSqlOperatorTable.CHAR_LENGTH,
childrenRexNode.get(0)));
} else {
return context.getRelBuilder().call(FlinkSqlOperatorTable.REPLACE, childrenRexNode);
}
}
private static RexNode convertAs(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 2);
String name =
extractValue((ValueLiteralExpression) call.getChildren().get(1), String.class);
RexNode child = context.toRexNode(call.getChildren().get(0));
return context.getRelBuilder().alias(child, name);
}
private static RexNode convertTrim(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 4);
List children = call.getChildren();
ValueLiteralExpression removeLeadingExpr = (ValueLiteralExpression) children.get(0);
Boolean removeLeading = extractValue(removeLeadingExpr, Boolean.class);
ValueLiteralExpression removeTrailingExpr = (ValueLiteralExpression) children.get(1);
Boolean removeTrailing = extractValue(removeTrailingExpr, Boolean.class);
RexNode trimString = context.toRexNode(children.get(2));
RexNode str = context.toRexNode(children.get(3));
Enum trimMode;
if (removeLeading && removeTrailing) {
trimMode = SqlTrimFunction.Flag.BOTH;
} else if (removeLeading) {
trimMode = SqlTrimFunction.Flag.LEADING;
} else if (removeTrailing) {
trimMode = SqlTrimFunction.Flag.TRAILING;
} else {
throw new IllegalArgumentException("Unsupported trim mode.");
}
return context.getRelBuilder()
.call(
FlinkSqlOperatorTable.TRIM,
context.getRelBuilder().getRexBuilder().makeFlag(trimMode),
trimString,
str);
}
private static RexNode convertGet(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 2);
RexNode child = context.toRexNode(call.getChildren().get(0));
ValueLiteralExpression keyLiteral = (ValueLiteralExpression) call.getChildren().get(1);
Optional indexOptional =
ExpressionUtils.extractValue(keyLiteral, String.class)
.map(child.getType().getFieldNames()::indexOf);
int index = indexOptional.orElseGet(() -> extractValue(keyLiteral, Integer.class));
return context.getRelBuilder().getRexBuilder().makeFieldAccess(child, index);
}
private static RexNode convertIn(CallExpression call, ConvertContext context) {
checkArgument(call, call.getChildren().size() > 1);
Expression headExpr = call.getChildren().get(1);
if (headExpr instanceof TableReferenceExpression) {
QueryOperation tableOperation =
((TableReferenceExpression) headExpr).getQueryOperation();
RexNode child = context.toRexNode(call.getChildren().get(0));
return RexSubQuery.in(
((FlinkRelBuilder) context.getRelBuilder())
.queryOperation(tableOperation)
.build(),
ImmutableList.of(child));
} else {
List child = toRexNodes(context, call.getChildren());
return context.getRelBuilder()
.getRexBuilder()
.makeIn(child.get(0), child.subList(1, child.size()));
}
}
private static RexNode convertReinterpretCast(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 3);
RexNode child = context.toRexNode(call.getChildren().get(0));
TypeLiteralExpression type = (TypeLiteralExpression) call.getChildren().get(1);
RexNode checkOverflow = context.toRexNode(call.getChildren().get(2));
return context.getRelBuilder()
.getRexBuilder()
.makeReinterpretCast(
context.getTypeFactory()
.createFieldTypeFromLogicalType(
type.getOutputDataType()
.getLogicalType()
.copy(child.getType().isNullable())),
child,
checkOverflow);
}
private static RexNode convertSqrt(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 1);
return context.getRelBuilder()
.call(
FlinkSqlOperatorTable.POWER,
toRexNodes(
context,
Arrays.asList(
call.getChildren().get(0),
ApiExpressionUtils.valueLiteral(0.5))));
}
private static RexNode convertThrowException(CallExpression call, ConvertContext context) {
checkArgumentNumber(call, 2);
DataType type = ((TypeLiteralExpression) call.getChildren().get(1)).getOutputDataType();
SqlThrowExceptionFunction function =
new SqlThrowExceptionFunction(
context.getTypeFactory()
.createFieldTypeFromLogicalType(fromDataTypeToLogicalType(type)));
return context.getRelBuilder().call(function, context.toRexNode(call.getChildren().get(0)));
}
private static void checkArgumentNumber(CallExpression call, int... numbers) {
boolean find = false;
for (int number : numbers) {
if (call.getChildren().size() == number) {
find = true;
break;
}
}
checkArgument(call, find);
}
private static void checkArgument(CallExpression call, boolean check) {
if (!check) {
throw new TableException("Invalid arguments for call: " + call);
}
}
private interface Conversion {
RexNode convert(CallExpression call, ConvertContext context);
}
}