com.hazelcast.org.apache.calcite.sql2rel.StandardConvertletTable Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hazelcast.org.apache.calcite.sql2rel;
import com.hazelcast.org.apache.calcite.avatica.util.DateTimeUtils;
import com.hazelcast.org.apache.calcite.avatica.util.TimeUnit;
import com.hazelcast.org.apache.calcite.plan.RelOptUtil;
import com.hazelcast.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeFactory;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeFamily;
import com.hazelcast.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.org.apache.calcite.rex.RexCall;
import com.hazelcast.org.apache.calcite.rex.RexCallBinding;
import com.hazelcast.org.apache.calcite.rex.RexLiteral;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexRangeRef;
import com.hazelcast.org.apache.calcite.rex.RexUtil;
import com.hazelcast.org.apache.calcite.sql.SqlAggFunction;
import com.hazelcast.org.apache.calcite.sql.SqlBinaryOperator;
import com.hazelcast.org.apache.calcite.sql.SqlCall;
import com.hazelcast.org.apache.calcite.sql.SqlDataTypeSpec;
import com.hazelcast.org.apache.calcite.sql.SqlFunction;
import com.hazelcast.org.apache.calcite.sql.SqlFunctionCategory;
import com.hazelcast.org.apache.calcite.sql.SqlIdentifier;
import com.hazelcast.org.apache.calcite.sql.SqlIntervalLiteral;
import com.hazelcast.org.apache.calcite.sql.SqlIntervalQualifier;
import com.hazelcast.org.apache.calcite.sql.SqlJdbcFunctionCall;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.sql.SqlLiteral;
import com.hazelcast.org.apache.calcite.sql.SqlNode;
import com.hazelcast.org.apache.calcite.sql.SqlNodeList;
import com.hazelcast.org.apache.calcite.sql.SqlNumericLiteral;
import com.hazelcast.org.apache.calcite.sql.SqlOperator;
import com.hazelcast.org.apache.calcite.sql.SqlUtil;
import com.hazelcast.org.apache.calcite.sql.SqlWindowTableFunction;
import com.hazelcast.org.apache.calcite.sql.fun.SqlArrayValueConstructor;
import com.hazelcast.org.apache.calcite.sql.fun.SqlBetweenOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlCase;
import com.hazelcast.org.apache.calcite.sql.fun.SqlDatetimeSubtractionOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlExtractFunction;
import com.hazelcast.org.apache.calcite.sql.fun.SqlJsonValueFunction;
import com.hazelcast.org.apache.calcite.sql.fun.SqlLibraryOperators;
import com.hazelcast.org.apache.calcite.sql.fun.SqlLiteralChainOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlMapValueConstructor;
import com.hazelcast.org.apache.calcite.sql.fun.SqlMultisetQueryConstructor;
import com.hazelcast.org.apache.calcite.sql.fun.SqlMultisetValueConstructor;
import com.hazelcast.org.apache.calcite.sql.fun.SqlOverlapsOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlRowOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlSequenceValueOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import com.hazelcast.org.apache.calcite.sql.fun.SqlTrimFunction;
import com.hazelcast.org.apache.calcite.sql.parser.SqlParserPos;
import com.hazelcast.org.apache.calcite.sql.type.SqlOperandTypeChecker;
import com.hazelcast.org.apache.calcite.sql.type.SqlTypeFamily;
import com.hazelcast.org.apache.calcite.sql.type.SqlTypeName;
import com.hazelcast.org.apache.calcite.sql.type.SqlTypeUtil;
import com.hazelcast.org.apache.calcite.sql.validate.SqlValidator;
import com.hazelcast.org.apache.calcite.sql.validate.SqlValidatorImpl;
import com.hazelcast.org.apache.calcite.util.Pair;
import com.hazelcast.org.apache.calcite.util.Util;
import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.com.google.common.collect.Lists;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
/**
* Standard implementation of {@link SqlRexConvertletTable}.
*/
public class StandardConvertletTable extends ReflectiveConvertletTable {
/** Singleton instance. */
public static final StandardConvertletTable INSTANCE =
new StandardConvertletTable();
//~ Constructors -----------------------------------------------------------
private StandardConvertletTable() {
super();
// Register aliases (operators which have a different name but
// identical behavior to other operators).
addAlias(SqlStdOperatorTable.CHARACTER_LENGTH,
SqlStdOperatorTable.CHAR_LENGTH);
addAlias(SqlStdOperatorTable.IS_UNKNOWN,
SqlStdOperatorTable.IS_NULL);
addAlias(SqlStdOperatorTable.IS_NOT_UNKNOWN,
SqlStdOperatorTable.IS_NOT_NULL);
addAlias(SqlStdOperatorTable.PERCENT_REMAINDER, SqlStdOperatorTable.MOD);
// Register convertlets for specific objects.
registerOp(SqlStdOperatorTable.CAST, this::convertCast);
registerOp(SqlLibraryOperators.INFIX_CAST, this::convertCast);
registerOp(SqlStdOperatorTable.IS_DISTINCT_FROM,
(cx, call) -> convertIsDistinctFrom(cx, call, false));
registerOp(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
(cx, call) -> convertIsDistinctFrom(cx, call, true));
registerOp(SqlStdOperatorTable.PLUS, this::convertPlus);
registerOp(SqlStdOperatorTable.MINUS,
(cx, call) -> {
final RexCall e =
(RexCall) StandardConvertletTable.this.convertCall(cx, call.getOperator(),
call.getOperandList());
switch (e.getOperands().get(0).getType().getSqlTypeName()) {
case DATE:
case TIME:
case TIMESTAMP:
return convertDatetimeMinus(cx, SqlStdOperatorTable.MINUS_DATE,
call);
default:
return e;
}
});
registerOp(SqlLibraryOperators.LTRIM,
new TrimConvertlet(SqlTrimFunction.Flag.LEADING));
registerOp(SqlLibraryOperators.RTRIM,
new TrimConvertlet(SqlTrimFunction.Flag.TRAILING));
registerOp(SqlLibraryOperators.GREATEST, new GreatestConvertlet());
registerOp(SqlLibraryOperators.LEAST, new GreatestConvertlet());
registerOp(SqlLibraryOperators.NVL,
(cx, call) -> {
final RexBuilder rexBuilder = cx.getRexBuilder();
final RexNode operand0 =
cx.convertExpression(call.getOperandList().get(0));
final RexNode operand1 =
cx.convertExpression(call.getOperandList().get(1));
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE,
ImmutableList.of(
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
operand0),
rexBuilder.makeCast(type, operand0),
rexBuilder.makeCast(type, operand1)));
});
registerOp(SqlLibraryOperators.DECODE,
(cx, call) -> {
final RexBuilder rexBuilder = cx.getRexBuilder();
final List operands = convertExpressionList(cx,
call.getOperandList(), SqlOperandTypeChecker.Consistency.NONE);
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
final List exprs = new ArrayList<>();
for (int i = 1; i < operands.size() - 1; i += 2) {
exprs.add(
RelOptUtil.isDistinctFrom(rexBuilder, operands.get(0),
operands.get(i), true));
exprs.add(operands.get(i + 1));
}
if (operands.size() % 2 == 0) {
exprs.add(Util.last(operands));
} else {
exprs.add(rexBuilder.makeNullLiteral(type));
}
return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, exprs);
});
// Expand "x NOT LIKE y" into "NOT (x LIKE y)"
registerOp(SqlStdOperatorTable.NOT_LIKE,
(cx, call) -> cx.convertExpression(
SqlStdOperatorTable.NOT.createCall(SqlParserPos.ZERO,
SqlStdOperatorTable.LIKE.createCall(SqlParserPos.ZERO,
call.getOperandList()))));
// Expand "x NOT SIMILAR y" into "NOT (x SIMILAR y)"
registerOp(SqlStdOperatorTable.NOT_SIMILAR_TO,
(cx, call) -> cx.convertExpression(
SqlStdOperatorTable.NOT.createCall(SqlParserPos.ZERO,
SqlStdOperatorTable.SIMILAR_TO.createCall(SqlParserPos.ZERO,
call.getOperandList()))));
// Unary "+" has no effect, so expand "+ x" into "x".
registerOp(SqlStdOperatorTable.UNARY_PLUS,
(cx, call) -> cx.convertExpression(call.operand(0)));
// "DOT"
registerOp(SqlStdOperatorTable.DOT,
(cx, call) -> cx.getRexBuilder().makeFieldAccess(
cx.convertExpression(call.operand(0)),
call.operand(1).toString(), false));
// "AS" has no effect, so expand "x AS id" into "x".
registerOp(SqlStdOperatorTable.AS,
(cx, call) -> cx.convertExpression(call.operand(0)));
// "SQRT(x)" is equivalent to "POWER(x, .5)"
registerOp(SqlStdOperatorTable.SQRT,
(cx, call) -> cx.convertExpression(
SqlStdOperatorTable.POWER.createCall(SqlParserPos.ZERO,
call.operand(0),
SqlLiteral.createExactNumeric("0.5", SqlParserPos.ZERO))));
// REVIEW jvs 24-Apr-2006: This only seems to be working from within a
// windowed agg. I have added an optimizer rule
// com.hazelcast.org.apache.calcite.rel.rules.AggregateReduceFunctionsRule which handles
// other cases post-translation. The reason I did that was to defer the
// implementation decision; e.g. we may want to push it down to a foreign
// server directly rather than decomposed; decomposition is easier than
// recognition.
// Convert "avg()" to "cast(sum() / count() as
// )". We don't need to handle the empty set specially, because
// the SUM is already supposed to come out as NULL in cases where the
// COUNT is zero, so the null check should take place first and prevent
// division by zero. We need the cast because SUM and COUNT may use
// different types, say BIGINT.
//
// Similarly STDDEV_POP and STDDEV_SAMP, VAR_POP and VAR_SAMP.
registerOp(SqlStdOperatorTable.AVG,
new AvgVarianceConvertlet(SqlKind.AVG));
registerOp(SqlStdOperatorTable.STDDEV_POP,
new AvgVarianceConvertlet(SqlKind.STDDEV_POP));
registerOp(SqlStdOperatorTable.STDDEV_SAMP,
new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
registerOp(SqlStdOperatorTable.STDDEV,
new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
registerOp(SqlStdOperatorTable.VAR_POP,
new AvgVarianceConvertlet(SqlKind.VAR_POP));
registerOp(SqlStdOperatorTable.VAR_SAMP,
new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
registerOp(SqlStdOperatorTable.VARIANCE,
new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
registerOp(SqlStdOperatorTable.COVAR_POP,
new RegrCovarianceConvertlet(SqlKind.COVAR_POP));
registerOp(SqlStdOperatorTable.COVAR_SAMP,
new RegrCovarianceConvertlet(SqlKind.COVAR_SAMP));
registerOp(SqlStdOperatorTable.REGR_SXX,
new RegrCovarianceConvertlet(SqlKind.REGR_SXX));
registerOp(SqlStdOperatorTable.REGR_SYY,
new RegrCovarianceConvertlet(SqlKind.REGR_SYY));
final SqlRexConvertlet floorCeilConvertlet = new FloorCeilConvertlet();
registerOp(SqlStdOperatorTable.FLOOR, floorCeilConvertlet);
registerOp(SqlStdOperatorTable.CEIL, floorCeilConvertlet);
registerOp(SqlStdOperatorTable.TIMESTAMP_ADD, new TimestampAddConvertlet());
registerOp(SqlStdOperatorTable.TIMESTAMP_DIFF,
new TimestampDiffConvertlet());
// Convert "element()" to "$element_slice()", if the
// expression is a multiset of scalars.
if (false) {
registerOp(SqlStdOperatorTable.ELEMENT,
(cx, call) -> {
assert call.operandCount() == 1;
final SqlNode operand = call.operand(0);
final RelDataType type =
cx.getValidator().getValidatedNodeType(operand);
if (!type.getComponentType().isStruct()) {
return cx.convertExpression(
SqlStdOperatorTable.ELEMENT_SLICE.createCall(
SqlParserPos.ZERO, operand));
}
// fallback on default behavior
return StandardConvertletTable.this.convertCall(cx, call);
});
}
// Convert "$element_slice()" to "element().field#0"
if (false) {
registerOp(SqlStdOperatorTable.ELEMENT_SLICE,
(cx, call) -> {
assert call.operandCount() == 1;
final SqlNode operand = call.operand(0);
final RexNode expr =
cx.convertExpression(
SqlStdOperatorTable.ELEMENT.createCall(SqlParserPos.ZERO,
operand));
return cx.getRexBuilder().makeFieldAccess(expr, 0);
});
}
}
//~ Methods ----------------------------------------------------------------
private RexNode or(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.OR, a0, a1);
}
private RexNode eq(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, a0, a1);
}
private RexNode ge(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, a0,
a1);
}
private RexNode le(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, a0, a1);
}
private RexNode and(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.AND, a0, a1);
}
private static RexNode divideInt(RexBuilder rexBuilder, RexNode a0,
RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE_INTEGER, a0, a1);
}
private RexNode plus(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, a0, a1);
}
private RexNode minus(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.MINUS, a0, a1);
}
private static RexNode multiply(RexBuilder rexBuilder, RexNode a0,
RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, a0, a1);
}
private RexNode case_(RexBuilder rexBuilder, RexNode... args) {
return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args);
}
// SqlNode helpers
private SqlCall plus(SqlParserPos pos, SqlNode a0, SqlNode a1) {
return SqlStdOperatorTable.PLUS.createCall(pos, a0, a1);
}
/**
* Converts a CASE expression.
*/
public RexNode convertCase(
SqlRexContext cx,
SqlCase call) {
SqlNodeList whenList = call.getWhenOperands();
SqlNodeList thenList = call.getThenOperands();
assert whenList.size() == thenList.size();
RexBuilder rexBuilder = cx.getRexBuilder();
final List exprList = new ArrayList<>();
final RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
final RexLiteral unknownLiteral = rexBuilder.makeNullLiteral(
typeFactory.createSqlType(SqlTypeName.BOOLEAN));
final RexLiteral nullLiteral = rexBuilder.makeNullLiteral(
typeFactory.createSqlType(SqlTypeName.NULL));
for (int i = 0; i < whenList.size(); i++) {
if (SqlUtil.isNullLiteral(whenList.get(i), false)) {
exprList.add(unknownLiteral);
} else {
exprList.add(cx.convertExpression(whenList.get(i)));
}
if (SqlUtil.isNullLiteral(thenList.get(i), false)) {
exprList.add(nullLiteral);
} else {
exprList.add(cx.convertExpression(thenList.get(i)));
}
}
if (SqlUtil.isNullLiteral(call.getElseOperand(), false)) {
exprList.add(nullLiteral);
} else {
exprList.add(cx.convertExpression(call.getElseOperand()));
}
RelDataType type =
rexBuilder.deriveReturnType(call.getOperator(), exprList);
for (int i : elseArgs(exprList.size())) {
exprList.set(i,
rexBuilder.ensureType(type, exprList.get(i), false));
}
return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, exprList);
}
public RexNode convertMultiset(
SqlRexContext cx,
SqlMultisetValueConstructor op,
SqlCall call) {
final RelDataType originalType =
cx.getValidator().getValidatedNodeType(call);
RexRangeRef rr = cx.getSubQueryExpr(call);
assert rr != null;
RelDataType msType = rr.getType().getFieldList().get(0).getType();
RexNode expr =
cx.getRexBuilder().makeInputRef(
msType,
rr.getOffset());
assert msType.getComponentType().isStruct();
if (!originalType.getComponentType().isStruct()) {
// If the type is not a struct, the multiset operator will have
// wrapped the type as a record. Add a call to the $SLICE operator
// to compensate. For example,
// if '' has type 'RECORD (INTEGER x) MULTISET',
// then '$SLICE() has type 'INTEGER MULTISET'.
// This will be removed as the expression is translated.
expr =
cx.getRexBuilder().makeCall(originalType, SqlStdOperatorTable.SLICE,
ImmutableList.of(expr));
}
return expr;
}
public RexNode convertArray(
SqlRexContext cx,
SqlArrayValueConstructor op,
SqlCall call) {
return convertCall(cx, call);
}
public RexNode convertMap(
SqlRexContext cx,
SqlMapValueConstructor op,
SqlCall call) {
return convertCall(cx, call);
}
public RexNode convertMultisetQuery(
SqlRexContext cx,
SqlMultisetQueryConstructor op,
SqlCall call) {
final RelDataType originalType =
cx.getValidator().getValidatedNodeType(call);
RexRangeRef rr = cx.getSubQueryExpr(call);
assert rr != null;
RelDataType msType = rr.getType().getFieldList().get(0).getType();
RexNode expr =
cx.getRexBuilder().makeInputRef(
msType,
rr.getOffset());
assert msType.getComponentType().isStruct();
if (!originalType.getComponentType().isStruct()) {
// If the type is not a struct, the multiset operator will have
// wrapped the type as a record. Add a call to the $SLICE operator
// to compensate. For example,
// if '' has type 'RECORD (INTEGER x) MULTISET',
// then '$SLICE() has type 'INTEGER MULTISET'.
// This will be removed as the expression is translated.
expr =
cx.getRexBuilder().makeCall(SqlStdOperatorTable.SLICE, expr);
}
return expr;
}
public RexNode convertJdbc(
SqlRexContext cx,
SqlJdbcFunctionCall op,
SqlCall call) {
// Yuck!! The function definition contains arguments!
// TODO: adopt a more conventional definition/instance structure
final SqlCall convertedCall = op.getLookupCall();
return cx.convertExpression(convertedCall);
}
protected RexNode convertCast(
SqlRexContext cx,
final SqlCall call) {
RelDataTypeFactory typeFactory = cx.getTypeFactory();
assert call.getKind() == SqlKind.CAST;
final SqlNode left = call.operand(0);
final SqlNode right = call.operand(1);
if (right instanceof SqlIntervalQualifier) {
final SqlIntervalQualifier intervalQualifier =
(SqlIntervalQualifier) right;
if (left instanceof SqlIntervalLiteral) {
RexLiteral sourceInterval =
(RexLiteral) cx.convertExpression(left);
BigDecimal sourceValue =
(BigDecimal) sourceInterval.getValue();
RexLiteral castedInterval =
cx.getRexBuilder().makeIntervalLiteral(sourceValue,
intervalQualifier);
return castToValidatedType(cx, call, castedInterval);
} else if (left instanceof SqlNumericLiteral) {
RexLiteral sourceInterval =
(RexLiteral) cx.convertExpression(left);
BigDecimal sourceValue =
(BigDecimal) sourceInterval.getValue();
final BigDecimal multiplier = intervalQualifier.getUnit().multiplier;
sourceValue = sourceValue.multiply(multiplier);
RexLiteral castedInterval =
cx.getRexBuilder().makeIntervalLiteral(
sourceValue,
intervalQualifier);
return castToValidatedType(cx, call, castedInterval);
}
return castToValidatedType(cx, call, cx.convertExpression(left));
}
SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
RelDataType type = dataType.deriveType(cx.getValidator());
if (type == null) {
type = cx.getValidator().getValidatedNodeType(dataType.getTypeName());
}
RexNode arg = cx.convertExpression(left);
if (arg.getType().isNullable()) {
type = typeFactory.createTypeWithNullability(type, true);
}
if (SqlUtil.isNullLiteral(left, false)) {
final SqlValidatorImpl validator = (SqlValidatorImpl) cx.getValidator();
validator.setValidatedNodeType(left, type);
return cx.convertExpression(left);
}
if (null != dataType.getCollectionsTypeName()) {
final RelDataType argComponentType =
arg.getType().getComponentType();
final RelDataType componentType = type.getComponentType();
if (argComponentType.isStruct()
&& !componentType.isStruct()) {
RelDataType tt =
typeFactory.builder()
.add(
argComponentType.getFieldList().get(0).getName(),
componentType)
.build();
tt = typeFactory.createTypeWithNullability(
tt,
componentType.isNullable());
boolean isn = type.isNullable();
type = typeFactory.createMultisetType(tt, -1);
type = typeFactory.createTypeWithNullability(type, isn);
}
}
return cx.getRexBuilder().makeCast(type, arg);
}
protected RexNode convertFloorCeil(SqlRexContext cx, SqlCall call) {
final boolean floor = call.getKind() == SqlKind.FLOOR;
// Rewrite floor, ceil of interval
if (call.operandCount() == 1
&& call.operand(0) instanceof SqlIntervalLiteral) {
final SqlIntervalLiteral literal = call.operand(0);
SqlIntervalLiteral.IntervalValue interval =
(SqlIntervalLiteral.IntervalValue) literal.getValue();
BigDecimal val =
interval.getIntervalQualifier().getStartUnit().multiplier;
RexNode rexInterval = cx.convertExpression(literal);
final RexBuilder rexBuilder = cx.getRexBuilder();
RexNode zero = rexBuilder.makeExactLiteral(BigDecimal.valueOf(0));
RexNode cond = ge(rexBuilder, rexInterval, zero);
RexNode pad =
rexBuilder.makeExactLiteral(val.subtract(BigDecimal.ONE));
RexNode cast = rexBuilder.makeReinterpretCast(
rexInterval.getType(), pad, rexBuilder.makeLiteral(false));
RexNode sum = floor
? minus(rexBuilder, rexInterval, cast)
: plus(rexBuilder, rexInterval, cast);
RexNode kase = floor
? case_(rexBuilder, rexInterval, cond, sum)
: case_(rexBuilder, sum, cond, rexInterval);
RexNode factor = rexBuilder.makeExactLiteral(val);
RexNode div = divideInt(rexBuilder, kase, factor);
return multiply(rexBuilder, div, factor);
}
// normal floor, ceil function
return convertFunction(cx, (SqlFunction) call.getOperator(), call);
}
/**
* Converts a call to the {@code EXTRACT} function.
*
* Called automatically via reflection.
*/
public RexNode convertExtract(
SqlRexContext cx,
SqlExtractFunction op,
SqlCall call) {
return convertFunction(cx, (SqlFunction) call.getOperator(), call);
}
private RexNode mod(RexBuilder rexBuilder, RelDataType resType, RexNode res,
BigDecimal val) {
if (val.equals(BigDecimal.ONE)) {
return res;
}
return rexBuilder.makeCall(SqlStdOperatorTable.MOD, res,
rexBuilder.makeExactLiteral(val, resType));
}
private static RexNode divide(RexBuilder rexBuilder, RexNode res,
BigDecimal val) {
if (val.equals(BigDecimal.ONE)) {
return res;
}
// If val is between 0 and 1, rather than divide by val, multiply by its
// reciprocal. For example, rather than divide by 0.001 multiply by 1000.
if (val.compareTo(BigDecimal.ONE) < 0
&& val.signum() == 1) {
try {
final BigDecimal reciprocal =
BigDecimal.ONE.divide(val, RoundingMode.UNNECESSARY);
return multiply(rexBuilder, res,
rexBuilder.makeExactLiteral(reciprocal));
} catch (ArithmeticException e) {
// ignore - reciprocal is not an integer
}
}
return divideInt(rexBuilder, res, rexBuilder.makeExactLiteral(val));
}
public RexNode convertDatetimeMinus(
SqlRexContext cx,
SqlDatetimeSubtractionOperator op,
SqlCall call) {
// Rewrite datetime minus
final RexBuilder rexBuilder = cx.getRexBuilder();
final List operands = call.getOperandList();
final List exprs = convertExpressionList(cx, operands,
SqlOperandTypeChecker.Consistency.NONE);
final RelDataType resType =
cx.getValidator().getValidatedNodeType(call);
return rexBuilder.makeCall(resType, op, exprs.subList(0, 2));
}
public RexNode convertFunction(
SqlRexContext cx,
SqlFunction fun,
SqlCall call) {
final List operands = call.getOperandList();
final List exprs = convertExpressionList(cx, operands,
SqlOperandTypeChecker.Consistency.NONE);
if (fun.getFunctionType() == SqlFunctionCategory.USER_DEFINED_CONSTRUCTOR) {
return makeConstructorCall(cx, fun, exprs);
}
RelDataType returnType =
cx.getValidator().getValidatedNodeTypeIfKnown(call);
if (returnType == null) {
returnType = cx.getRexBuilder().deriveReturnType(fun, exprs);
}
return cx.getRexBuilder().makeCall(returnType, fun, exprs);
}
public RexNode convertWindowFunction(
SqlRexContext cx,
SqlWindowTableFunction fun,
SqlCall call) {
// The first operand of window function is actually a query, skip that.
final List operands = Util.skip(call.getOperandList(), 1);
final List exprs = convertExpressionList(cx, operands,
SqlOperandTypeChecker.Consistency.NONE);
RelDataType returnType =
cx.getValidator().getValidatedNodeTypeIfKnown(call);
if (returnType == null) {
returnType = cx.getRexBuilder().deriveReturnType(fun, exprs);
}
return cx.getRexBuilder().makeCall(returnType, fun, exprs);
}
public RexNode convertJsonValueFunction(
SqlRexContext cx,
SqlJsonValueFunction fun,
SqlCall call) {
// For Expression with explicit return type:
// i.e. json_value('{"foo":"bar"}', 'lax $.foo', returning varchar(2000))
// use the specified type as the return type.
List operands = call.getOperandList();
boolean hasExplicitReturningType = SqlJsonValueFunction.hasExplicitTypeSpec(
operands.toArray(SqlNode.EMPTY_ARRAY));
if (hasExplicitReturningType) {
operands = SqlJsonValueFunction.removeTypeSpecOperands(call);
}
final List exprs = convertExpressionList(cx, operands,
SqlOperandTypeChecker.Consistency.NONE);
RelDataType returnType =
cx.getValidator().getValidatedNodeTypeIfKnown(call);
return cx.getRexBuilder().makeCall(returnType, fun, exprs);
}
public RexNode convertSequenceValue(
SqlRexContext cx,
SqlSequenceValueOperator fun,
SqlCall call) {
final List operands = call.getOperandList();
assert operands.size() == 1;
assert operands.get(0) instanceof SqlIdentifier;
final SqlIdentifier id = (SqlIdentifier) operands.get(0);
final String key = Util.listToString(id.names);
RelDataType returnType =
cx.getValidator().getValidatedNodeType(call);
return cx.getRexBuilder().makeCall(returnType, fun,
ImmutableList.of(cx.getRexBuilder().makeLiteral(key)));
}
public RexNode convertAggregateFunction(
SqlRexContext cx,
SqlAggFunction fun,
SqlCall call) {
final List operands = call.getOperandList();
final List exprs;
if (call.isCountStar()) {
exprs = ImmutableList.of();
} else {
exprs = convertExpressionList(cx, operands,
SqlOperandTypeChecker.Consistency.NONE);
}
RelDataType returnType =
cx.getValidator().getValidatedNodeTypeIfKnown(call);
final int groupCount = cx.getGroupCount();
if (returnType == null) {
RexCallBinding binding =
new RexCallBinding(cx.getTypeFactory(), fun, exprs,
ImmutableList.of()) {
@Override public int getGroupCount() {
return groupCount;
}
};
returnType = fun.inferReturnType(binding);
}
return cx.getRexBuilder().makeCall(returnType, fun, exprs);
}
private static RexNode makeConstructorCall(
SqlRexContext cx,
SqlFunction constructor,
List exprs) {
final RexBuilder rexBuilder = cx.getRexBuilder();
RelDataType type = rexBuilder.deriveReturnType(constructor, exprs);
int n = type.getFieldCount();
ImmutableList.Builder initializationExprs =
ImmutableList.builder();
final InitializerContext initializerContext = new InitializerContext() {
public RexBuilder getRexBuilder() {
return rexBuilder;
}
public SqlNode validateExpression(RelDataType rowType, SqlNode expr) {
throw new UnsupportedOperationException();
}
public RexNode convertExpression(SqlNode e) {
throw new UnsupportedOperationException();
}
};
for (int i = 0; i < n; ++i) {
initializationExprs.add(
cx.getInitializerExpressionFactory().newAttributeInitializer(
type, constructor, i, exprs, initializerContext));
}
List defaultCasts =
RexUtil.generateCastExpressions(
rexBuilder,
type,
initializationExprs.build());
return rexBuilder.makeNewInvocation(type, defaultCasts);
}
/**
* Converts a call to an operator into a {@link RexCall} to the same
* operator.
*
* Called automatically via reflection.
*
* @param cx Context
* @param call Call
* @return Rex call
*/
public RexNode convertCall(
SqlRexContext cx,
SqlCall call) {
return convertCall(cx, call.getOperator(), call.getOperandList());
}
/** Converts a {@link SqlCall} to a {@link RexCall} with a perhaps different
* operator. */
private RexNode convertCall(
SqlRexContext cx, SqlOperator op, List operands) {
final RexBuilder rexBuilder = cx.getRexBuilder();
final SqlOperandTypeChecker.Consistency consistency =
op.getOperandTypeChecker() == null
? SqlOperandTypeChecker.Consistency.NONE
: op.getOperandTypeChecker().getConsistency();
final List exprs =
convertExpressionList(cx, operands, consistency);
RelDataType type = rexBuilder.deriveReturnType(op, exprs);
return rexBuilder.makeCall(type, op, RexUtil.flatten(exprs, op));
}
private List elseArgs(int count) {
// If list is odd, e.g. [0, 1, 2, 3, 4] we get [1, 3, 4]
// If list is even, e.g. [0, 1, 2, 3, 4, 5] we get [2, 4, 5]
final List list = new ArrayList<>();
for (int i = count % 2;;) {
list.add(i);
i += 2;
if (i >= count) {
list.add(i - 1);
break;
}
}
return list;
}
private static List convertExpressionList(SqlRexContext cx,
List nodes, SqlOperandTypeChecker.Consistency consistency) {
final List exprs = new ArrayList<>();
for (SqlNode node : nodes) {
exprs.add(cx.convertExpression(node));
}
if (exprs.size() > 1) {
final RelDataType type =
consistentType(cx, consistency, RexUtil.types(exprs));
if (type != null) {
final List oldExprs = Lists.newArrayList(exprs);
exprs.clear();
for (RexNode expr : oldExprs) {
exprs.add(cx.getRexBuilder().ensureType(type, expr, true));
}
}
}
return exprs;
}
private static RelDataType consistentType(SqlRexContext cx,
SqlOperandTypeChecker.Consistency consistency, List types) {
switch (consistency) {
case COMPARE:
if (SqlTypeUtil.areSameFamily(types)) {
// All arguments are of same family. No need for explicit casts.
return null;
}
final List nonCharacterTypes = new ArrayList<>();
for (RelDataType type : types) {
if (type.getFamily() != SqlTypeFamily.CHARACTER) {
nonCharacterTypes.add(type);
}
}
if (!nonCharacterTypes.isEmpty()) {
final int typeCount = types.size();
types = nonCharacterTypes;
if (nonCharacterTypes.size() < typeCount) {
final RelDataTypeFamily family =
nonCharacterTypes.get(0).getFamily();
if (family instanceof SqlTypeFamily) {
// The character arguments might be larger than the numeric
// argument. Give ourselves some headroom.
switch ((SqlTypeFamily) family) {
case INTEGER:
case NUMERIC:
nonCharacterTypes.add(
cx.getTypeFactory().createSqlType(SqlTypeName.BIGINT));
}
}
}
}
// fall through
case LEAST_RESTRICTIVE:
return cx.getTypeFactory().leastRestrictive(types);
default:
return null;
}
}
private RexNode convertPlus(SqlRexContext cx, SqlCall call) {
final RexNode rex = convertCall(cx, call);
switch (rex.getType().getSqlTypeName()) {
case DATE:
case TIME:
case TIMESTAMP:
// Use special "+" operator for datetime + interval.
// Re-order operands, if necessary, so that interval is second.
final RexBuilder rexBuilder = cx.getRexBuilder();
List operands = ((RexCall) rex).getOperands();
if (operands.size() == 2) {
final SqlTypeName sqlTypeName = operands.get(0).getType().getSqlTypeName();
switch (sqlTypeName) {
case INTERVAL_YEAR:
case INTERVAL_YEAR_MONTH:
case INTERVAL_MONTH:
case INTERVAL_DAY:
case INTERVAL_DAY_HOUR:
case INTERVAL_DAY_MINUTE:
case INTERVAL_DAY_SECOND:
case INTERVAL_HOUR:
case INTERVAL_HOUR_MINUTE:
case INTERVAL_HOUR_SECOND:
case INTERVAL_MINUTE:
case INTERVAL_MINUTE_SECOND:
case INTERVAL_SECOND:
operands = ImmutableList.of(operands.get(1), operands.get(0));
}
}
return rexBuilder.makeCall(rex.getType(),
SqlStdOperatorTable.DATETIME_PLUS, operands);
default:
return rex;
}
}
private RexNode convertIsDistinctFrom(
SqlRexContext cx,
SqlCall call,
boolean neg) {
RexNode op0 = cx.convertExpression(call.operand(0));
RexNode op1 = cx.convertExpression(call.operand(1));
return RelOptUtil.isDistinctFrom(
cx.getRexBuilder(), op0, op1, neg);
}
/**
* Converts a BETWEEN expression.
*
* Called automatically via reflection.
*/
public RexNode convertBetween(
SqlRexContext cx,
SqlBetweenOperator op,
SqlCall call) {
final List list =
convertExpressionList(cx, call.getOperandList(),
op.getOperandTypeChecker().getConsistency());
final RexNode x = list.get(SqlBetweenOperator.VALUE_OPERAND);
final RexNode y = list.get(SqlBetweenOperator.LOWER_OPERAND);
final RexNode z = list.get(SqlBetweenOperator.UPPER_OPERAND);
final RexBuilder rexBuilder = cx.getRexBuilder();
RexNode ge1 = ge(rexBuilder, x, y);
RexNode le1 = le(rexBuilder, x, z);
RexNode and1 = and(rexBuilder, ge1, le1);
RexNode res;
final SqlBetweenOperator.Flag symmetric = op.flag;
switch (symmetric) {
case ASYMMETRIC:
res = and1;
break;
case SYMMETRIC:
RexNode ge2 = ge(rexBuilder, x, z);
RexNode le2 = le(rexBuilder, x, y);
RexNode and2 = and(rexBuilder, ge2, le2);
res = or(rexBuilder, and1, and2);
break;
default:
throw Util.unexpected(symmetric);
}
final SqlBetweenOperator betweenOp =
(SqlBetweenOperator) call.getOperator();
if (betweenOp.isNegated()) {
res = rexBuilder.makeCall(SqlStdOperatorTable.NOT, res);
}
return res;
}
/**
* Converts a LiteralChain expression: that is, concatenates the operands
* immediately, to produce a single literal string.
*
* Called automatically via reflection.
*/
public RexNode convertLiteralChain(
SqlRexContext cx,
SqlLiteralChainOperator op,
SqlCall call) {
Util.discard(cx);
SqlLiteral sum = SqlLiteralChainOperator.concatenateOperands(call);
return cx.convertLiteral(sum);
}
/**
* Converts a ROW.
*
*
Called automatically via reflection.
*/
public RexNode convertRow(
SqlRexContext cx,
SqlRowOperator op,
SqlCall call) {
if (cx.getValidator().getValidatedNodeType(call).getSqlTypeName()
!= SqlTypeName.COLUMN_LIST) {
return convertCall(cx, call);
}
final RexBuilder rexBuilder = cx.getRexBuilder();
final List columns = new ArrayList<>();
for (SqlNode operand : call.getOperandList()) {
columns.add(
rexBuilder.makeLiteral(
((SqlIdentifier) operand).getSimple()));
}
final RelDataType type =
rexBuilder.deriveReturnType(SqlStdOperatorTable.COLUMN_LIST, columns);
return rexBuilder.makeCall(type, SqlStdOperatorTable.COLUMN_LIST, columns);
}
/**
* Converts a call to OVERLAPS.
*
* Called automatically via reflection.
*/
public RexNode convertOverlaps(
SqlRexContext cx,
SqlOverlapsOperator op,
SqlCall call) {
// for intervals [t0, t1] overlaps [t2, t3], we can find if the
// intervals overlaps by: ~(t1 < t2 or t3 < t0)
assert call.getOperandList().size() == 2;
final Pair left =
convertOverlapsOperand(cx, call.getParserPosition(), call.operand(0));
final RexNode r0 = left.left;
final RexNode r1 = left.right;
final Pair right =
convertOverlapsOperand(cx, call.getParserPosition(), call.operand(1));
final RexNode r2 = right.left;
final RexNode r3 = right.right;
// Sort end points into start and end, such that (s0 <= e0) and (s1 <= e1).
final RexBuilder rexBuilder = cx.getRexBuilder();
RexNode leftSwap = le(rexBuilder, r0, r1);
final RexNode s0 = case_(rexBuilder, leftSwap, r0, r1);
final RexNode e0 = case_(rexBuilder, leftSwap, r1, r0);
RexNode rightSwap = le(rexBuilder, r2, r3);
final RexNode s1 = case_(rexBuilder, rightSwap, r2, r3);
final RexNode e1 = case_(rexBuilder, rightSwap, r3, r2);
// (e0 >= s1) AND (e1 >= s0)
switch (op.kind) {
case OVERLAPS:
return and(rexBuilder,
ge(rexBuilder, e0, s1),
ge(rexBuilder, e1, s0));
case CONTAINS:
return and(rexBuilder,
le(rexBuilder, s0, s1),
ge(rexBuilder, e0, e1));
case PERIOD_EQUALS:
return and(rexBuilder,
eq(rexBuilder, s0, s1),
eq(rexBuilder, e0, e1));
case PRECEDES:
return le(rexBuilder, e0, s1);
case IMMEDIATELY_PRECEDES:
return eq(rexBuilder, e0, s1);
case SUCCEEDS:
return ge(rexBuilder, s0, e1);
case IMMEDIATELY_SUCCEEDS:
return eq(rexBuilder, s0, e1);
default:
throw new AssertionError(op);
}
}
private Pair convertOverlapsOperand(SqlRexContext cx,
SqlParserPos pos, SqlNode operand) {
final SqlNode a0;
final SqlNode a1;
switch (operand.getKind()) {
case ROW:
a0 = ((SqlCall) operand).operand(0);
final SqlNode a10 = ((SqlCall) operand).operand(1);
final RelDataType t1 = cx.getValidator().getValidatedNodeType(a10);
if (SqlTypeUtil.isInterval(t1)) {
// make t1 = t0 + t1 when t1 is an interval.
a1 = plus(pos, a0, a10);
} else {
a1 = a10;
}
break;
default:
a0 = operand;
a1 = operand;
}
final RexNode r0 = cx.convertExpression(a0);
final RexNode r1 = cx.convertExpression(a1);
return Pair.of(r0, r1);
}
/**
* Casts a RexNode value to the validated type of a SqlCall. If the value
* was already of the validated type, then the value is returned without an
* additional cast.
*/
public RexNode castToValidatedType(
SqlRexContext cx,
SqlCall call,
RexNode value) {
return castToValidatedType(call, value, cx.getValidator(),
cx.getRexBuilder());
}
/**
* Casts a RexNode value to the validated type of a SqlCall. If the value
* was already of the validated type, then the value is returned without an
* additional cast.
*/
public static RexNode castToValidatedType(SqlNode node, RexNode e,
SqlValidator validator, RexBuilder rexBuilder) {
final RelDataType type = validator.getValidatedNodeType(node);
if (e.getType() == type) {
return e;
}
return rexBuilder.makeCast(type, e);
}
/** Convertlet that handles {@code COVAR_POP}, {@code COVAR_SAMP},
* {@code REGR_SXX}, {@code REGR_SYY} windowed aggregate functions.
*/
private static class RegrCovarianceConvertlet implements SqlRexConvertlet {
private final SqlKind kind;
RegrCovarianceConvertlet(SqlKind kind) {
this.kind = kind;
}
public RexNode convertCall(SqlRexContext cx, SqlCall call) {
assert call.operandCount() == 2;
final SqlNode arg1 = call.operand(0);
final SqlNode arg2 = call.operand(1);
final SqlNode expr;
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
switch (kind) {
case COVAR_POP:
expr = expandCovariance(arg1, arg2, null, type, cx, true);
break;
case COVAR_SAMP:
expr = expandCovariance(arg1, arg2, null, type, cx, false);
break;
case REGR_SXX:
expr = expandRegrSzz(arg2, arg1, type, cx, true);
break;
case REGR_SYY:
expr = expandRegrSzz(arg1, arg2, type, cx, true);
break;
default:
throw Util.unexpected(kind);
}
RexNode rex = cx.convertExpression(expr);
return cx.getRexBuilder().ensureType(type, rex, true);
}
private SqlNode expandRegrSzz(
final SqlNode arg1, final SqlNode arg2,
final RelDataType avgType, final SqlRexContext cx, boolean variance) {
final SqlParserPos pos = SqlParserPos.ZERO;
final SqlNode count =
SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg1, arg2);
final SqlNode varPop =
expandCovariance(arg1, variance ? arg1 : arg2, arg2, avgType, cx, true);
final RexNode varPopRex = cx.convertExpression(varPop);
final SqlNode varPopCast;
varPopCast = getCastedSqlNode(varPop, avgType, pos, varPopRex);
return SqlStdOperatorTable.MULTIPLY.createCall(pos, varPopCast, count);
}
private SqlNode expandCovariance(
final SqlNode arg0Input,
final SqlNode arg1Input,
final SqlNode dependent,
final RelDataType varType,
final SqlRexContext cx,
boolean biased) {
// covar_pop(x1, x2) ==>
// (sum(x1 * x2) - sum(x2) * sum(x1) / count(x1, x2))
// / count(x1, x2)
//
// covar_samp(x1, x2) ==>
// (sum(x1 * x2) - sum(x1) * sum(x2) / count(x1, x2))
// / (count(x1, x2) - 1)
final SqlParserPos pos = SqlParserPos.ZERO;
final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
final RexNode arg0Rex = cx.convertExpression(arg0Input);
final RexNode arg1Rex = cx.convertExpression(arg1Input);
final SqlNode arg0 = getCastedSqlNode(arg0Input, varType, pos, arg0Rex);
final SqlNode arg1 = getCastedSqlNode(arg1Input, varType, pos, arg1Rex);
final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg0, arg1);
final SqlNode sumArgSquared;
final SqlNode sum0;
final SqlNode sum1;
final SqlNode count;
if (dependent == null) {
sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared);
sum0 = SqlStdOperatorTable.SUM.createCall(pos, arg0, arg1);
sum1 = SqlStdOperatorTable.SUM.createCall(pos, arg1, arg0);
count = SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg0, arg1);
} else {
sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared, dependent);
sum0 = SqlStdOperatorTable.SUM.createCall(
pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent);
sum1 = SqlStdOperatorTable.SUM.createCall(
pos, arg1, Objects.equals(dependent, arg1Input) ? arg0 : dependent);
count = SqlStdOperatorTable.REGR_COUNT.createCall(
pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent);
}
final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum0, sum1);
final SqlNode countCasted =
getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
final SqlNode avgSumSquared =
SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, countCasted);
final SqlNode diff = SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared);
SqlNode denominator;
if (biased) {
denominator = countCasted;
} else {
final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
denominator = new SqlCase(SqlParserPos.ZERO, countCasted,
SqlNodeList.of(SqlStdOperatorTable.EQUALS.createCall(pos, countCasted, one)),
SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)),
SqlStdOperatorTable.MINUS.createCall(pos, countCasted, one));
}
return SqlStdOperatorTable.DIVIDE.createCall(pos, diff, denominator);
}
private SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType,
SqlParserPos pos, RexNode argRex) {
SqlNode arg;
if (argRex != null && !argRex.getType().equals(varType)) {
arg = SqlStdOperatorTable.CAST.createCall(
pos, argInput, SqlTypeUtil.convertTypeToSpec(varType));
} else {
arg = argInput;
}
return arg;
}
}
/** Convertlet that handles {@code AVG} and {@code VARIANCE}
* windowed aggregate functions. */
private static class AvgVarianceConvertlet implements SqlRexConvertlet {
private final SqlKind kind;
AvgVarianceConvertlet(SqlKind kind) {
this.kind = kind;
}
public RexNode convertCall(SqlRexContext cx, SqlCall call) {
assert call.operandCount() == 1;
final SqlNode arg = call.operand(0);
final SqlNode expr;
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
switch (kind) {
case AVG:
expr = expandAvg(arg, type, cx);
break;
case STDDEV_POP:
expr = expandVariance(arg, type, cx, true, true);
break;
case STDDEV_SAMP:
expr = expandVariance(arg, type, cx, false, true);
break;
case VAR_POP:
expr = expandVariance(arg, type, cx, true, false);
break;
case VAR_SAMP:
expr = expandVariance(arg, type, cx, false, false);
break;
default:
throw Util.unexpected(kind);
}
RexNode rex = cx.convertExpression(expr);
return cx.getRexBuilder().ensureType(type, rex, true);
}
private SqlNode expandAvg(
final SqlNode arg, final RelDataType avgType, final SqlRexContext cx) {
final SqlParserPos pos = SqlParserPos.ZERO;
final SqlNode sum =
SqlStdOperatorTable.SUM.createCall(pos, arg);
final RexNode sumRex = cx.convertExpression(sum);
final SqlNode sumCast;
sumCast = getCastedSqlNode(sum, avgType, pos, sumRex);
final SqlNode count =
SqlStdOperatorTable.COUNT.createCall(pos, arg);
return SqlStdOperatorTable.DIVIDE.createCall(
pos, sumCast, count);
}
private SqlNode expandVariance(
final SqlNode argInput,
final RelDataType varType,
final SqlRexContext cx,
boolean biased,
boolean sqrt) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / (count(x) - 1),
// .5)
//
// var_pop(x) ==>
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x)
//
// var_samp(x) ==>
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / (count(x) - 1)
final SqlParserPos pos = SqlParserPos.ZERO;
final SqlNode arg = getCastedSqlNode(argInput, varType, pos, cx.convertExpression(argInput));
final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
final SqlNode argSquaredCasted =
getCastedSqlNode(argSquared, varType, pos, cx.convertExpression(argSquared));
final SqlNode sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquaredCasted);
final SqlNode sumArgSquaredCasted =
getCastedSqlNode(sumArgSquared, varType, pos, cx.convertExpression(sumArgSquared));
final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
final SqlNode sumCasted = getCastedSqlNode(sum, varType, pos, cx.convertExpression(sum));
final SqlNode sumSquared =
SqlStdOperatorTable.MULTIPLY.createCall(pos, sumCasted, sumCasted);
final SqlNode sumSquaredCasted =
getCastedSqlNode(sumSquared, varType, pos, cx.convertExpression(sumSquared));
final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
final SqlNode countCasted =
getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
final SqlNode avgSumSquared =
SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquaredCasted, countCasted);
final SqlNode avgSumSquaredCasted =
getCastedSqlNode(avgSumSquared, varType, pos, cx.convertExpression(avgSumSquared));
final SqlNode diff =
SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquaredCasted, avgSumSquaredCasted);
final SqlNode diffCasted =
getCastedSqlNode(diff, varType, pos, cx.convertExpression(diff));
final SqlNode denominator;
if (biased) {
denominator = countCasted;
} else {
final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
denominator = new SqlCase(SqlParserPos.ZERO,
count,
SqlNodeList.of(SqlStdOperatorTable.EQUALS.createCall(pos, count, one)),
SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)),
SqlStdOperatorTable.MINUS.createCall(pos, count, one));
}
final SqlNode div =
SqlStdOperatorTable.DIVIDE.createCall(pos, diffCasted, denominator);
final SqlNode divCasted = getCastedSqlNode(div, varType, pos, cx.convertExpression(div));
SqlNode result = div;
if (sqrt) {
final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos);
result = SqlStdOperatorTable.POWER.createCall(pos, divCasted, half);
}
return result;
}
private SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType,
SqlParserPos pos, RexNode argRex) {
SqlNode arg;
if (argRex != null && !argRex.getType().equals(varType)) {
arg = SqlStdOperatorTable.CAST.createCall(
pos, argInput, SqlTypeUtil.convertTypeToSpec(varType));
} else {
arg = argInput;
}
return arg;
}
}
/** Convertlet that converts {@code LTRIM} and {@code RTRIM} to
* {@code TRIM}. */
private static class TrimConvertlet implements SqlRexConvertlet {
private final SqlTrimFunction.Flag flag;
TrimConvertlet(SqlTrimFunction.Flag flag) {
this.flag = flag;
}
public RexNode convertCall(SqlRexContext cx, SqlCall call) {
final RexBuilder rexBuilder = cx.getRexBuilder();
final RexNode operand =
cx.convertExpression(call.getOperandList().get(0));
return rexBuilder.makeCall(SqlStdOperatorTable.TRIM,
rexBuilder.makeFlag(flag), rexBuilder.makeLiteral(" "), operand);
}
}
/** Convertlet that converts {@code GREATEST} and {@code LEAST}. */
private static class GreatestConvertlet implements SqlRexConvertlet {
public RexNode convertCall(SqlRexContext cx, SqlCall call) {
// Translate
// GREATEST(a, b, c, d)
// to
// CASE
// WHEN a IS NULL OR b IS NULL OR c IS NULL OR d IS NULL
// THEN NULL
// WHEN a > b AND a > c AND a > d
// THEN a
// WHEN b > c AND b > d
// THEN b
// WHEN c > d
// THEN c
// ELSE d
// END
final RexBuilder rexBuilder = cx.getRexBuilder();
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
final SqlBinaryOperator op;
switch (call.getKind()) {
case GREATEST:
op = SqlStdOperatorTable.GREATER_THAN;
break;
case LEAST:
op = SqlStdOperatorTable.LESS_THAN;
break;
default:
throw new AssertionError();
}
final List exprs = convertExpressionList(cx,
call.getOperandList(), SqlOperandTypeChecker.Consistency.NONE);
final List list = new ArrayList<>();
final List orList = new ArrayList<>();
for (RexNode expr : exprs) {
orList.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, expr));
}
list.add(RexUtil.composeDisjunction(rexBuilder, orList));
list.add(rexBuilder.makeNullLiteral(type));
for (int i = 0; i < exprs.size() - 1; i++) {
RexNode expr = exprs.get(i);
final List andList = new ArrayList<>();
for (int j = i + 1; j < exprs.size(); j++) {
final RexNode expr2 = exprs.get(j);
andList.add(rexBuilder.makeCall(op, expr, expr2));
}
list.add(RexUtil.composeConjunction(rexBuilder, andList));
list.add(expr);
}
list.add(exprs.get(exprs.size() - 1));
return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, list);
}
}
/** Convertlet that handles {@code FLOOR} and {@code CEIL} functions. */
private class FloorCeilConvertlet implements SqlRexConvertlet {
public RexNode convertCall(SqlRexContext cx, SqlCall call) {
return convertFloorCeil(cx, call);
}
}
/** Convertlet that handles the {@code TIMESTAMPADD} function. */
private static class TimestampAddConvertlet implements SqlRexConvertlet {
public RexNode convertCall(SqlRexContext cx, SqlCall call) {
// TIMESTAMPADD(unit, count, timestamp)
// => timestamp + count * INTERVAL '1' UNIT
final RexBuilder rexBuilder = cx.getRexBuilder();
final SqlLiteral unitLiteral = call.operand(0);
final TimeUnit unit = unitLiteral.symbolValue(TimeUnit.class);
RexNode interval2Add;
SqlIntervalQualifier qualifier =
new SqlIntervalQualifier(unit, null, unitLiteral.getParserPosition());
RexNode op1 = cx.convertExpression(call.operand(1));
switch (unit) {
case MICROSECOND:
case NANOSECOND:
interval2Add =
divide(rexBuilder,
multiply(rexBuilder,
rexBuilder.makeIntervalLiteral(BigDecimal.ONE, qualifier), op1),
BigDecimal.ONE.divide(unit.multiplier,
RoundingMode.UNNECESSARY));
break;
default:
interval2Add = multiply(rexBuilder,
rexBuilder.makeIntervalLiteral(unit.multiplier, qualifier), op1);
}
return rexBuilder.makeCall(SqlStdOperatorTable.DATETIME_PLUS,
cx.convertExpression(call.operand(2)), interval2Add);
}
}
/** Convertlet that handles the {@code TIMESTAMPDIFF} function. */
private static class TimestampDiffConvertlet implements SqlRexConvertlet {
public RexNode convertCall(SqlRexContext cx, SqlCall call) {
// TIMESTAMPDIFF(unit, t1, t2)
// => (t2 - t1) UNIT
final RexBuilder rexBuilder = cx.getRexBuilder();
final SqlLiteral unitLiteral = call.operand(0);
TimeUnit unit = unitLiteral.symbolValue(TimeUnit.class);
BigDecimal multiplier = BigDecimal.ONE;
BigDecimal divider = BigDecimal.ONE;
SqlTypeName sqlTypeName = unit == TimeUnit.NANOSECOND
? SqlTypeName.BIGINT
: SqlTypeName.INTEGER;
switch (unit) {
case MICROSECOND:
case MILLISECOND:
case NANOSECOND:
case WEEK:
multiplier = BigDecimal.valueOf(DateTimeUtils.MILLIS_PER_SECOND);
divider = unit.multiplier;
unit = TimeUnit.SECOND;
break;
case QUARTER:
divider = unit.multiplier;
unit = TimeUnit.MONTH;
break;
}
final SqlIntervalQualifier qualifier =
new SqlIntervalQualifier(unit, null, SqlParserPos.ZERO);
final RexNode op2 = cx.convertExpression(call.operand(2));
final RexNode op1 = cx.convertExpression(call.operand(1));
final RelDataType intervalType =
cx.getTypeFactory().createTypeWithNullability(
cx.getTypeFactory().createSqlIntervalType(qualifier),
op1.getType().isNullable() || op2.getType().isNullable());
final RexCall rexCall = (RexCall) rexBuilder.makeCall(
intervalType, SqlStdOperatorTable.MINUS_DATE,
ImmutableList.of(op2, op1));
final RelDataType intType =
cx.getTypeFactory().createTypeWithNullability(
cx.getTypeFactory().createSqlType(sqlTypeName),
SqlTypeUtil.containsNullable(rexCall.getType()));
RexNode e = rexBuilder.makeCast(intType, rexCall);
return rexBuilder.multiplyDivide(e, multiplier, divider);
}
}
}