All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hazelcast.org.apache.calcite.rel.rules.ReduceDecimalsRule Maven / Gradle / Ivy

There is a newer version: 5.5.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.hazelcast.org.apache.calcite.rel.rules;

import com.hazelcast.org.apache.calcite.linq4j.Ord;
import com.hazelcast.org.apache.calcite.plan.Convention;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleCall;
import com.hazelcast.org.apache.calcite.plan.RelRule;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalCalc;
import com.hazelcast.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeSystem;
import com.hazelcast.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.org.apache.calcite.rex.RexCall;
import com.hazelcast.org.apache.calcite.rex.RexLiteral;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexProgram;
import com.hazelcast.org.apache.calcite.rex.RexProgramBuilder;
import com.hazelcast.org.apache.calcite.rex.RexShuttle;
import com.hazelcast.org.apache.calcite.rex.RexUtil;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.sql.SqlOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import com.hazelcast.org.apache.calcite.sql.type.SqlTypeName;
import com.hazelcast.org.apache.calcite.sql.type.SqlTypeUtil;
import com.hazelcast.org.apache.calcite.tools.RelBuilderFactory;
import com.hazelcast.org.apache.calcite.util.Pair;
import com.hazelcast.org.apache.calcite.util.Util;

import com.hazelcast.com.google.common.collect.ImmutableList;

import com.hazelcast.org.checkerframework.checker.nullness.qual.MonotonicNonNull;
import com.hazelcast.org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static com.hazelcast.org.apache.calcite.util.Static.RESOURCE;

import static java.util.Objects.requireNonNull;

/**
 * Rule that reduces decimal operations (such as casts
 * or arithmetic) into operations involving more primitive types (such as longs
 * and doubles). The rule allows Calcite implementations to deal with decimals
 * in a consistent manner, while saving the effort of implementing them.
 *
 * 

The rule can be applied to a * {@link com.hazelcast.org.apache.calcite.rel.logical.LogicalCalc} with a program for which * {@link RexUtil#requiresDecimalExpansion} returns true. The rule relies on a * {@link RexShuttle} to walk over relational expressions and replace them. * *

While decimals are generally not implemented by the Calcite runtime, the * rule is optionally applied, in order to support the situation in which we * would like to push down decimal operations to an external database. * * @see CoreRules#CALC_REDUCE_DECIMALS */ @Value.Enclosing public class ReduceDecimalsRule extends RelRule implements TransformationRule { /** Creates a ReduceDecimalsRule. */ protected ReduceDecimalsRule(Config config) { super(config); } @Deprecated // to be removed before 2.0 public ReduceDecimalsRule(RelBuilderFactory relBuilderFactory) { this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory) .as(Config.class)); } //~ Methods ---------------------------------------------------------------- @Override public @Nullable Convention getOutConvention() { return Convention.NONE; } @Override public void onMatch(RelOptRuleCall call) { LogicalCalc calc = call.rel(0); // Expand decimals in every expression in this program. If no // expression changes, don't apply the rule. final RexProgram program = calc.getProgram(); if (!RexUtil.requiresDecimalExpansion(program, true)) { return; } final RexBuilder rexBuilder = calc.getCluster().getRexBuilder(); final RexShuttle shuttle = new DecimalShuttle(rexBuilder); RexProgramBuilder programBuilder = RexProgramBuilder.create( rexBuilder, calc.getInput().getRowType(), program.getExprList(), program.getProjectList(), program.getCondition(), program.getOutputRowType(), shuttle, true); final RexProgram newProgram = programBuilder.getProgram(); LogicalCalc newCalc = LogicalCalc.create(calc.getInput(), newProgram); call.transformTo(newCalc); } //~ Inner Classes ---------------------------------------------------------- /** * A shuttle which converts decimal expressions to expressions based on * longs. */ public static class DecimalShuttle extends RexShuttle { private final Map, RexNode> irreducible; private final Map, RexNode> results; private final ExpanderMap expanderMap; DecimalShuttle(RexBuilder rexBuilder) { irreducible = new HashMap<>(); results = new HashMap<>(); expanderMap = new ExpanderMap(rexBuilder); } /** * Rewrites a call in place, from bottom up. Algorithm is as follows: * *

    *
  1. visit operands *
  2. visit call node * *
      *
    1. rewrite call *
    2. visit the rewritten call *
    *
*/ @Override public RexNode visitCall(RexCall call) { RexNode savedResult = lookup(call); if (savedResult != null) { return savedResult; } RexNode newCall = call; RexNode rewrite = rewriteCall(call); if (rewrite != call) { newCall = rewrite.accept(this); } register(call, newCall); return newCall; } /** * Registers node so it will not be computed again. */ private void register(RexNode node, RexNode reducedNode) { Pair key = RexUtil.makeKey(node); if (node == reducedNode) { irreducible.put(key, reducedNode); } else { results.put(key, reducedNode); } } /** * Looks up a registered node. */ private @Nullable RexNode lookup(RexNode node) { Pair key = RexUtil.makeKey(node); if (irreducible.get(key) != null) { return node; } return results.get(key); } /** * Rewrites a call, if required, or returns the original call. */ private RexNode rewriteCall(RexCall call) { SqlOperator operator = call.getOperator(); if (!operator.requiresDecimalExpansion()) { return call; } RexExpander expander = getExpander(call); if (expander.canExpand(call)) { return expander.expand(call); } return call; } /** * Returns a {@link RexExpander} for a call. */ private RexExpander getExpander(RexCall call) { return expanderMap.getExpander(call); } } /** * Maps a RexCall to a RexExpander. */ private static class ExpanderMap { private final Map map; private RexExpander defaultExpander; private ExpanderMap(RexBuilder rexBuilder) { map = new HashMap<>(); defaultExpander = new CastArgAsDoubleExpander(rexBuilder); registerExpanders(map, rexBuilder); } private static void registerExpanders(Map map, RexBuilder rexBuilder) { RexExpander cast = new CastExpander(rexBuilder); map.put(SqlStdOperatorTable.CAST, cast); RexExpander passThrough = new PassThroughExpander(rexBuilder); map.put(SqlStdOperatorTable.UNARY_MINUS, passThrough); map.put(SqlStdOperatorTable.ABS, passThrough); map.put(SqlStdOperatorTable.IS_NULL, passThrough); map.put(SqlStdOperatorTable.IS_NOT_NULL, passThrough); RexExpander arithmetic = new BinaryArithmeticExpander(rexBuilder); map.put(SqlStdOperatorTable.DIVIDE, arithmetic); map.put(SqlStdOperatorTable.MULTIPLY, arithmetic); map.put(SqlStdOperatorTable.PLUS, arithmetic); map.put(SqlStdOperatorTable.MINUS, arithmetic); map.put(SqlStdOperatorTable.MOD, arithmetic); map.put(SqlStdOperatorTable.EQUALS, arithmetic); map.put(SqlStdOperatorTable.GREATER_THAN, arithmetic); map.put( SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, arithmetic); map.put(SqlStdOperatorTable.LESS_THAN, arithmetic); map.put(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, arithmetic); RexExpander floor = new FloorExpander(rexBuilder); map.put(SqlStdOperatorTable.FLOOR, floor); RexExpander ceil = new CeilExpander(rexBuilder); map.put(SqlStdOperatorTable.CEIL, ceil); RexExpander reinterpret = new ReinterpretExpander(rexBuilder); map.put(SqlStdOperatorTable.REINTERPRET, reinterpret); RexExpander caseExpander = new CaseExpander(rexBuilder); map.put(SqlStdOperatorTable.CASE, caseExpander); } RexExpander getExpander(RexCall call) { RexExpander expander = map.get(call.getOperator()); return (expander != null) ? expander : defaultExpander; } } /** * Rewrites a decimal expression for a specific set of SqlOperator's. In * general, most expressions are rewritten in such a way that SqlOperator's * do not have to deal with decimals. Decimals are represented by their * unscaled integer representations, similar to * {@link BigDecimal#unscaledValue()} (i.e. 10^scale). Once decimals are * decoded, SqlOperators can then operate on the integer representations. The * value can later be recoded as a decimal. * *

For example, suppose one casts 2.0 as a decimal(10,4). The value is * decoded (20), multiplied by a scale factor (1000), for a result of * (20000) which is encoded as a decimal(10,4), in this case 2.0000 * *

To avoid the lengthy coding of RexNode expressions, this base class * provides succinct methods for building expressions used in rewrites. */ public abstract static class RexExpander { /** * Factory for creating relational expressions. */ final RexBuilder builder; /** * Type for the internal representation of decimals. This type is a * non-nullable type and requires extra work to make it nullable. */ final RelDataType int8; /** * Type for doubles. This type is a non-nullable type and requires extra * work to make it nullable. */ final RelDataType real8; /** * Creates a RexExpander. */ RexExpander(RexBuilder builder) { this.builder = builder; int8 = builder.getTypeFactory().createSqlType(SqlTypeName.BIGINT); real8 = builder.getTypeFactory().createSqlType(SqlTypeName.DOUBLE); } /** * This defaults to the utility method, * {@link RexUtil#requiresDecimalExpansion(RexNode, boolean)} which checks * general guidelines on whether a rewrite should be considered at all. In * general, it is helpful to update the utility method since that method is * often used to filter the somewhat expensive rewrite process. * *

However, this method provides another place for implementations of * RexExpander to make a more detailed analysis before deciding on * whether to perform a rewrite. */ public boolean canExpand(RexCall call) { return RexUtil.requiresDecimalExpansion(call, false); } /** * Rewrites an expression containing decimals. Normally, this method * always performs a rewrite, but implementations may choose to return * the original expression if no change was required. */ public abstract RexNode expand(RexCall call); /** * Makes an exact numeric literal to be used for scaling. * * @param scale a scale from one to max precision - 1 * @return 10^scale as an exact numeric value */ protected RexNode makeScaleFactor(int scale) { assert scale > 0; assert scale < builder.getTypeFactory().getTypeSystem().getMaxNumericPrecision(); return makeExactLiteral(powerOfTen(scale)); } /** * Makes an approximate literal to be used for scaling. * * @param scale a scale from -99 to 99 * @return 10^scale as an approximate value */ protected RexNode makeApproxScaleFactor(int scale) { assert (-100 < scale) && (scale < 100) : "could not make approximate scale factor"; if (scale >= 0) { return makeApproxLiteral(BigDecimal.TEN.pow(scale)); } else { BigDecimal tenth = BigDecimal.valueOf(1, 1); return makeApproxLiteral(tenth.pow(-scale)); } } /** * Makes an exact numeric value to be used for rounding. * * @param scale a scale from 1 to max precision - 1 * @return 10^scale / 2 as an exact numeric value */ protected RexNode makeRoundFactor(int scale) { assert scale > 0; assert scale < builder.getTypeFactory().getTypeSystem().getMaxNumericPrecision(); return makeExactLiteral(powerOfTen(scale) / 2); } /** * Calculates a power of ten, as a long value. */ protected long powerOfTen(int scale) { assert scale >= 0; assert scale < builder.getTypeFactory().getTypeSystem().getMaxNumericPrecision(); return BigInteger.TEN.pow(scale).longValue(); } /** * Makes an exact, non-nullable literal of Bigint type. */ protected RexNode makeExactLiteral(long l) { BigDecimal bd = BigDecimal.valueOf(l); return builder.makeExactLiteral(bd, int8); } /** * Makes an approximate literal of double precision. */ protected RexNode makeApproxLiteral(BigDecimal bd) { return builder.makeApproxLiteral(bd); } /** * Scales up a decimal value and returns the scaled value as an exact * number. * * @param value the integer representation of a decimal * @param scale a value from zero to max precision - 1 * @return value * 10^scale as an exact numeric value */ protected RexNode scaleUp(RexNode value, int scale) { assert scale >= 0; assert scale < builder.getTypeFactory().getTypeSystem().getMaxNumericPrecision(); if (scale == 0) { return value; } return builder.makeCall( SqlStdOperatorTable.MULTIPLY, value, makeScaleFactor(scale)); } /** * Scales down a decimal value, and returns the scaled value as an exact * numeric. with the rounding convention * {@link BigDecimal#ROUND_HALF_UP BigDecimal.ROUND_HALF_UP}. (Values midway * between two points are rounded away from zero.) * * @param value the integer representation of a decimal * @param scale a value from zero to max precision * @return value/10^scale, rounded away from zero and returned as an * exact numeric value */ protected RexNode scaleDown(RexNode value, int scale) { final int maxPrecision = builder.getTypeFactory().getTypeSystem().getMaxNumericPrecision(); assert scale >= 0 && scale <= maxPrecision; if (scale == 0) { return value; } if (scale == maxPrecision) { long half = BigInteger.TEN.pow(scale - 1).longValue() * 5; return makeCase( builder.makeCall( SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, value, makeExactLiteral(half)), makeExactLiteral(1), builder.makeCall( SqlStdOperatorTable.LESS_THAN_OR_EQUAL, value, makeExactLiteral(-half)), makeExactLiteral(-1), makeExactLiteral(0)); } RexNode roundFactor = makeRoundFactor(scale); RexNode roundValue = makeCase( builder.makeCall( SqlStdOperatorTable.GREATER_THAN, value, makeExactLiteral(0)), makePlus(value, roundFactor), makeMinus(value, roundFactor)); return makeDivide( roundValue, makeScaleFactor(scale)); } /** * Scales down a decimal value and returns the scaled value as a an * double precision approximate value. Scaling is implemented with * double precision arithmetic. * * @param value the integer representation of a decimal * @param scale a value from zero to max precision * @return value/10^scale as a double precision value */ protected RexNode scaleDownDouble(RexNode value, int scale) { assert scale >= 0; assert scale <= builder.getTypeFactory().getTypeSystem().getMaxNumericPrecision(); RexNode cast = ensureType(real8, value); if (scale == 0) { return cast; } return makeDivide( cast, makeApproxScaleFactor(scale)); } /** * Ensures a value is of a required scale. If it is not, then the value * is multiplied by a scale factor. Scaling up an exact value is limited * to max precision - 1, because we cannot represent the result of * larger scales internally. Scaling up a floating point value is more * flexible since the value may be very small despite having a scale of * zero and the scaling may still produce a reasonable result * * @param value integer representation of decimal, or a floating point * number * @param scale current scale, 0 for floating point numbers * @param required required scale, must be at least the current scale; * the scale difference may not be greater than max * precision - 1 for exact numerics * @return value * 10^scale, returned as an exact or approximate value * corresponding to the input value */ protected RexNode ensureScale(RexNode value, int scale, int required) { final RelDataTypeSystem typeSystem = builder.getTypeFactory().getTypeSystem(); final int maxPrecision = typeSystem.getMaxNumericPrecision(); assert scale <= maxPrecision && required <= maxPrecision; assert required >= scale; if (scale == required) { return value; } int scaleDiff = required - scale; if (SqlTypeUtil.isApproximateNumeric(value.getType())) { return makeMultiply( value, makeApproxScaleFactor(scaleDiff)); } // TODO: make a validator exception for this if (scaleDiff >= maxPrecision) { throw Util.needToImplement("Source type with scale " + scale + " cannot be converted to target type with scale " + required + " because the smallest value of the " + "source type is too large to be encoded by the " + "target type"); } return scaleUp(value, scaleDiff); } /** * Retrieves a decimal node's integer representation. * * @param decimalNode the decimal value as an opaque type * @return an integer representation of the decimal value */ protected RexNode decodeValue(RexNode decimalNode) { assert SqlTypeUtil.isDecimal(decimalNode.getType()); return builder.decodeIntervalOrDecimal(decimalNode); } /** * Retrieves the primitive value of a numeric node. If the node is a * decimal, then it must first be decoded. Otherwise the original node * may be returned. * * @param node a numeric node, possibly a decimal * @return the primitive value of the numeric node */ protected RexNode accessValue(RexNode node) { assert SqlTypeUtil.isNumeric(node.getType()); if (SqlTypeUtil.isDecimal(node.getType())) { return decodeValue(node); } return node; } /** * Casts a decimal's integer representation to a decimal node. If the * expression is not the expected integer type, then it is casted first. * *

This method does not request an overflow check. * * @param value integer representation of decimal * @param decimalType type integer will be reinterpreted as * @return the integer representation reinterpreted as a decimal type */ protected RexNode encodeValue(RexNode value, RelDataType decimalType) { return encodeValue(value, decimalType, false); } /** * Casts a decimal's integer representation to a decimal node. If the * expression is not the expected integer type, then it is casted first. * *

An overflow check may be requested to ensure the internal value * does not exceed the maximum value of the decimal type. * * @param value integer representation of decimal * @param decimalType type integer will be reinterpreted as * @param checkOverflow indicates whether an overflow check is required * when reinterpreting this particular value as the * decimal type. A check usually not required for * arithmetic, but is often required for rounding and * explicit casts. * @return the integer reinterpreted as an opaque decimal type */ protected RexNode encodeValue( RexNode value, RelDataType decimalType, boolean checkOverflow) { return builder.encodeIntervalOrDecimal( value, decimalType, checkOverflow); } /** * Ensures expression is interpreted as a specified type. The returned * expression may be wrapped with a cast. * *

This method corrects the nullability of the specified type to * match the nullability of the expression. * * @param type desired type * @param node expression * @return a casted expression or the original expression */ protected RexNode ensureType(RelDataType type, RexNode node) { return ensureType(type, node, true); } /** * Ensures expression is interpreted as a specified type. The returned * expression may be wrapped with a cast. * * @param type desired type * @param node expression * @param matchNullability whether to correct nullability of specified * type to match the expression; this usually should * be true, except for explicit casts which can * override default nullability * @return a casted expression or the original expression */ protected RexNode ensureType( RelDataType type, RexNode node, boolean matchNullability) { return builder.ensureType(type, node, matchNullability); } protected RexNode makeCase( RexNode condition, RexNode thenClause, RexNode elseClause) { return builder.makeCall( SqlStdOperatorTable.CASE, condition, thenClause, elseClause); } protected RexNode makeCase( RexNode whenA, RexNode thenA, RexNode whenB, RexNode thenB, RexNode elseClause) { return builder.makeCall( SqlStdOperatorTable.CASE, whenA, thenA, whenB, thenB, elseClause); } protected RexNode makePlus( RexNode a, RexNode b) { return builder.makeCall( SqlStdOperatorTable.PLUS, a, b); } protected RexNode makeMinus( RexNode a, RexNode b) { return builder.makeCall( SqlStdOperatorTable.MINUS, a, b); } protected RexNode makeDivide( RexNode a, RexNode b) { return builder.makeCall( SqlStdOperatorTable.DIVIDE_INTEGER, a, b); } protected RexNode makeMultiply( RexNode a, RexNode b) { return builder.makeCall( SqlStdOperatorTable.MULTIPLY, a, b); } protected RexNode makeIsPositive( RexNode a) { return builder.makeCall( SqlStdOperatorTable.GREATER_THAN, a, makeExactLiteral(0)); } protected RexNode makeIsNegative( RexNode a) { return builder.makeCall( SqlStdOperatorTable.LESS_THAN, a, makeExactLiteral(0)); } } /** * Expands a decimal cast expression. */ private static class CastExpander extends RexExpander { private CastExpander(RexBuilder builder) { super(builder); } // implement RexExpander @Override public RexNode expand(RexCall call) { List operands = call.operands; assert call.isA(SqlKind.CAST); assert operands.size() == 1; assert !RexLiteral.isNullLiteral(operands.get(0)); RexNode operand = operands.get(0); RelDataType fromType = operand.getType(); RelDataType toType = call.getType(); assert SqlTypeUtil.isDecimal(fromType) || SqlTypeUtil.isDecimal(toType); if (SqlTypeUtil.isIntType(toType)) { // decimal to int return ensureType( toType, scaleDown( decodeValue(operand), fromType.getScale()), false); } else if (SqlTypeUtil.isApproximateNumeric(toType)) { // decimal to floating point return ensureType( toType, scaleDownDouble( decodeValue(operand), fromType.getScale()), false); } else if (SqlTypeUtil.isApproximateNumeric(fromType)) { // real to decimal return encodeValue( ensureScale( operand, 0, toType.getScale()), toType, true); } if (!SqlTypeUtil.isExactNumeric(fromType) || !SqlTypeUtil.isExactNumeric(toType)) { throw Util.needToImplement( "Cast from '" + fromType.toString() + "' to '" + toType.toString() + "'"); } int fromScale = fromType.getScale(); int toScale = toType.getScale(); int fromDigits = fromType.getPrecision() - fromScale; int toDigits = toType.getPrecision() - toScale; // NOTE: precision 19 overflows when its underlying // bigint representation overflows boolean checkOverflow = (toType.getPrecision() < 19) && (toDigits < fromDigits); if (SqlTypeUtil.isIntType(fromType)) { // int to decimal return encodeValue( ensureScale( operand, 0, toType.getScale()), toType, checkOverflow); } else if ( SqlTypeUtil.isDecimal(fromType) && SqlTypeUtil.isDecimal(toType)) { // decimal to decimal RexNode value = decodeValue(operand); RexNode scaled; if (fromScale <= toScale) { scaled = ensureScale(value, fromScale, toScale); } else { if (toDigits == fromDigits) { // rounding away from zero may cause an overflow // for example: cast(9.99 as decimal(2,1)) checkOverflow = true; } scaled = scaleDown(value, fromScale - toScale); } return encodeValue(scaled, toType, checkOverflow); } else { throw Util.needToImplement( "Reduce decimal cast from " + fromType + " to " + toType); } } } /** * Expands a decimal arithmetic expression. */ private static class BinaryArithmeticExpander extends RexExpander { @MonotonicNonNull RelDataType typeA; @MonotonicNonNull RelDataType typeB; int scaleA; int scaleB; private BinaryArithmeticExpander(RexBuilder builder) { super(builder); } @Override public RexNode expand(RexCall call) { List operands = call.operands; assert operands.size() == 2; RelDataType typeA = operands.get(0).getType(); RelDataType typeB = operands.get(1).getType(); assert SqlTypeUtil.isNumeric(typeA) && SqlTypeUtil.isNumeric(typeB); if (SqlTypeUtil.isApproximateNumeric(typeA) || SqlTypeUtil.isApproximateNumeric(typeB)) { List newOperands; if (SqlTypeUtil.isApproximateNumeric(typeA)) { newOperands = ImmutableList.of( operands.get(0), ensureType(real8, operands.get(1))); } else { newOperands = ImmutableList.of( ensureType(real8, operands.get(0)), operands.get(1)); } return builder.makeCall( call.getOperator(), newOperands); } analyzeOperands(operands); if (call.isA(SqlKind.PLUS)) { return expandPlusMinus(call, operands); } else if (call.isA(SqlKind.MINUS)) { return expandPlusMinus(call, operands); } else if (call.isA(SqlKind.DIVIDE)) { return expandDivide(call, operands); } else if (call.isA(SqlKind.TIMES)) { return expandTimes(call, operands); } else if (call.isA(SqlKind.COMPARISON)) { return expandComparison(call, operands); } else if (call.getOperator() == SqlStdOperatorTable.MOD) { return expandMod(call, operands); } else { throw new AssertionError("ReduceDecimalsRule could not expand " + call.getOperator()); } } /** * Convenience method for reading characteristics of operands (such as * scale, precision, whole digits) into an ArithmeticExpander. The * operands are restricted by the following constraints: * *

    *
  • there are exactly two operands *
  • both are exact numeric types *
*/ private void analyzeOperands(List operands) { assert operands.size() == 2; typeA = operands.get(0).getType(); typeB = operands.get(1).getType(); assert SqlTypeUtil.isExactNumeric(typeA) && SqlTypeUtil.isExactNumeric(typeB); scaleA = typeA.getScale(); scaleB = typeB.getScale(); } private RexNode expandPlusMinus(RexCall call, List operands) { RelDataType outType = call.getType(); int outScale = outType.getScale(); return encodeValue( builder.makeCall( call.getOperator(), ensureScale( accessValue(operands.get(0)), scaleA, outScale), ensureScale( accessValue(operands.get(1)), scaleB, outScale)), outType); } private RexNode expandDivide(RexCall call, List operands) { RelDataType outType = call.getType(); RexNode dividend = builder.makeCall( call.getOperator(), ensureType( real8, accessValue(operands.get(0))), ensureType( real8, accessValue(operands.get(1)))); int scaleDifference = outType.getScale() - scaleA + scaleB; RexNode rescale = builder.makeCall( SqlStdOperatorTable.MULTIPLY, dividend, makeApproxScaleFactor(scaleDifference)); return encodeValue(rescale, outType); } private RexNode expandTimes(RexCall call, List operands) { // Multiplying the internal values of the two arguments leads to // a number with scale = scaleA + scaleB. If the result type has // a lower scale, then the number should be scaled down. int divisor = scaleA + scaleB - call.getType().getScale(); if (builder.getTypeFactory().getTypeSystem().shouldUseDoubleMultiplication( builder.getTypeFactory(), requireNonNull(typeA, "typeA"), requireNonNull(typeB, "typeB"))) { // Approximate implementation: // cast (a as double) * cast (b as double) // / 10^divisor RexNode division = makeDivide( makeMultiply( ensureType(real8, accessValue(operands.get(0))), ensureType(real8, accessValue(operands.get(1)))), makeApproxLiteral(BigDecimal.TEN.pow(divisor))); return encodeValue(division, call.getType(), true); } else { // Exact implementation: scaleDown(a * b) return encodeValue( scaleDown( builder.makeCall( call.getOperator(), accessValue(operands.get(0)), accessValue(operands.get(1))), divisor), call.getType()); } } private RexNode expandComparison(RexCall call, List operands) { int commonScale = Math.max(scaleA, scaleB); return builder.makeCall( call.getOperator(), ensureScale( accessValue(operands.get(0)), scaleA, commonScale), ensureScale( accessValue(operands.get(1)), scaleB, commonScale)); } private RexNode expandMod(RexCall call, List operands) { assert SqlTypeUtil.isExactNumeric(requireNonNull(typeA, "typeA")); assert SqlTypeUtil.isExactNumeric(requireNonNull(typeB, "typeB")); if (scaleA != 0 || scaleB != 0) { throw RESOURCE.argumentMustHaveScaleZero(call.getOperator().getName()) .ex(); } RexNode result = builder.makeCall( call.getOperator(), accessValue(operands.get(0)), accessValue(operands.get(1))); RelDataType retType = call.getType(); if (SqlTypeUtil.isDecimal(retType)) { return encodeValue(result, retType); } return ensureType( call.getType(), result); } } /** * Expander that rewrites {@code FLOOR(DECIMAL)} expressions. * Rewrite is as follows: * *
   * if (value < 0)
   *     (value - 0.99...) / (10^scale)
   * else
   *     value / (10 ^ scale)
   * 
*/ private static class FloorExpander extends RexExpander { private FloorExpander(RexBuilder rexBuilder) { super(rexBuilder); } @Override public RexNode expand(RexCall call) { assert call.getOperator() == SqlStdOperatorTable.FLOOR; RexNode decValue = call.operands.get(0); int scale = decValue.getType().getScale(); RexNode value = decodeValue(decValue); final RelDataTypeSystem typeSystem = builder.getTypeFactory().getTypeSystem(); RexNode rewrite; if (scale == 0) { rewrite = decValue; } else if (scale == typeSystem.getMaxNumericPrecision()) { rewrite = makeCase( makeIsNegative(value), makeExactLiteral(-1), makeExactLiteral(0)); } else { RexNode round = makeExactLiteral(1 - powerOfTen(scale)); RexNode scaleFactor = makeScaleFactor(scale); rewrite = makeCase( makeIsNegative(value), makeDivide( makePlus(value, round), scaleFactor), makeDivide(value, scaleFactor)); } return encodeValue( rewrite, call.getType()); } } /** * Expander that rewrites {@code CEILING(DECIMAL)} expressions. * Rewrite is as follows: * *
   * if (value > 0)
   *     (value + 0.99...) / (10 ^ scale)
   * else
   *     value / (10 ^ scale)
   * 
*/ private static class CeilExpander extends RexExpander { private CeilExpander(RexBuilder rexBuilder) { super(rexBuilder); } @Override public RexNode expand(RexCall call) { assert call.getOperator() == SqlStdOperatorTable.CEIL; RexNode decValue = call.operands.get(0); int scale = decValue.getType().getScale(); RexNode value = decodeValue(decValue); final RelDataTypeSystem typeSystem = builder.getTypeFactory().getTypeSystem(); RexNode rewrite; if (scale == 0) { rewrite = decValue; } else if (scale == typeSystem.getMaxNumericPrecision()) { rewrite = makeCase( makeIsPositive(value), makeExactLiteral(1), makeExactLiteral(0)); } else { RexNode round = makeExactLiteral(powerOfTen(scale) - 1); RexNode scaleFactor = makeScaleFactor(scale); rewrite = makeCase( makeIsPositive(value), makeDivide( makePlus(value, round), scaleFactor), makeDivide(value, scaleFactor)); } return encodeValue( rewrite, call.getType()); } } /** * Expander that rewrites case expressions, in place. Starting from: * *
(when $cond then $val)+ else $default
* *

this expander casts all values to the return type. If the target type is * a decimal, then the values are then decoded. The result of expansion is * that the case operator no longer deals with decimals args. (The return * value is encoded if necessary.) * *

Note: a decimal type is returned iff arguments have decimals. */ private static class CaseExpander extends RexExpander { private CaseExpander(RexBuilder rexBuilder) { super(rexBuilder); } @Override public RexNode expand(RexCall call) { RelDataType retType = call.getType(); int argCount = call.operands.size(); ImmutableList.Builder opBuilder = ImmutableList.builder(); for (int i = 0; i < argCount; i++) { // skip case conditions if (((i % 2) == 0) && (i != (argCount - 1))) { opBuilder.add(call.operands.get(i)); continue; } RexNode expr = ensureType(retType, call.operands.get(i), false); if (SqlTypeUtil.isDecimal(retType)) { expr = decodeValue(expr); } opBuilder.add(expr); } RexNode newCall = builder.makeCall(retType, call.getOperator(), opBuilder.build()); if (SqlTypeUtil.isDecimal(retType)) { newCall = encodeValue(newCall, retType); } return newCall; } } /** * An expander that substitutes decimals with their integer representations. * If the output is decimal, the output is reinterpreted from the integer * representation into a decimal. */ private static class PassThroughExpander extends RexExpander { private PassThroughExpander(RexBuilder builder) { super(builder); } @Override public boolean canExpand(RexCall call) { return RexUtil.requiresDecimalExpansion(call, false); } @Override public RexNode expand(RexCall call) { ImmutableList.Builder opBuilder = ImmutableList.builder(); for (RexNode operand : call.operands) { if (SqlTypeUtil.isNumeric(operand.getType())) { opBuilder.add(accessValue(operand)); } else { opBuilder.add(operand); } } RexNode newCall = builder.makeCall(call.getType(), call.getOperator(), opBuilder.build()); if (SqlTypeUtil.isDecimal(call.getType())) { return encodeValue( newCall, call.getType()); } else { return newCall; } } } /** * Expander that casts DECIMAL arguments as DOUBLE. */ private static class CastArgAsDoubleExpander extends CastArgAsTypeExpander { private CastArgAsDoubleExpander(RexBuilder builder) { super(builder); } @Override public RelDataType getArgType(RexCall call, int ordinal) { RelDataType type = real8; if (call.operands.get(ordinal).getType().isNullable()) { type = builder.getTypeFactory().createTypeWithNullability( type, true); } return type; } } /** * Expander that casts DECIMAL arguments as another type. */ private abstract static class CastArgAsTypeExpander extends RexExpander { private CastArgAsTypeExpander(RexBuilder builder) { super(builder); } public abstract RelDataType getArgType(RexCall call, int ordinal); @Override public RexNode expand(RexCall call) { ImmutableList.Builder opBuilder = ImmutableList.builder(); for (Ord operand : Ord.zip(call.operands)) { RelDataType targetType = getArgType(call, operand.i); if (SqlTypeUtil.isDecimal(operand.e.getType())) { opBuilder.add(ensureType(targetType, operand.e, true)); } else { opBuilder.add(operand.e); } } RexNode ret = builder.makeCall( call.getType(), call.getOperator(), opBuilder.build()); ret = ensureType( call.getType(), ret, true); return ret; } } /** * An expander that simplifies reinterpret calls. * *

Consider (1.0+1)*1. The inner * operation encodes a decimal (Reinterpret(...)) which the outer operation * immediately decodes: (Reinterpret(Reinterpret(...))). Arithmetic overflow * is handled by underlying integer operations, so we don't have to consider * it. Simply remove the nested Reinterpret. */ private static class ReinterpretExpander extends RexExpander { private ReinterpretExpander(RexBuilder builder) { super(builder); } @Override public boolean canExpand(RexCall call) { return call.isA(SqlKind.REINTERPRET) && call.operands.get(0).isA(SqlKind.REINTERPRET); } @Override public RexNode expand(RexCall call) { List operands = call.operands; RexCall subCall = (RexCall) operands.get(0); RexNode innerValue = subCall.operands.get(0); if (canSimplify(call, subCall, innerValue)) { return innerValue; } return call; } /** * Detect, in a generic, but strict way, whether it is possible to * simplify a reinterpret cast. The rules are as follows: * *

    *
  1. If value is not the same basic type as outer, then we cannot * simplify *
  2. If the value is nullable but the inner or outer are not, then we * cannot simplify. *
  3. If inner is nullable but outer is not, we cannot simplify. *
  4. If an overflow check is required from either inner or outer, we * cannot simplify. *
  5. Otherwise, given the same type, and sufficient nullability * constraints, we can simplify. *
* * @param outer outer call to reinterpret * @param inner inner call to reinterpret * @param value inner value * @return whether the two reinterpret casts can be removed */ private static boolean canSimplify( RexCall outer, RexCall inner, RexNode value) { RelDataType outerType = outer.getType(); RelDataType innerType = inner.getType(); RelDataType valueType = value.getType(); boolean outerCheck = RexUtil.canReinterpretOverflow(outer); boolean innerCheck = RexUtil.canReinterpretOverflow(inner); if ((outerType.getSqlTypeName() != valueType.getSqlTypeName()) || (outerType.getPrecision() != valueType.getPrecision()) || (outerType.getScale() != valueType.getScale())) { return false; } if (valueType.isNullable() && (!innerType.isNullable() || !outerType.isNullable())) { return false; } if (innerType.isNullable() && !outerType.isNullable()) { return false; } // One would think that we could go from Nullable -> Not Nullable // since we are substituting a general type with a more specific // type. However the optimizer doesn't like it. if (valueType.isNullable() != outerType.isNullable()) { return false; } if (innerCheck || outerCheck) { return false; } return true; } } /** Rule configuration. */ @Value.Immutable public interface Config extends RelRule.Config { Config DEFAULT = ImmutableReduceDecimalsRule.Config.of() .withOperandSupplier(b -> b.operand(LogicalCalc.class).anyInputs()); @Override default ReduceDecimalsRule toRule() { return new ReduceDecimalsRule(this); } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy