All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.prestosql.sql.planner.iterative.rule.UnwrapCastInComparison Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.OperatorNotFoundException;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.DoubleType;
import io.prestosql.spi.type.LongTimestampWithTimeZone;
import io.prestosql.spi.type.RealType;
import io.prestosql.spi.type.TimeWithTimeZoneType;
import io.prestosql.spi.type.TimeZoneKey;
import io.prestosql.spi.type.TimestampType;
import io.prestosql.spi.type.TimestampWithTimeZoneType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeOperators;
import io.prestosql.sql.InterpretedFunctionInvoker;
import io.prestosql.sql.planner.ExpressionInterpreter;
import io.prestosql.sql.planner.LiteralEncoder;
import io.prestosql.sql.planner.NoOpSymbolResolver;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.IsNotNullPredicate;
import io.prestosql.sql.tree.IsNullPredicate;
import io.prestosql.sql.tree.NullLiteral;
import io.prestosql.type.TypeCoercion;

import java.lang.invoke.MethodHandle;
import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.time.zone.ZoneOffsetTransition;
import java.util.Optional;

import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.spi.type.BooleanType.BOOLEAN;
import static io.prestosql.spi.type.DateTimeEncoding.packDateTimeWithZone;
import static io.prestosql.spi.type.DateTimeEncoding.unpackMillisUtc;
import static io.prestosql.spi.type.DateTimeEncoding.unpackZoneKey;
import static io.prestosql.spi.type.DoubleType.DOUBLE;
import static io.prestosql.spi.type.IntegerType.INTEGER;
import static io.prestosql.spi.type.RealType.REAL;
import static io.prestosql.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND;
import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN;
import static io.prestosql.sql.ExpressionUtils.and;
import static io.prestosql.sql.ExpressionUtils.or;
import static io.prestosql.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.EQUAL;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.GREATER_THAN;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.NOT_EQUAL;
import static java.lang.Float.intBitsToFloat;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

/**
 * Given s of type S, a constant expression t of type T, and when an implicit
 * cast exists between S->T, converts expression of the form:
 *
 * 
 * CAST(s as T) = t
 * 
*

* into * *

 * s = CAST(t as S)
 * 
*

* For example: * *

 * CAST(x AS bigint) = bigint '1'
 * 
*

* turns into * *

 * x = smallint '1'
 * 
*

* It can simplify expressions that are known to be true or false, and * remove the comparisons altogether. For example, give x::smallint, * for an expression like: * *

 * CAST(x AS bigint) > bigint '10000000'
 * 
*/ public class UnwrapCastInComparison extends ExpressionRewriteRuleSet { public UnwrapCastInComparison(Metadata metadata, TypeOperators typeOperators, TypeAnalyzer typeAnalyzer) { super(createRewrite(metadata, typeOperators, typeAnalyzer)); } private static ExpressionRewriter createRewrite(Metadata metadata, TypeOperators typeOperators, TypeAnalyzer typeAnalyzer) { requireNonNull(metadata, "metadata is null"); requireNonNull(typeAnalyzer, "typeAnalyzer is null"); return (expression, context) -> unwrapCasts(context.getSession(), metadata, typeOperators, typeAnalyzer, context.getSymbolAllocator().getTypes(), expression); } public static Expression unwrapCasts(Session session, Metadata metadata, TypeOperators typeOperators, TypeAnalyzer typeAnalyzer, TypeProvider types, Expression expression) { if (SystemSessionProperties.isUnwrapCasts(session)) { return ExpressionTreeRewriter.rewriteWith(new Visitor(metadata, typeOperators, typeAnalyzer, session, types), expression); } return expression; } private static class Visitor extends io.prestosql.sql.tree.ExpressionRewriter { private final Metadata metadata; private final TypeOperators typeOperators; private final TypeAnalyzer typeAnalyzer; private final Session session; private final TypeProvider types; private final InterpretedFunctionInvoker functionInvoker; private final LiteralEncoder literalEncoder; public Visitor(Metadata metadata, TypeOperators typeOperators, TypeAnalyzer typeAnalyzer, Session session, TypeProvider types) { this.metadata = requireNonNull(metadata, "metadata is null"); this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); this.functionInvoker = new InterpretedFunctionInvoker(metadata); this.literalEncoder = new LiteralEncoder(metadata); } @Override public Expression rewriteComparisonExpression(ComparisonExpression node, Void context, ExpressionTreeRewriter treeRewriter) { ComparisonExpression expression = (ComparisonExpression) treeRewriter.defaultRewrite((Expression) node, null); return unwrapCast(expression); } private Expression unwrapCast(ComparisonExpression expression) { // Canonicalization is handled by CanonicalizeExpressionRewriter if (!(expression.getLeft() instanceof Cast)) { return expression; } Object right = ExpressionInterpreter.expressionOptimizer(expression.getRight(), metadata, session, typeAnalyzer.getTypes(session, types, expression.getRight())) .optimize(NoOpSymbolResolver.INSTANCE); Cast cast = (Cast) expression.getLeft(); ComparisonExpression.Operator operator = expression.getOperator(); if (right == null || right instanceof NullLiteral) { switch (operator) { case EQUAL: case NOT_EQUAL: case LESS_THAN: case LESS_THAN_OR_EQUAL: case GREATER_THAN: case GREATER_THAN_OR_EQUAL: return new Cast(new NullLiteral(), toSqlType(BOOLEAN)); case IS_DISTINCT_FROM: return new IsNotNullPredicate(cast); default: throw new UnsupportedOperationException("Not yet implemented"); } } if (right instanceof Expression) { return expression; } Type sourceType = typeAnalyzer.getType(session, types, cast.getExpression()); Type targetType = typeAnalyzer.getType(session, types, expression.getRight()); if (targetType instanceof TimestampWithTimeZoneType) { // Note: two TIMESTAMP WITH TIME ZONE values differing in zone only (same instant) are considered equal. right = withTimeZone(((TimestampWithTimeZoneType) targetType), right, session.getTimeZoneKey()); } if (!hasInjectiveImplicitCoercion(sourceType, targetType, right)) { return expression; } // Handle comparison against NaN. // It must be done before source type range bounds are compared to target value. if (isFloatingPointNaN(targetType, right)) { switch (operator) { case EQUAL: case GREATER_THAN: case GREATER_THAN_OR_EQUAL: case LESS_THAN: case LESS_THAN_OR_EQUAL: return falseIfNotNull(cast.getExpression()); case NOT_EQUAL: return trueIfNotNull(cast.getExpression()); case IS_DISTINCT_FROM: if (!typeHasNaN(sourceType)) { return TRUE_LITERAL; } else { // NaN on the right of comparison will be cast to source type later break; } default: throw new UnsupportedOperationException("Not yet implemented: " + operator); } } ResolvedFunction sourceToTarget = metadata.getCoercion(sourceType, targetType); Optional sourceRange = sourceType.getRange(); if (sourceRange.isPresent()) { Object max = sourceRange.get().getMax(); Object maxInTargetType = coerce(max, sourceToTarget); // NaN values of `right` are excluded at this point. Otherwise, NaN would be recognized as // greater than source type upper bound, and incorrect expression might be derived. int upperBoundComparison = compare(targetType, right, maxInTargetType); if (upperBoundComparison > 0) { // larger than maximum representable value switch (operator) { case EQUAL: case GREATER_THAN: case GREATER_THAN_OR_EQUAL: return falseIfNotNull(cast.getExpression()); case NOT_EQUAL: case LESS_THAN: case LESS_THAN_OR_EQUAL: return trueIfNotNull(cast.getExpression()); case IS_DISTINCT_FROM: return TRUE_LITERAL; default: throw new UnsupportedOperationException("Not yet implemented: " + operator); } } if (upperBoundComparison == 0) { // equal to max representable value switch (operator) { case GREATER_THAN: return falseIfNotNull(cast.getExpression()); case GREATER_THAN_OR_EQUAL: return new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(max, sourceType)); case LESS_THAN_OR_EQUAL: return trueIfNotNull(cast.getExpression()); case LESS_THAN: return new ComparisonExpression(NOT_EQUAL, cast.getExpression(), literalEncoder.toExpression(max, sourceType)); case EQUAL: case NOT_EQUAL: case IS_DISTINCT_FROM: return new ComparisonExpression(operator, cast.getExpression(), literalEncoder.toExpression(max, sourceType)); default: throw new UnsupportedOperationException("Not yet implemented: " + operator); } } Object min = sourceRange.get().getMin(); Object minInTargetType = coerce(min, sourceToTarget); int lowerBoundComparison = compare(targetType, right, minInTargetType); if (lowerBoundComparison < 0) { // smaller than minimum representable value switch (operator) { case NOT_EQUAL: case GREATER_THAN: case GREATER_THAN_OR_EQUAL: return trueIfNotNull(cast.getExpression()); case EQUAL: case LESS_THAN: case LESS_THAN_OR_EQUAL: return falseIfNotNull(cast.getExpression()); case IS_DISTINCT_FROM: return TRUE_LITERAL; default: throw new UnsupportedOperationException("Not yet implemented: " + operator); } } if (lowerBoundComparison == 0) { // equal to min representable value switch (operator) { case LESS_THAN: return falseIfNotNull(cast.getExpression()); case LESS_THAN_OR_EQUAL: return new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(min, sourceType)); case GREATER_THAN_OR_EQUAL: return trueIfNotNull(cast.getExpression()); case GREATER_THAN: return new ComparisonExpression(NOT_EQUAL, cast.getExpression(), literalEncoder.toExpression(min, sourceType)); case EQUAL: case NOT_EQUAL: case IS_DISTINCT_FROM: return new ComparisonExpression(operator, cast.getExpression(), literalEncoder.toExpression(min, sourceType)); default: throw new UnsupportedOperationException("Not yet implemented: " + operator); } } } ResolvedFunction targetToSource; try { targetToSource = metadata.getCoercion(targetType, sourceType); } catch (OperatorNotFoundException e) { // Without a cast between target -> source, there's nothing more we can do return expression; } Object literalInSourceType; try { literalInSourceType = coerce(right, targetToSource); } catch (PrestoException e) { // A failure to cast from target -> source type could be because: // 1. missing cast // 2. bad implementation // 3. out of range or otherwise unrepresentable value // Since we can't distinguish between those cases, take the conservative option // and bail out. return expression; } Object roundtripLiteral = coerce(literalInSourceType, sourceToTarget); int literalVsRoundtripped = compare(targetType, right, roundtripLiteral); if (literalVsRoundtripped > 0) { // cast rounded down switch (operator) { case EQUAL: return falseIfNotNull(cast.getExpression()); case NOT_EQUAL: return trueIfNotNull(cast.getExpression()); case IS_DISTINCT_FROM: return TRUE_LITERAL; case LESS_THAN: case LESS_THAN_OR_EQUAL: if (sourceRange.isPresent() && compare(sourceType, sourceRange.get().getMin(), literalInSourceType) == 0) { return new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(literalInSourceType, sourceType)); } return new ComparisonExpression(LESS_THAN_OR_EQUAL, cast.getExpression(), literalEncoder.toExpression(literalInSourceType, sourceType)); case GREATER_THAN: case GREATER_THAN_OR_EQUAL: // We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot produce a value // larger than the next value in the source type return new ComparisonExpression(GREATER_THAN, cast.getExpression(), literalEncoder.toExpression(literalInSourceType, sourceType)); default: throw new UnsupportedOperationException("Not yet implemented: " + operator); } } if (literalVsRoundtripped < 0) { // cast rounded up switch (operator) { case EQUAL: return falseIfNotNull(cast.getExpression()); case NOT_EQUAL: return trueIfNotNull(cast.getExpression()); case IS_DISTINCT_FROM: return TRUE_LITERAL; case LESS_THAN: case LESS_THAN_OR_EQUAL: // We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot produce a value // smaller than the next value in the source type return new ComparisonExpression(LESS_THAN, cast.getExpression(), literalEncoder.toExpression(literalInSourceType, sourceType)); case GREATER_THAN: case GREATER_THAN_OR_EQUAL: if (sourceRange.isPresent() && compare(sourceType, sourceRange.get().getMax(), literalInSourceType) == 0) { return new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(literalInSourceType, sourceType)); } return new ComparisonExpression(GREATER_THAN_OR_EQUAL, cast.getExpression(), literalEncoder.toExpression(literalInSourceType, sourceType)); default: throw new UnsupportedOperationException("Not yet implemented: " + operator); } } return new ComparisonExpression(operator, cast.getExpression(), literalEncoder.toExpression(literalInSourceType, sourceType)); } private boolean hasInjectiveImplicitCoercion(Type source, Type target, Object value) { if ((source.equals(BIGINT) && target.equals(DOUBLE)) || (source.equals(BIGINT) && target.equals(REAL)) || (source.equals(INTEGER) && target.equals(REAL))) { // Not every BIGINT fits in DOUBLE/REAL due to 64 bit vs 53-bit/23-bit mantissa. Similarly, // not every INTEGER fits in a REAL (32-bit vs 23-bit mantissa) if (target.equals(DOUBLE)) { double doubleValue = (double) value; return doubleValue > Long.MAX_VALUE || doubleValue < Long.MIN_VALUE || Double.isNaN(doubleValue) || (doubleValue > -1L << 53 && doubleValue < 1L << 53); // in (-2^53, 2^53), bigint follows an injective implicit coercion w.r.t double } else { float realValue = intBitsToFloat(toIntExact((long) value)); return (source.equals(BIGINT) && (realValue > Long.MAX_VALUE || realValue < Long.MIN_VALUE)) || (source.equals(INTEGER) && (realValue > Integer.MAX_VALUE || realValue < Integer.MIN_VALUE)) || Float.isNaN(realValue) || (realValue > -1L << 23 && realValue < 1L << 23); // in (-2^23, 2^23), bigint (and integer) follows an injective implicit coercion w.r.t real } } if (source instanceof DecimalType) { int precision = ((DecimalType) source).getPrecision(); if (precision > 15 && target.equals(DOUBLE)) { // decimal(p,s) with p > 15 doesn't fit in a double without loss return false; } if (precision > 7 && target.equals(REAL)) { // decimal(p,s) with p > 7 doesn't fit in a double without loss return false; } } if (target instanceof TimestampWithTimeZoneType) { TimestampWithTimeZoneType timestampWithTimeZoneType = (TimestampWithTimeZoneType) target; if (source instanceof TimestampType) { // Cast from TIMESTAMP WITH TIME ZONE to TIMESTAMP and back to TIMESTAMP WITH TIME ZONE does not round trip, unless the value's zone is equal to sesion zone if (!getTimeZone(timestampWithTimeZoneType, value).equals(session.getTimeZoneKey())) { return false; } // Cast from TIMESTAMP to TIMESTAMP WITH TIME ZONE is not monotonic when there is a forward DST change in the session zone if (!isTimestampToTimestampWithTimeZoneInjectiveAt(session.getTimeZoneKey().getZoneId(), getInstantWithTruncation(timestampWithTimeZoneType, value))) { return false; } return true; } // CAST from TIMESTAMP WITH TIME ZONE to d and back to TIMESTAMP WITH TIME ZONE does not round trip for most types d // TODO add test coverage // TODO (https://github.com/prestosql/presto/issues/5798) handle DATE -> TIMESTAMP WITH TIME ZONE return false; } if (target instanceof TimeWithTimeZoneType) { // For example, CAST from TIME WITH TIME ZONE to TIME and back to TIME WITH TIME ZONE does not round trip // TODO add test coverage return false; } // Well-behaved implicit casts are injective return new TypeCoercion(metadata::getType).canCoerce(source, target); } private Object coerce(Object value, ResolvedFunction coercion) { return functionInvoker.invoke(coercion, session.toConnectorSession(), value); } private boolean typeHasNaN(Type type) { return type instanceof DoubleType || type instanceof RealType; } private int compare(Type type, Object first, Object second) { requireNonNull(first, "first is null"); requireNonNull(second, "second is null"); MethodHandle comparisonOperator = typeOperators.getComparisonOperator(type, InvocationConvention.simpleConvention(FAIL_ON_NULL, NEVER_NULL, NEVER_NULL)); try { return (int) (long) comparisonOperator.invoke(first, second); } catch (Throwable throwable) { Throwables.throwIfUnchecked(throwable); throw new PrestoException(GENERIC_INTERNAL_ERROR, throwable); } } } /** * Replace time zone component of a {@link TimestampWithTimeZoneType} value with a given one, preserving point in time * (equivalent to {@link java.time.ZonedDateTime#withZoneSameInstant}. */ private static Object withTimeZone(TimestampWithTimeZoneType type, Object value, TimeZoneKey newZone) { if (type.isShort()) { return packDateTimeWithZone(unpackMillisUtc((long) value), newZone); } LongTimestampWithTimeZone longTimestampWithTimeZone = (LongTimestampWithTimeZone) value; return LongTimestampWithTimeZone.fromEpochMillisAndFraction(longTimestampWithTimeZone.getEpochMillis(), longTimestampWithTimeZone.getPicosOfMilli(), newZone); } private static TimeZoneKey getTimeZone(TimestampWithTimeZoneType type, Object value) { if (type.isShort()) { return unpackZoneKey(((long) value)); } return TimeZoneKey.getTimeZoneKey(((LongTimestampWithTimeZone) value).getTimeZoneKey()); } @VisibleForTesting static boolean isTimestampToTimestampWithTimeZoneInjectiveAt(ZoneId zone, Instant instant) { ZoneOffsetTransition transition = zone.getRules().previousTransition(instant.plusNanos(1)); if (transition != null) { // DST change forward and the instant is ambiguous, being within the 'gap' area non-monotonic remapping if (!transition.getDuration().isNegative() && !transition.getDateTimeAfter().minusNanos(1).atZone(zone).toInstant().isBefore(instant)) { return false; } } return true; } private static Instant getInstantWithTruncation(TimestampWithTimeZoneType type, Object value) { if (type.isShort()) { return Instant.ofEpochMilli(unpackMillisUtc(((long) value))); } LongTimestampWithTimeZone longTimestampWithTimeZone = (LongTimestampWithTimeZone) value; return Instant.ofEpochMilli(longTimestampWithTimeZone.getEpochMillis()) .plus(longTimestampWithTimeZone.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND, ChronoUnit.NANOS); } private static Expression falseIfNotNull(Expression argument) { return and(new IsNullPredicate(argument), new NullLiteral()); } private static Expression trueIfNotNull(Expression argument) { return or(new IsNotNullPredicate(argument), new NullLiteral()); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy