All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hazelcast.org.apache.calcite.sql2rel.StandardConvertletTable Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.hazelcast.org.apache.calcite.sql2rel;

import com.hazelcast.org.apache.calcite.avatica.util.DateTimeUtils;
import com.hazelcast.org.apache.calcite.avatica.util.TimeUnit;
import com.hazelcast.org.apache.calcite.plan.RelOptUtil;
import com.hazelcast.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeFactory;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeFamily;
import com.hazelcast.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.org.apache.calcite.rex.RexCall;
import com.hazelcast.org.apache.calcite.rex.RexCallBinding;
import com.hazelcast.org.apache.calcite.rex.RexLiteral;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexRangeRef;
import com.hazelcast.org.apache.calcite.rex.RexUtil;
import com.hazelcast.org.apache.calcite.sql.SqlAggFunction;
import com.hazelcast.org.apache.calcite.sql.SqlBinaryOperator;
import com.hazelcast.org.apache.calcite.sql.SqlCall;
import com.hazelcast.org.apache.calcite.sql.SqlDataTypeSpec;
import com.hazelcast.org.apache.calcite.sql.SqlFunction;
import com.hazelcast.org.apache.calcite.sql.SqlFunctionCategory;
import com.hazelcast.org.apache.calcite.sql.SqlIdentifier;
import com.hazelcast.org.apache.calcite.sql.SqlIntervalLiteral;
import com.hazelcast.org.apache.calcite.sql.SqlIntervalQualifier;
import com.hazelcast.org.apache.calcite.sql.SqlJdbcFunctionCall;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.sql.SqlLiteral;
import com.hazelcast.org.apache.calcite.sql.SqlNode;
import com.hazelcast.org.apache.calcite.sql.SqlNodeList;
import com.hazelcast.org.apache.calcite.sql.SqlNumericLiteral;
import com.hazelcast.org.apache.calcite.sql.SqlOperator;
import com.hazelcast.org.apache.calcite.sql.SqlUtil;
import com.hazelcast.org.apache.calcite.sql.SqlWindowTableFunction;
import com.hazelcast.org.apache.calcite.sql.fun.SqlArrayValueConstructor;
import com.hazelcast.org.apache.calcite.sql.fun.SqlBetweenOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlCase;
import com.hazelcast.org.apache.calcite.sql.fun.SqlDatetimeSubtractionOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlExtractFunction;
import com.hazelcast.org.apache.calcite.sql.fun.SqlJsonValueFunction;
import com.hazelcast.org.apache.calcite.sql.fun.SqlLibraryOperators;
import com.hazelcast.org.apache.calcite.sql.fun.SqlLiteralChainOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlMapValueConstructor;
import com.hazelcast.org.apache.calcite.sql.fun.SqlMultisetQueryConstructor;
import com.hazelcast.org.apache.calcite.sql.fun.SqlMultisetValueConstructor;
import com.hazelcast.org.apache.calcite.sql.fun.SqlOverlapsOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlRowOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlSequenceValueOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import com.hazelcast.org.apache.calcite.sql.fun.SqlTrimFunction;
import com.hazelcast.org.apache.calcite.sql.parser.SqlParserPos;
import com.hazelcast.org.apache.calcite.sql.type.SqlOperandTypeChecker;
import com.hazelcast.org.apache.calcite.sql.type.SqlTypeFamily;
import com.hazelcast.org.apache.calcite.sql.type.SqlTypeName;
import com.hazelcast.org.apache.calcite.sql.type.SqlTypeUtil;
import com.hazelcast.org.apache.calcite.sql.validate.SqlValidator;
import com.hazelcast.org.apache.calcite.sql.validate.SqlValidatorImpl;
import com.hazelcast.org.apache.calcite.util.Pair;
import com.hazelcast.org.apache.calcite.util.Util;

import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.com.google.common.collect.Lists;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
 * Standard implementation of {@link SqlRexConvertletTable}.
 */
public class StandardConvertletTable extends ReflectiveConvertletTable {

  /** Singleton instance. */
  public static final StandardConvertletTable INSTANCE =
      new StandardConvertletTable();

  //~ Constructors -----------------------------------------------------------

  private StandardConvertletTable() {
    super();

    // Register aliases (operators which have a different name but
    // identical behavior to other operators).
    addAlias(SqlStdOperatorTable.CHARACTER_LENGTH,
        SqlStdOperatorTable.CHAR_LENGTH);
    addAlias(SqlStdOperatorTable.IS_UNKNOWN,
        SqlStdOperatorTable.IS_NULL);
    addAlias(SqlStdOperatorTable.IS_NOT_UNKNOWN,
        SqlStdOperatorTable.IS_NOT_NULL);
    addAlias(SqlStdOperatorTable.PERCENT_REMAINDER, SqlStdOperatorTable.MOD);

    // Register convertlets for specific objects.
    registerOp(SqlStdOperatorTable.CAST, this::convertCast);
    registerOp(SqlLibraryOperators.INFIX_CAST, this::convertCast);
    registerOp(SqlStdOperatorTable.IS_DISTINCT_FROM,
        (cx, call) -> convertIsDistinctFrom(cx, call, false));
    registerOp(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
        (cx, call) -> convertIsDistinctFrom(cx, call, true));

    registerOp(SqlStdOperatorTable.PLUS, this::convertPlus);

    registerOp(SqlStdOperatorTable.MINUS,
        (cx, call) -> {
          final RexCall e =
              (RexCall) StandardConvertletTable.this.convertCall(cx, call.getOperator(),
                  call.getOperandList());
          switch (e.getOperands().get(0).getType().getSqlTypeName()) {
          case DATE:
          case TIME:
          case TIMESTAMP:
            return convertDatetimeMinus(cx, SqlStdOperatorTable.MINUS_DATE,
                call);
          default:
            return e;
          }
        });

    registerOp(SqlLibraryOperators.LTRIM,
        new TrimConvertlet(SqlTrimFunction.Flag.LEADING));
    registerOp(SqlLibraryOperators.RTRIM,
        new TrimConvertlet(SqlTrimFunction.Flag.TRAILING));

    registerOp(SqlLibraryOperators.GREATEST, new GreatestConvertlet());
    registerOp(SqlLibraryOperators.LEAST, new GreatestConvertlet());

    registerOp(SqlLibraryOperators.NVL,
        (cx, call) -> {
          final RexBuilder rexBuilder = cx.getRexBuilder();
          final RexNode operand0 =
              cx.convertExpression(call.getOperandList().get(0));
          final RexNode operand1 =
              cx.convertExpression(call.getOperandList().get(1));
          final RelDataType type =
              cx.getValidator().getValidatedNodeType(call);
          return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE,
              ImmutableList.of(
                  rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
                      operand0),
                  rexBuilder.makeCast(type, operand0),
                  rexBuilder.makeCast(type, operand1)));
        });

    registerOp(SqlLibraryOperators.DECODE,
        (cx, call) -> {
          final RexBuilder rexBuilder = cx.getRexBuilder();
          final List operands = convertExpressionList(cx,
              call.getOperandList(), SqlOperandTypeChecker.Consistency.NONE);
          final RelDataType type =
              cx.getValidator().getValidatedNodeType(call);
          final List exprs = new ArrayList<>();
          for (int i = 1; i < operands.size() - 1; i += 2) {
            exprs.add(
                RelOptUtil.isDistinctFrom(rexBuilder, operands.get(0),
                    operands.get(i), true));
            exprs.add(operands.get(i + 1));
          }
          if (operands.size() % 2 == 0) {
            exprs.add(Util.last(operands));
          } else {
            exprs.add(rexBuilder.makeNullLiteral(type));
          }
          return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, exprs);
        });

    // Expand "x NOT LIKE y" into "NOT (x LIKE y)"
    registerOp(SqlStdOperatorTable.NOT_LIKE,
        (cx, call) -> cx.convertExpression(
            SqlStdOperatorTable.NOT.createCall(SqlParserPos.ZERO,
                SqlStdOperatorTable.LIKE.createCall(SqlParserPos.ZERO,
                    call.getOperandList()))));

    // Expand "x NOT SIMILAR y" into "NOT (x SIMILAR y)"
    registerOp(SqlStdOperatorTable.NOT_SIMILAR_TO,
        (cx, call) -> cx.convertExpression(
            SqlStdOperatorTable.NOT.createCall(SqlParserPos.ZERO,
                SqlStdOperatorTable.SIMILAR_TO.createCall(SqlParserPos.ZERO,
                    call.getOperandList()))));

    // Unary "+" has no effect, so expand "+ x" into "x".
    registerOp(SqlStdOperatorTable.UNARY_PLUS,
        (cx, call) -> cx.convertExpression(call.operand(0)));

    // "DOT"
    registerOp(SqlStdOperatorTable.DOT,
        (cx, call) -> cx.getRexBuilder().makeFieldAccess(
            cx.convertExpression(call.operand(0)),
            call.operand(1).toString(), false));
    // "AS" has no effect, so expand "x AS id" into "x".
    registerOp(SqlStdOperatorTable.AS,
        (cx, call) -> cx.convertExpression(call.operand(0)));
    // "SQRT(x)" is equivalent to "POWER(x, .5)"
    registerOp(SqlStdOperatorTable.SQRT,
        (cx, call) -> cx.convertExpression(
            SqlStdOperatorTable.POWER.createCall(SqlParserPos.ZERO,
                call.operand(0),
                SqlLiteral.createExactNumeric("0.5", SqlParserPos.ZERO))));

    // REVIEW jvs 24-Apr-2006: This only seems to be working from within a
    // windowed agg.  I have added an optimizer rule
    // com.hazelcast.org.apache.calcite.rel.rules.AggregateReduceFunctionsRule which handles
    // other cases post-translation.  The reason I did that was to defer the
    // implementation decision; e.g. we may want to push it down to a foreign
    // server directly rather than decomposed; decomposition is easier than
    // recognition.

    // Convert "avg()" to "cast(sum() / count() as
    // )". We don't need to handle the empty set specially, because
    // the SUM is already supposed to come out as NULL in cases where the
    // COUNT is zero, so the null check should take place first and prevent
    // division by zero. We need the cast because SUM and COUNT may use
    // different types, say BIGINT.
    //
    // Similarly STDDEV_POP and STDDEV_SAMP, VAR_POP and VAR_SAMP.
    registerOp(SqlStdOperatorTable.AVG,
        new AvgVarianceConvertlet(SqlKind.AVG));
    registerOp(SqlStdOperatorTable.STDDEV_POP,
        new AvgVarianceConvertlet(SqlKind.STDDEV_POP));
    registerOp(SqlStdOperatorTable.STDDEV_SAMP,
        new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
    registerOp(SqlStdOperatorTable.STDDEV,
        new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
    registerOp(SqlStdOperatorTable.VAR_POP,
        new AvgVarianceConvertlet(SqlKind.VAR_POP));
    registerOp(SqlStdOperatorTable.VAR_SAMP,
        new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
    registerOp(SqlStdOperatorTable.VARIANCE,
        new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
    registerOp(SqlStdOperatorTable.COVAR_POP,
        new RegrCovarianceConvertlet(SqlKind.COVAR_POP));
    registerOp(SqlStdOperatorTable.COVAR_SAMP,
        new RegrCovarianceConvertlet(SqlKind.COVAR_SAMP));
    registerOp(SqlStdOperatorTable.REGR_SXX,
        new RegrCovarianceConvertlet(SqlKind.REGR_SXX));
    registerOp(SqlStdOperatorTable.REGR_SYY,
        new RegrCovarianceConvertlet(SqlKind.REGR_SYY));

    final SqlRexConvertlet floorCeilConvertlet = new FloorCeilConvertlet();
    registerOp(SqlStdOperatorTable.FLOOR, floorCeilConvertlet);
    registerOp(SqlStdOperatorTable.CEIL, floorCeilConvertlet);

    registerOp(SqlStdOperatorTable.TIMESTAMP_ADD, new TimestampAddConvertlet());
    registerOp(SqlStdOperatorTable.TIMESTAMP_DIFF,
        new TimestampDiffConvertlet());

    // Convert "element()" to "$element_slice()", if the
    // expression is a multiset of scalars.
    if (false) {
      registerOp(SqlStdOperatorTable.ELEMENT,
          (cx, call) -> {
            assert call.operandCount() == 1;
            final SqlNode operand = call.operand(0);
            final RelDataType type =
                cx.getValidator().getValidatedNodeType(operand);
            if (!type.getComponentType().isStruct()) {
              return cx.convertExpression(
                  SqlStdOperatorTable.ELEMENT_SLICE.createCall(
                      SqlParserPos.ZERO, operand));
            }

            // fallback on default behavior
            return StandardConvertletTable.this.convertCall(cx, call);
          });
    }

    // Convert "$element_slice()" to "element().field#0"
    if (false) {
      registerOp(SqlStdOperatorTable.ELEMENT_SLICE,
          (cx, call) -> {
            assert call.operandCount() == 1;
            final SqlNode operand = call.operand(0);
            final RexNode expr =
                cx.convertExpression(
                    SqlStdOperatorTable.ELEMENT.createCall(SqlParserPos.ZERO,
                        operand));
            return cx.getRexBuilder().makeFieldAccess(expr, 0);
          });
    }
  }

  //~ Methods ----------------------------------------------------------------

  private RexNode or(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
    return rexBuilder.makeCall(SqlStdOperatorTable.OR, a0, a1);
  }

  private RexNode eq(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
    return rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, a0, a1);
  }

  private RexNode ge(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
    return rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, a0,
        a1);
  }

  private RexNode le(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
    return rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, a0, a1);
  }

  private RexNode and(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
    return rexBuilder.makeCall(SqlStdOperatorTable.AND, a0, a1);
  }

  private static RexNode divideInt(RexBuilder rexBuilder, RexNode a0,
      RexNode a1) {
    return rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE_INTEGER, a0, a1);
  }

  private RexNode plus(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
    return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, a0, a1);
  }

  private RexNode minus(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
    return rexBuilder.makeCall(SqlStdOperatorTable.MINUS, a0, a1);
  }

  private static RexNode multiply(RexBuilder rexBuilder, RexNode a0,
      RexNode a1) {
    return rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, a0, a1);
  }

  private RexNode case_(RexBuilder rexBuilder, RexNode... args) {
    return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args);
  }

  // SqlNode helpers

  private SqlCall plus(SqlParserPos pos, SqlNode a0, SqlNode a1) {
    return SqlStdOperatorTable.PLUS.createCall(pos, a0, a1);
  }

  /**
   * Converts a CASE expression.
   */
  public RexNode convertCase(
      SqlRexContext cx,
      SqlCase call) {
    SqlNodeList whenList = call.getWhenOperands();
    SqlNodeList thenList = call.getThenOperands();
    assert whenList.size() == thenList.size();

    RexBuilder rexBuilder = cx.getRexBuilder();
    final List exprList = new ArrayList<>();
    final RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
    final RexLiteral unknownLiteral = rexBuilder.makeNullLiteral(
        typeFactory.createSqlType(SqlTypeName.BOOLEAN));
    final RexLiteral nullLiteral = rexBuilder.makeNullLiteral(
        typeFactory.createSqlType(SqlTypeName.NULL));
    for (int i = 0; i < whenList.size(); i++) {
      if (SqlUtil.isNullLiteral(whenList.get(i), false)) {
        exprList.add(unknownLiteral);
      } else {
        exprList.add(cx.convertExpression(whenList.get(i)));
      }
      if (SqlUtil.isNullLiteral(thenList.get(i), false)) {
        exprList.add(nullLiteral);
      } else {
        exprList.add(cx.convertExpression(thenList.get(i)));
      }
    }
    if (SqlUtil.isNullLiteral(call.getElseOperand(), false)) {
      exprList.add(nullLiteral);
    } else {
      exprList.add(cx.convertExpression(call.getElseOperand()));
    }

    RelDataType type =
        rexBuilder.deriveReturnType(call.getOperator(), exprList);
    for (int i : elseArgs(exprList.size())) {
      exprList.set(i,
          rexBuilder.ensureType(type, exprList.get(i), false));
    }
    return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, exprList);
  }

  public RexNode convertMultiset(
      SqlRexContext cx,
      SqlMultisetValueConstructor op,
      SqlCall call) {
    final RelDataType originalType =
        cx.getValidator().getValidatedNodeType(call);
    RexRangeRef rr = cx.getSubQueryExpr(call);
    assert rr != null;
    RelDataType msType = rr.getType().getFieldList().get(0).getType();
    RexNode expr =
        cx.getRexBuilder().makeInputRef(
            msType,
            rr.getOffset());
    assert msType.getComponentType().isStruct();
    if (!originalType.getComponentType().isStruct()) {
      // If the type is not a struct, the multiset operator will have
      // wrapped the type as a record. Add a call to the $SLICE operator
      // to compensate. For example,
      // if '' has type 'RECORD (INTEGER x) MULTISET',
      // then '$SLICE() has type 'INTEGER MULTISET'.
      // This will be removed as the expression is translated.
      expr =
          cx.getRexBuilder().makeCall(originalType, SqlStdOperatorTable.SLICE,
              ImmutableList.of(expr));
    }
    return expr;
  }

  public RexNode convertArray(
      SqlRexContext cx,
      SqlArrayValueConstructor op,
      SqlCall call) {
    return convertCall(cx, call);
  }

  public RexNode convertMap(
      SqlRexContext cx,
      SqlMapValueConstructor op,
      SqlCall call) {
    return convertCall(cx, call);
  }

  public RexNode convertMultisetQuery(
      SqlRexContext cx,
      SqlMultisetQueryConstructor op,
      SqlCall call) {
    final RelDataType originalType =
        cx.getValidator().getValidatedNodeType(call);
    RexRangeRef rr = cx.getSubQueryExpr(call);
    assert rr != null;
    RelDataType msType = rr.getType().getFieldList().get(0).getType();
    RexNode expr =
        cx.getRexBuilder().makeInputRef(
            msType,
            rr.getOffset());
    assert msType.getComponentType().isStruct();
    if (!originalType.getComponentType().isStruct()) {
      // If the type is not a struct, the multiset operator will have
      // wrapped the type as a record. Add a call to the $SLICE operator
      // to compensate. For example,
      // if '' has type 'RECORD (INTEGER x) MULTISET',
      // then '$SLICE() has type 'INTEGER MULTISET'.
      // This will be removed as the expression is translated.
      expr =
          cx.getRexBuilder().makeCall(SqlStdOperatorTable.SLICE, expr);
    }
    return expr;
  }

  public RexNode convertJdbc(
      SqlRexContext cx,
      SqlJdbcFunctionCall op,
      SqlCall call) {
    // Yuck!! The function definition contains arguments!
    // TODO: adopt a more conventional definition/instance structure
    final SqlCall convertedCall = op.getLookupCall();
    return cx.convertExpression(convertedCall);
  }

  protected RexNode convertCast(
      SqlRexContext cx,
      final SqlCall call) {
    RelDataTypeFactory typeFactory = cx.getTypeFactory();
    assert call.getKind() == SqlKind.CAST;
    final SqlNode left = call.operand(0);
    final SqlNode right = call.operand(1);
    if (right instanceof SqlIntervalQualifier) {
      final SqlIntervalQualifier intervalQualifier =
          (SqlIntervalQualifier) right;
      if (left instanceof SqlIntervalLiteral) {
        RexLiteral sourceInterval =
            (RexLiteral) cx.convertExpression(left);
        BigDecimal sourceValue =
            (BigDecimal) sourceInterval.getValue();
        RexLiteral castedInterval =
            cx.getRexBuilder().makeIntervalLiteral(sourceValue,
                intervalQualifier);
        return castToValidatedType(cx, call, castedInterval);
      } else if (left instanceof SqlNumericLiteral) {
        RexLiteral sourceInterval =
            (RexLiteral) cx.convertExpression(left);
        BigDecimal sourceValue =
            (BigDecimal) sourceInterval.getValue();
        final BigDecimal multiplier = intervalQualifier.getUnit().multiplier;
        sourceValue = sourceValue.multiply(multiplier);
        RexLiteral castedInterval =
            cx.getRexBuilder().makeIntervalLiteral(
                sourceValue,
                intervalQualifier);
        return castToValidatedType(cx, call, castedInterval);
      }
      return castToValidatedType(cx, call, cx.convertExpression(left));
    }
    SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
    RelDataType type = dataType.deriveType(cx.getValidator());
    if (type == null) {
      type = cx.getValidator().getValidatedNodeType(dataType.getTypeName());
    }
    RexNode arg = cx.convertExpression(left);
    if (arg.getType().isNullable()) {
      type = typeFactory.createTypeWithNullability(type, true);
    }
    if (SqlUtil.isNullLiteral(left, false)) {
      final SqlValidatorImpl validator = (SqlValidatorImpl) cx.getValidator();
      validator.setValidatedNodeType(left, type);
      return cx.convertExpression(left);
    }
    if (null != dataType.getCollectionsTypeName()) {
      final RelDataType argComponentType =
          arg.getType().getComponentType();
      final RelDataType componentType = type.getComponentType();
      if (argComponentType.isStruct()
          && !componentType.isStruct()) {
        RelDataType tt =
            typeFactory.builder()
                .add(
                    argComponentType.getFieldList().get(0).getName(),
                    componentType)
                .build();
        tt = typeFactory.createTypeWithNullability(
            tt,
            componentType.isNullable());
        boolean isn = type.isNullable();
        type = typeFactory.createMultisetType(tt, -1);
        type = typeFactory.createTypeWithNullability(type, isn);
      }
    }
    return cx.getRexBuilder().makeCast(type, arg);
  }

  protected RexNode convertFloorCeil(SqlRexContext cx, SqlCall call) {
    final boolean floor = call.getKind() == SqlKind.FLOOR;
    // Rewrite floor, ceil of interval
    if (call.operandCount() == 1
        && call.operand(0) instanceof SqlIntervalLiteral) {
      final SqlIntervalLiteral literal = call.operand(0);
      SqlIntervalLiteral.IntervalValue interval =
          (SqlIntervalLiteral.IntervalValue) literal.getValue();
      BigDecimal val =
          interval.getIntervalQualifier().getStartUnit().multiplier;
      RexNode rexInterval = cx.convertExpression(literal);

      final RexBuilder rexBuilder = cx.getRexBuilder();
      RexNode zero = rexBuilder.makeExactLiteral(BigDecimal.valueOf(0));
      RexNode cond = ge(rexBuilder, rexInterval, zero);

      RexNode pad =
          rexBuilder.makeExactLiteral(val.subtract(BigDecimal.ONE));
      RexNode cast = rexBuilder.makeReinterpretCast(
          rexInterval.getType(), pad, rexBuilder.makeLiteral(false));
      RexNode sum = floor
          ? minus(rexBuilder, rexInterval, cast)
          : plus(rexBuilder, rexInterval, cast);

      RexNode kase = floor
          ? case_(rexBuilder, rexInterval, cond, sum)
          : case_(rexBuilder, sum, cond, rexInterval);

      RexNode factor = rexBuilder.makeExactLiteral(val);
      RexNode div = divideInt(rexBuilder, kase, factor);
      return multiply(rexBuilder, div, factor);
    }

    // normal floor, ceil function
    return convertFunction(cx, (SqlFunction) call.getOperator(), call);
  }

  /**
   * Converts a call to the {@code EXTRACT} function.
   *
   * 

Called automatically via reflection. */ public RexNode convertExtract( SqlRexContext cx, SqlExtractFunction op, SqlCall call) { return convertFunction(cx, (SqlFunction) call.getOperator(), call); } private RexNode mod(RexBuilder rexBuilder, RelDataType resType, RexNode res, BigDecimal val) { if (val.equals(BigDecimal.ONE)) { return res; } return rexBuilder.makeCall(SqlStdOperatorTable.MOD, res, rexBuilder.makeExactLiteral(val, resType)); } private static RexNode divide(RexBuilder rexBuilder, RexNode res, BigDecimal val) { if (val.equals(BigDecimal.ONE)) { return res; } // If val is between 0 and 1, rather than divide by val, multiply by its // reciprocal. For example, rather than divide by 0.001 multiply by 1000. if (val.compareTo(BigDecimal.ONE) < 0 && val.signum() == 1) { try { final BigDecimal reciprocal = BigDecimal.ONE.divide(val, RoundingMode.UNNECESSARY); return multiply(rexBuilder, res, rexBuilder.makeExactLiteral(reciprocal)); } catch (ArithmeticException e) { // ignore - reciprocal is not an integer } } return divideInt(rexBuilder, res, rexBuilder.makeExactLiteral(val)); } public RexNode convertDatetimeMinus( SqlRexContext cx, SqlDatetimeSubtractionOperator op, SqlCall call) { // Rewrite datetime minus final RexBuilder rexBuilder = cx.getRexBuilder(); final List operands = call.getOperandList(); final List exprs = convertExpressionList(cx, operands, SqlOperandTypeChecker.Consistency.NONE); final RelDataType resType = cx.getValidator().getValidatedNodeType(call); return rexBuilder.makeCall(resType, op, exprs.subList(0, 2)); } public RexNode convertFunction( SqlRexContext cx, SqlFunction fun, SqlCall call) { final List operands = call.getOperandList(); final List exprs = convertExpressionList(cx, operands, SqlOperandTypeChecker.Consistency.NONE); if (fun.getFunctionType() == SqlFunctionCategory.USER_DEFINED_CONSTRUCTOR) { return makeConstructorCall(cx, fun, exprs); } RelDataType returnType = cx.getValidator().getValidatedNodeTypeIfKnown(call); if (returnType == null) { returnType = cx.getRexBuilder().deriveReturnType(fun, exprs); } return cx.getRexBuilder().makeCall(returnType, fun, exprs); } public RexNode convertWindowFunction( SqlRexContext cx, SqlWindowTableFunction fun, SqlCall call) { // The first operand of window function is actually a query, skip that. final List operands = Util.skip(call.getOperandList(), 1); final List exprs = convertExpressionList(cx, operands, SqlOperandTypeChecker.Consistency.NONE); RelDataType returnType = cx.getValidator().getValidatedNodeTypeIfKnown(call); if (returnType == null) { returnType = cx.getRexBuilder().deriveReturnType(fun, exprs); } return cx.getRexBuilder().makeCall(returnType, fun, exprs); } public RexNode convertJsonValueFunction( SqlRexContext cx, SqlJsonValueFunction fun, SqlCall call) { // For Expression with explicit return type: // i.e. json_value('{"foo":"bar"}', 'lax $.foo', returning varchar(2000)) // use the specified type as the return type. List operands = call.getOperandList(); boolean hasExplicitReturningType = SqlJsonValueFunction.hasExplicitTypeSpec( operands.toArray(SqlNode.EMPTY_ARRAY)); if (hasExplicitReturningType) { operands = SqlJsonValueFunction.removeTypeSpecOperands(call); } final List exprs = convertExpressionList(cx, operands, SqlOperandTypeChecker.Consistency.NONE); RelDataType returnType = cx.getValidator().getValidatedNodeTypeIfKnown(call); return cx.getRexBuilder().makeCall(returnType, fun, exprs); } public RexNode convertSequenceValue( SqlRexContext cx, SqlSequenceValueOperator fun, SqlCall call) { final List operands = call.getOperandList(); assert operands.size() == 1; assert operands.get(0) instanceof SqlIdentifier; final SqlIdentifier id = (SqlIdentifier) operands.get(0); final String key = Util.listToString(id.names); RelDataType returnType = cx.getValidator().getValidatedNodeType(call); return cx.getRexBuilder().makeCall(returnType, fun, ImmutableList.of(cx.getRexBuilder().makeLiteral(key))); } public RexNode convertAggregateFunction( SqlRexContext cx, SqlAggFunction fun, SqlCall call) { final List operands = call.getOperandList(); final List exprs; if (call.isCountStar()) { exprs = ImmutableList.of(); } else { exprs = convertExpressionList(cx, operands, SqlOperandTypeChecker.Consistency.NONE); } RelDataType returnType = cx.getValidator().getValidatedNodeTypeIfKnown(call); final int groupCount = cx.getGroupCount(); if (returnType == null) { RexCallBinding binding = new RexCallBinding(cx.getTypeFactory(), fun, exprs, ImmutableList.of()) { @Override public int getGroupCount() { return groupCount; } }; returnType = fun.inferReturnType(binding); } return cx.getRexBuilder().makeCall(returnType, fun, exprs); } private static RexNode makeConstructorCall( SqlRexContext cx, SqlFunction constructor, List exprs) { final RexBuilder rexBuilder = cx.getRexBuilder(); RelDataType type = rexBuilder.deriveReturnType(constructor, exprs); int n = type.getFieldCount(); ImmutableList.Builder initializationExprs = ImmutableList.builder(); final InitializerContext initializerContext = new InitializerContext() { public RexBuilder getRexBuilder() { return rexBuilder; } public SqlNode validateExpression(RelDataType rowType, SqlNode expr) { throw new UnsupportedOperationException(); } public RexNode convertExpression(SqlNode e) { throw new UnsupportedOperationException(); } }; for (int i = 0; i < n; ++i) { initializationExprs.add( cx.getInitializerExpressionFactory().newAttributeInitializer( type, constructor, i, exprs, initializerContext)); } List defaultCasts = RexUtil.generateCastExpressions( rexBuilder, type, initializationExprs.build()); return rexBuilder.makeNewInvocation(type, defaultCasts); } /** * Converts a call to an operator into a {@link RexCall} to the same * operator. * *

Called automatically via reflection. * * @param cx Context * @param call Call * @return Rex call */ public RexNode convertCall( SqlRexContext cx, SqlCall call) { return convertCall(cx, call.getOperator(), call.getOperandList()); } /** Converts a {@link SqlCall} to a {@link RexCall} with a perhaps different * operator. */ private RexNode convertCall( SqlRexContext cx, SqlOperator op, List operands) { final RexBuilder rexBuilder = cx.getRexBuilder(); final SqlOperandTypeChecker.Consistency consistency = op.getOperandTypeChecker() == null ? SqlOperandTypeChecker.Consistency.NONE : op.getOperandTypeChecker().getConsistency(); final List exprs = convertExpressionList(cx, operands, consistency); RelDataType type = rexBuilder.deriveReturnType(op, exprs); return rexBuilder.makeCall(type, op, RexUtil.flatten(exprs, op)); } private List elseArgs(int count) { // If list is odd, e.g. [0, 1, 2, 3, 4] we get [1, 3, 4] // If list is even, e.g. [0, 1, 2, 3, 4, 5] we get [2, 4, 5] final List list = new ArrayList<>(); for (int i = count % 2;;) { list.add(i); i += 2; if (i >= count) { list.add(i - 1); break; } } return list; } private static List convertExpressionList(SqlRexContext cx, List nodes, SqlOperandTypeChecker.Consistency consistency) { final List exprs = new ArrayList<>(); for (SqlNode node : nodes) { exprs.add(cx.convertExpression(node)); } if (exprs.size() > 1) { final RelDataType type = consistentType(cx, consistency, RexUtil.types(exprs)); if (type != null) { final List oldExprs = Lists.newArrayList(exprs); exprs.clear(); for (RexNode expr : oldExprs) { exprs.add(cx.getRexBuilder().ensureType(type, expr, true)); } } } return exprs; } private static RelDataType consistentType(SqlRexContext cx, SqlOperandTypeChecker.Consistency consistency, List types) { switch (consistency) { case COMPARE: if (SqlTypeUtil.areSameFamily(types)) { // All arguments are of same family. No need for explicit casts. return null; } final List nonCharacterTypes = new ArrayList<>(); for (RelDataType type : types) { if (type.getFamily() != SqlTypeFamily.CHARACTER) { nonCharacterTypes.add(type); } } if (!nonCharacterTypes.isEmpty()) { final int typeCount = types.size(); types = nonCharacterTypes; if (nonCharacterTypes.size() < typeCount) { final RelDataTypeFamily family = nonCharacterTypes.get(0).getFamily(); if (family instanceof SqlTypeFamily) { // The character arguments might be larger than the numeric // argument. Give ourselves some headroom. switch ((SqlTypeFamily) family) { case INTEGER: case NUMERIC: nonCharacterTypes.add( cx.getTypeFactory().createSqlType(SqlTypeName.BIGINT)); } } } } // fall through case LEAST_RESTRICTIVE: return cx.getTypeFactory().leastRestrictive(types); default: return null; } } private RexNode convertPlus(SqlRexContext cx, SqlCall call) { final RexNode rex = convertCall(cx, call); switch (rex.getType().getSqlTypeName()) { case DATE: case TIME: case TIMESTAMP: // Use special "+" operator for datetime + interval. // Re-order operands, if necessary, so that interval is second. final RexBuilder rexBuilder = cx.getRexBuilder(); List operands = ((RexCall) rex).getOperands(); if (operands.size() == 2) { final SqlTypeName sqlTypeName = operands.get(0).getType().getSqlTypeName(); switch (sqlTypeName) { case INTERVAL_YEAR: case INTERVAL_YEAR_MONTH: case INTERVAL_MONTH: case INTERVAL_DAY: case INTERVAL_DAY_HOUR: case INTERVAL_DAY_MINUTE: case INTERVAL_DAY_SECOND: case INTERVAL_HOUR: case INTERVAL_HOUR_MINUTE: case INTERVAL_HOUR_SECOND: case INTERVAL_MINUTE: case INTERVAL_MINUTE_SECOND: case INTERVAL_SECOND: operands = ImmutableList.of(operands.get(1), operands.get(0)); } } return rexBuilder.makeCall(rex.getType(), SqlStdOperatorTable.DATETIME_PLUS, operands); default: return rex; } } private RexNode convertIsDistinctFrom( SqlRexContext cx, SqlCall call, boolean neg) { RexNode op0 = cx.convertExpression(call.operand(0)); RexNode op1 = cx.convertExpression(call.operand(1)); return RelOptUtil.isDistinctFrom( cx.getRexBuilder(), op0, op1, neg); } /** * Converts a BETWEEN expression. * *

Called automatically via reflection. */ public RexNode convertBetween( SqlRexContext cx, SqlBetweenOperator op, SqlCall call) { final List list = convertExpressionList(cx, call.getOperandList(), op.getOperandTypeChecker().getConsistency()); final RexNode x = list.get(SqlBetweenOperator.VALUE_OPERAND); final RexNode y = list.get(SqlBetweenOperator.LOWER_OPERAND); final RexNode z = list.get(SqlBetweenOperator.UPPER_OPERAND); final RexBuilder rexBuilder = cx.getRexBuilder(); RexNode ge1 = ge(rexBuilder, x, y); RexNode le1 = le(rexBuilder, x, z); RexNode and1 = and(rexBuilder, ge1, le1); RexNode res; final SqlBetweenOperator.Flag symmetric = op.flag; switch (symmetric) { case ASYMMETRIC: res = and1; break; case SYMMETRIC: RexNode ge2 = ge(rexBuilder, x, z); RexNode le2 = le(rexBuilder, x, y); RexNode and2 = and(rexBuilder, ge2, le2); res = or(rexBuilder, and1, and2); break; default: throw Util.unexpected(symmetric); } final SqlBetweenOperator betweenOp = (SqlBetweenOperator) call.getOperator(); if (betweenOp.isNegated()) { res = rexBuilder.makeCall(SqlStdOperatorTable.NOT, res); } return res; } /** * Converts a LiteralChain expression: that is, concatenates the operands * immediately, to produce a single literal string. * *

Called automatically via reflection. */ public RexNode convertLiteralChain( SqlRexContext cx, SqlLiteralChainOperator op, SqlCall call) { Util.discard(cx); SqlLiteral sum = SqlLiteralChainOperator.concatenateOperands(call); return cx.convertLiteral(sum); } /** * Converts a ROW. * *

Called automatically via reflection. */ public RexNode convertRow( SqlRexContext cx, SqlRowOperator op, SqlCall call) { if (cx.getValidator().getValidatedNodeType(call).getSqlTypeName() != SqlTypeName.COLUMN_LIST) { return convertCall(cx, call); } final RexBuilder rexBuilder = cx.getRexBuilder(); final List columns = new ArrayList<>(); for (SqlNode operand : call.getOperandList()) { columns.add( rexBuilder.makeLiteral( ((SqlIdentifier) operand).getSimple())); } final RelDataType type = rexBuilder.deriveReturnType(SqlStdOperatorTable.COLUMN_LIST, columns); return rexBuilder.makeCall(type, SqlStdOperatorTable.COLUMN_LIST, columns); } /** * Converts a call to OVERLAPS. * *

Called automatically via reflection. */ public RexNode convertOverlaps( SqlRexContext cx, SqlOverlapsOperator op, SqlCall call) { // for intervals [t0, t1] overlaps [t2, t3], we can find if the // intervals overlaps by: ~(t1 < t2 or t3 < t0) assert call.getOperandList().size() == 2; final Pair left = convertOverlapsOperand(cx, call.getParserPosition(), call.operand(0)); final RexNode r0 = left.left; final RexNode r1 = left.right; final Pair right = convertOverlapsOperand(cx, call.getParserPosition(), call.operand(1)); final RexNode r2 = right.left; final RexNode r3 = right.right; // Sort end points into start and end, such that (s0 <= e0) and (s1 <= e1). final RexBuilder rexBuilder = cx.getRexBuilder(); RexNode leftSwap = le(rexBuilder, r0, r1); final RexNode s0 = case_(rexBuilder, leftSwap, r0, r1); final RexNode e0 = case_(rexBuilder, leftSwap, r1, r0); RexNode rightSwap = le(rexBuilder, r2, r3); final RexNode s1 = case_(rexBuilder, rightSwap, r2, r3); final RexNode e1 = case_(rexBuilder, rightSwap, r3, r2); // (e0 >= s1) AND (e1 >= s0) switch (op.kind) { case OVERLAPS: return and(rexBuilder, ge(rexBuilder, e0, s1), ge(rexBuilder, e1, s0)); case CONTAINS: return and(rexBuilder, le(rexBuilder, s0, s1), ge(rexBuilder, e0, e1)); case PERIOD_EQUALS: return and(rexBuilder, eq(rexBuilder, s0, s1), eq(rexBuilder, e0, e1)); case PRECEDES: return le(rexBuilder, e0, s1); case IMMEDIATELY_PRECEDES: return eq(rexBuilder, e0, s1); case SUCCEEDS: return ge(rexBuilder, s0, e1); case IMMEDIATELY_SUCCEEDS: return eq(rexBuilder, s0, e1); default: throw new AssertionError(op); } } private Pair convertOverlapsOperand(SqlRexContext cx, SqlParserPos pos, SqlNode operand) { final SqlNode a0; final SqlNode a1; switch (operand.getKind()) { case ROW: a0 = ((SqlCall) operand).operand(0); final SqlNode a10 = ((SqlCall) operand).operand(1); final RelDataType t1 = cx.getValidator().getValidatedNodeType(a10); if (SqlTypeUtil.isInterval(t1)) { // make t1 = t0 + t1 when t1 is an interval. a1 = plus(pos, a0, a10); } else { a1 = a10; } break; default: a0 = operand; a1 = operand; } final RexNode r0 = cx.convertExpression(a0); final RexNode r1 = cx.convertExpression(a1); return Pair.of(r0, r1); } /** * Casts a RexNode value to the validated type of a SqlCall. If the value * was already of the validated type, then the value is returned without an * additional cast. */ public RexNode castToValidatedType( SqlRexContext cx, SqlCall call, RexNode value) { return castToValidatedType(call, value, cx.getValidator(), cx.getRexBuilder()); } /** * Casts a RexNode value to the validated type of a SqlCall. If the value * was already of the validated type, then the value is returned without an * additional cast. */ public static RexNode castToValidatedType(SqlNode node, RexNode e, SqlValidator validator, RexBuilder rexBuilder) { final RelDataType type = validator.getValidatedNodeType(node); if (e.getType() == type) { return e; } return rexBuilder.makeCast(type, e); } /** Convertlet that handles {@code COVAR_POP}, {@code COVAR_SAMP}, * {@code REGR_SXX}, {@code REGR_SYY} windowed aggregate functions. */ private static class RegrCovarianceConvertlet implements SqlRexConvertlet { private final SqlKind kind; RegrCovarianceConvertlet(SqlKind kind) { this.kind = kind; } public RexNode convertCall(SqlRexContext cx, SqlCall call) { assert call.operandCount() == 2; final SqlNode arg1 = call.operand(0); final SqlNode arg2 = call.operand(1); final SqlNode expr; final RelDataType type = cx.getValidator().getValidatedNodeType(call); switch (kind) { case COVAR_POP: expr = expandCovariance(arg1, arg2, null, type, cx, true); break; case COVAR_SAMP: expr = expandCovariance(arg1, arg2, null, type, cx, false); break; case REGR_SXX: expr = expandRegrSzz(arg2, arg1, type, cx, true); break; case REGR_SYY: expr = expandRegrSzz(arg1, arg2, type, cx, true); break; default: throw Util.unexpected(kind); } RexNode rex = cx.convertExpression(expr); return cx.getRexBuilder().ensureType(type, rex, true); } private SqlNode expandRegrSzz( final SqlNode arg1, final SqlNode arg2, final RelDataType avgType, final SqlRexContext cx, boolean variance) { final SqlParserPos pos = SqlParserPos.ZERO; final SqlNode count = SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg1, arg2); final SqlNode varPop = expandCovariance(arg1, variance ? arg1 : arg2, arg2, avgType, cx, true); final RexNode varPopRex = cx.convertExpression(varPop); final SqlNode varPopCast; varPopCast = getCastedSqlNode(varPop, avgType, pos, varPopRex); return SqlStdOperatorTable.MULTIPLY.createCall(pos, varPopCast, count); } private SqlNode expandCovariance( final SqlNode arg0Input, final SqlNode arg1Input, final SqlNode dependent, final RelDataType varType, final SqlRexContext cx, boolean biased) { // covar_pop(x1, x2) ==> // (sum(x1 * x2) - sum(x2) * sum(x1) / count(x1, x2)) // / count(x1, x2) // // covar_samp(x1, x2) ==> // (sum(x1 * x2) - sum(x1) * sum(x2) / count(x1, x2)) // / (count(x1, x2) - 1) final SqlParserPos pos = SqlParserPos.ZERO; final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO); final RexNode arg0Rex = cx.convertExpression(arg0Input); final RexNode arg1Rex = cx.convertExpression(arg1Input); final SqlNode arg0 = getCastedSqlNode(arg0Input, varType, pos, arg0Rex); final SqlNode arg1 = getCastedSqlNode(arg1Input, varType, pos, arg1Rex); final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg0, arg1); final SqlNode sumArgSquared; final SqlNode sum0; final SqlNode sum1; final SqlNode count; if (dependent == null) { sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared); sum0 = SqlStdOperatorTable.SUM.createCall(pos, arg0, arg1); sum1 = SqlStdOperatorTable.SUM.createCall(pos, arg1, arg0); count = SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg0, arg1); } else { sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared, dependent); sum0 = SqlStdOperatorTable.SUM.createCall( pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent); sum1 = SqlStdOperatorTable.SUM.createCall( pos, arg1, Objects.equals(dependent, arg1Input) ? arg0 : dependent); count = SqlStdOperatorTable.REGR_COUNT.createCall( pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent); } final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum0, sum1); final SqlNode countCasted = getCastedSqlNode(count, varType, pos, cx.convertExpression(count)); final SqlNode avgSumSquared = SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, countCasted); final SqlNode diff = SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared); SqlNode denominator; if (biased) { denominator = countCasted; } else { final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos); denominator = new SqlCase(SqlParserPos.ZERO, countCasted, SqlNodeList.of(SqlStdOperatorTable.EQUALS.createCall(pos, countCasted, one)), SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)), SqlStdOperatorTable.MINUS.createCall(pos, countCasted, one)); } return SqlStdOperatorTable.DIVIDE.createCall(pos, diff, denominator); } private SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType, SqlParserPos pos, RexNode argRex) { SqlNode arg; if (argRex != null && !argRex.getType().equals(varType)) { arg = SqlStdOperatorTable.CAST.createCall( pos, argInput, SqlTypeUtil.convertTypeToSpec(varType)); } else { arg = argInput; } return arg; } } /** Convertlet that handles {@code AVG} and {@code VARIANCE} * windowed aggregate functions. */ private static class AvgVarianceConvertlet implements SqlRexConvertlet { private final SqlKind kind; AvgVarianceConvertlet(SqlKind kind) { this.kind = kind; } public RexNode convertCall(SqlRexContext cx, SqlCall call) { assert call.operandCount() == 1; final SqlNode arg = call.operand(0); final SqlNode expr; final RelDataType type = cx.getValidator().getValidatedNodeType(call); switch (kind) { case AVG: expr = expandAvg(arg, type, cx); break; case STDDEV_POP: expr = expandVariance(arg, type, cx, true, true); break; case STDDEV_SAMP: expr = expandVariance(arg, type, cx, false, true); break; case VAR_POP: expr = expandVariance(arg, type, cx, true, false); break; case VAR_SAMP: expr = expandVariance(arg, type, cx, false, false); break; default: throw Util.unexpected(kind); } RexNode rex = cx.convertExpression(expr); return cx.getRexBuilder().ensureType(type, rex, true); } private SqlNode expandAvg( final SqlNode arg, final RelDataType avgType, final SqlRexContext cx) { final SqlParserPos pos = SqlParserPos.ZERO; final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg); final RexNode sumRex = cx.convertExpression(sum); final SqlNode sumCast; sumCast = getCastedSqlNode(sum, avgType, pos, sumRex); final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg); return SqlStdOperatorTable.DIVIDE.createCall( pos, sumCast, count); } private SqlNode expandVariance( final SqlNode argInput, final RelDataType varType, final SqlRexContext cx, boolean biased, boolean sqrt) { // stddev_pop(x) ==> // power( // (sum(x * x) - sum(x) * sum(x) / count(x)) // / count(x), // .5) // // stddev_samp(x) ==> // power( // (sum(x * x) - sum(x) * sum(x) / count(x)) // / (count(x) - 1), // .5) // // var_pop(x) ==> // (sum(x * x) - sum(x) * sum(x) / count(x)) // / count(x) // // var_samp(x) ==> // (sum(x * x) - sum(x) * sum(x) / count(x)) // / (count(x) - 1) final SqlParserPos pos = SqlParserPos.ZERO; final SqlNode arg = getCastedSqlNode(argInput, varType, pos, cx.convertExpression(argInput)); final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg); final SqlNode argSquaredCasted = getCastedSqlNode(argSquared, varType, pos, cx.convertExpression(argSquared)); final SqlNode sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquaredCasted); final SqlNode sumArgSquaredCasted = getCastedSqlNode(sumArgSquared, varType, pos, cx.convertExpression(sumArgSquared)); final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg); final SqlNode sumCasted = getCastedSqlNode(sum, varType, pos, cx.convertExpression(sum)); final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sumCasted, sumCasted); final SqlNode sumSquaredCasted = getCastedSqlNode(sumSquared, varType, pos, cx.convertExpression(sumSquared)); final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg); final SqlNode countCasted = getCastedSqlNode(count, varType, pos, cx.convertExpression(count)); final SqlNode avgSumSquared = SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquaredCasted, countCasted); final SqlNode avgSumSquaredCasted = getCastedSqlNode(avgSumSquared, varType, pos, cx.convertExpression(avgSumSquared)); final SqlNode diff = SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquaredCasted, avgSumSquaredCasted); final SqlNode diffCasted = getCastedSqlNode(diff, varType, pos, cx.convertExpression(diff)); final SqlNode denominator; if (biased) { denominator = countCasted; } else { final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos); final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO); denominator = new SqlCase(SqlParserPos.ZERO, count, SqlNodeList.of(SqlStdOperatorTable.EQUALS.createCall(pos, count, one)), SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)), SqlStdOperatorTable.MINUS.createCall(pos, count, one)); } final SqlNode div = SqlStdOperatorTable.DIVIDE.createCall(pos, diffCasted, denominator); final SqlNode divCasted = getCastedSqlNode(div, varType, pos, cx.convertExpression(div)); SqlNode result = div; if (sqrt) { final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos); result = SqlStdOperatorTable.POWER.createCall(pos, divCasted, half); } return result; } private SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType, SqlParserPos pos, RexNode argRex) { SqlNode arg; if (argRex != null && !argRex.getType().equals(varType)) { arg = SqlStdOperatorTable.CAST.createCall( pos, argInput, SqlTypeUtil.convertTypeToSpec(varType)); } else { arg = argInput; } return arg; } } /** Convertlet that converts {@code LTRIM} and {@code RTRIM} to * {@code TRIM}. */ private static class TrimConvertlet implements SqlRexConvertlet { private final SqlTrimFunction.Flag flag; TrimConvertlet(SqlTrimFunction.Flag flag) { this.flag = flag; } public RexNode convertCall(SqlRexContext cx, SqlCall call) { final RexBuilder rexBuilder = cx.getRexBuilder(); final RexNode operand = cx.convertExpression(call.getOperandList().get(0)); return rexBuilder.makeCall(SqlStdOperatorTable.TRIM, rexBuilder.makeFlag(flag), rexBuilder.makeLiteral(" "), operand); } } /** Convertlet that converts {@code GREATEST} and {@code LEAST}. */ private static class GreatestConvertlet implements SqlRexConvertlet { public RexNode convertCall(SqlRexContext cx, SqlCall call) { // Translate // GREATEST(a, b, c, d) // to // CASE // WHEN a IS NULL OR b IS NULL OR c IS NULL OR d IS NULL // THEN NULL // WHEN a > b AND a > c AND a > d // THEN a // WHEN b > c AND b > d // THEN b // WHEN c > d // THEN c // ELSE d // END final RexBuilder rexBuilder = cx.getRexBuilder(); final RelDataType type = cx.getValidator().getValidatedNodeType(call); final SqlBinaryOperator op; switch (call.getKind()) { case GREATEST: op = SqlStdOperatorTable.GREATER_THAN; break; case LEAST: op = SqlStdOperatorTable.LESS_THAN; break; default: throw new AssertionError(); } final List exprs = convertExpressionList(cx, call.getOperandList(), SqlOperandTypeChecker.Consistency.NONE); final List list = new ArrayList<>(); final List orList = new ArrayList<>(); for (RexNode expr : exprs) { orList.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, expr)); } list.add(RexUtil.composeDisjunction(rexBuilder, orList)); list.add(rexBuilder.makeNullLiteral(type)); for (int i = 0; i < exprs.size() - 1; i++) { RexNode expr = exprs.get(i); final List andList = new ArrayList<>(); for (int j = i + 1; j < exprs.size(); j++) { final RexNode expr2 = exprs.get(j); andList.add(rexBuilder.makeCall(op, expr, expr2)); } list.add(RexUtil.composeConjunction(rexBuilder, andList)); list.add(expr); } list.add(exprs.get(exprs.size() - 1)); return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, list); } } /** Convertlet that handles {@code FLOOR} and {@code CEIL} functions. */ private class FloorCeilConvertlet implements SqlRexConvertlet { public RexNode convertCall(SqlRexContext cx, SqlCall call) { return convertFloorCeil(cx, call); } } /** Convertlet that handles the {@code TIMESTAMPADD} function. */ private static class TimestampAddConvertlet implements SqlRexConvertlet { public RexNode convertCall(SqlRexContext cx, SqlCall call) { // TIMESTAMPADD(unit, count, timestamp) // => timestamp + count * INTERVAL '1' UNIT final RexBuilder rexBuilder = cx.getRexBuilder(); final SqlLiteral unitLiteral = call.operand(0); final TimeUnit unit = unitLiteral.symbolValue(TimeUnit.class); RexNode interval2Add; SqlIntervalQualifier qualifier = new SqlIntervalQualifier(unit, null, unitLiteral.getParserPosition()); RexNode op1 = cx.convertExpression(call.operand(1)); switch (unit) { case MICROSECOND: case NANOSECOND: interval2Add = divide(rexBuilder, multiply(rexBuilder, rexBuilder.makeIntervalLiteral(BigDecimal.ONE, qualifier), op1), BigDecimal.ONE.divide(unit.multiplier, RoundingMode.UNNECESSARY)); break; default: interval2Add = multiply(rexBuilder, rexBuilder.makeIntervalLiteral(unit.multiplier, qualifier), op1); } return rexBuilder.makeCall(SqlStdOperatorTable.DATETIME_PLUS, cx.convertExpression(call.operand(2)), interval2Add); } } /** Convertlet that handles the {@code TIMESTAMPDIFF} function. */ private static class TimestampDiffConvertlet implements SqlRexConvertlet { public RexNode convertCall(SqlRexContext cx, SqlCall call) { // TIMESTAMPDIFF(unit, t1, t2) // => (t2 - t1) UNIT final RexBuilder rexBuilder = cx.getRexBuilder(); final SqlLiteral unitLiteral = call.operand(0); TimeUnit unit = unitLiteral.symbolValue(TimeUnit.class); BigDecimal multiplier = BigDecimal.ONE; BigDecimal divider = BigDecimal.ONE; SqlTypeName sqlTypeName = unit == TimeUnit.NANOSECOND ? SqlTypeName.BIGINT : SqlTypeName.INTEGER; switch (unit) { case MICROSECOND: case MILLISECOND: case NANOSECOND: case WEEK: multiplier = BigDecimal.valueOf(DateTimeUtils.MILLIS_PER_SECOND); divider = unit.multiplier; unit = TimeUnit.SECOND; break; case QUARTER: divider = unit.multiplier; unit = TimeUnit.MONTH; break; } final SqlIntervalQualifier qualifier = new SqlIntervalQualifier(unit, null, SqlParserPos.ZERO); final RexNode op2 = cx.convertExpression(call.operand(2)); final RexNode op1 = cx.convertExpression(call.operand(1)); final RelDataType intervalType = cx.getTypeFactory().createTypeWithNullability( cx.getTypeFactory().createSqlIntervalType(qualifier), op1.getType().isNullable() || op2.getType().isNullable()); final RexCall rexCall = (RexCall) rexBuilder.makeCall( intervalType, SqlStdOperatorTable.MINUS_DATE, ImmutableList.of(op2, op1)); final RelDataType intType = cx.getTypeFactory().createTypeWithNullability( cx.getTypeFactory().createSqlType(sqlTypeName), SqlTypeUtil.containsNullable(rexCall.getType())); RexNode e = rexBuilder.makeCast(intType, rexCall); return rexBuilder.multiplyDivide(e, multiplier, divider); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy