Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.trino.cost.ComparisonStatsCalculator Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.cost;
import io.trino.sql.planner.Symbol;
import io.trino.sql.tree.ComparisonExpression;
import java.util.Optional;
import java.util.OptionalDouble;
import static io.trino.cost.SymbolStatsEstimate.buildFrom;
import static io.trino.util.MoreMath.averageExcludingNaNs;
import static io.trino.util.MoreMath.max;
import static io.trino.util.MoreMath.maxExcludeNaN;
import static io.trino.util.MoreMath.min;
import static io.trino.util.MoreMath.minExcludeNaN;
import static java.lang.Double.NEGATIVE_INFINITY;
import static java.lang.Double.NaN;
import static java.lang.Double.POSITIVE_INFINITY;
import static java.lang.Double.isFinite;
import static java.lang.Double.isNaN;
public final class ComparisonStatsCalculator
{
// We assume uniform distribution of values within each range.
// Within the overlapping range, we assume that all pairs of distinct values from both ranges exist.
// Based on the above, we estimate that half of the pairs of values will match inequality predicate on average.
public static final double OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT = 0.5;
private ComparisonStatsCalculator() {}
public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison(
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate expressionStatistics,
Optional expressionSymbol,
OptionalDouble literalValue,
ComparisonExpression.Operator operator)
{
switch (operator) {
case EQUAL:
return estimateExpressionEqualToLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue);
case NOT_EQUAL:
return estimateExpressionNotEqualToLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue);
case LESS_THAN:
case LESS_THAN_OR_EQUAL:
return estimateExpressionLessThanLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue);
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
return estimateExpressionGreaterThanLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue);
case IS_DISTINCT_FROM:
return PlanNodeStatsEstimate.unknown();
}
throw new IllegalArgumentException("Unexpected comparison operator: " + operator);
}
private static PlanNodeStatsEstimate estimateExpressionEqualToLiteral(
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate expressionStatistics,
Optional expressionSymbol,
OptionalDouble literalValue)
{
StatisticRange filterRange;
if (literalValue.isPresent()) {
filterRange = new StatisticRange(literalValue.getAsDouble(), literalValue.getAsDouble(), 1);
}
else {
filterRange = new StatisticRange(NEGATIVE_INFINITY, POSITIVE_INFINITY, 1);
}
return estimateFilterRange(inputStatistics, expressionStatistics, expressionSymbol, filterRange);
}
private static PlanNodeStatsEstimate estimateExpressionNotEqualToLiteral(
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate expressionStatistics,
Optional expressionSymbol,
OptionalDouble literalValue)
{
StatisticRange expressionRange = StatisticRange.from(expressionStatistics);
StatisticRange filterRange;
if (literalValue.isPresent()) {
filterRange = new StatisticRange(literalValue.getAsDouble(), literalValue.getAsDouble(), 1);
}
else {
filterRange = new StatisticRange(NEGATIVE_INFINITY, POSITIVE_INFINITY, 1);
}
StatisticRange intersectRange = expressionRange.intersect(filterRange);
double filterFactor = 1 - expressionRange.overlapPercentWith(intersectRange);
PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics);
estimate.setOutputRowCount(filterFactor * (1 - expressionStatistics.getNullsFraction()) * inputStatistics.getOutputRowCount());
if (expressionSymbol.isPresent()) {
SymbolStatsEstimate symbolNewEstimate = buildFrom(expressionStatistics)
.setNullsFraction(0.0)
.setDistinctValuesCount(max(expressionStatistics.getDistinctValuesCount() - 1, 0))
.build();
estimate = estimate.addSymbolStatistics(expressionSymbol.get(), symbolNewEstimate);
}
return estimate.build();
}
private static PlanNodeStatsEstimate estimateExpressionLessThanLiteral(
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate expressionStatistics,
Optional expressionSymbol,
OptionalDouble literalValue)
{
StatisticRange filterRange = new StatisticRange(NEGATIVE_INFINITY, literalValue.orElse(POSITIVE_INFINITY), NaN);
return estimateFilterRange(inputStatistics, expressionStatistics, expressionSymbol, filterRange);
}
private static PlanNodeStatsEstimate estimateExpressionGreaterThanLiteral(
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate expressionStatistics,
Optional expressionSymbol,
OptionalDouble literalValue)
{
StatisticRange filterRange = new StatisticRange(literalValue.orElse(NEGATIVE_INFINITY), POSITIVE_INFINITY, NaN);
return estimateFilterRange(inputStatistics, expressionStatistics, expressionSymbol, filterRange);
}
private static PlanNodeStatsEstimate estimateFilterRange(
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate expressionStatistics,
Optional expressionSymbol,
StatisticRange filterRange)
{
StatisticRange expressionRange = StatisticRange.from(expressionStatistics);
StatisticRange intersectRange = expressionRange.intersect(filterRange);
double filterFactor = expressionRange.overlapPercentWith(intersectRange);
PlanNodeStatsEstimate estimate = inputStatistics.mapOutputRowCount(rowCount -> filterFactor * (1 - expressionStatistics.getNullsFraction()) * rowCount);
if (expressionSymbol.isPresent()) {
SymbolStatsEstimate symbolNewEstimate =
SymbolStatsEstimate.builder()
.setAverageRowSize(expressionStatistics.getAverageRowSize())
.setStatisticsRange(intersectRange)
.setNullsFraction(0.0)
.build();
estimate = estimate.mapSymbolColumnStatistics(expressionSymbol.get(), oldStats -> symbolNewEstimate);
}
return estimate;
}
public static PlanNodeStatsEstimate estimateExpressionToExpressionComparison(
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate leftExpressionStatistics,
Optional leftExpressionSymbol,
SymbolStatsEstimate rightExpressionStatistics,
Optional rightExpressionSymbol,
ComparisonExpression.Operator operator)
{
switch (operator) {
case EQUAL:
return estimateExpressionEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionSymbol, rightExpressionStatistics, rightExpressionSymbol);
case NOT_EQUAL:
return estimateExpressionNotEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionSymbol, rightExpressionStatistics, rightExpressionSymbol);
case LESS_THAN:
case LESS_THAN_OR_EQUAL:
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
return estimateExpressionToExpressionInequality(
operator,
inputStatistics,
leftExpressionStatistics,
leftExpressionSymbol,
rightExpressionStatistics,
rightExpressionSymbol);
case IS_DISTINCT_FROM:
return PlanNodeStatsEstimate.unknown();
}
throw new IllegalArgumentException("Unexpected comparison operator: " + operator);
}
private static PlanNodeStatsEstimate estimateExpressionEqualToExpression(
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate leftExpressionStatistics,
Optional leftExpressionSymbol,
SymbolStatsEstimate rightExpressionStatistics,
Optional rightExpressionSymbol)
{
if (isNaN(leftExpressionStatistics.getDistinctValuesCount()) || isNaN(rightExpressionStatistics.getDistinctValuesCount())) {
return PlanNodeStatsEstimate.unknown();
}
StatisticRange leftExpressionRange = StatisticRange.from(leftExpressionStatistics);
StatisticRange rightExpressionRange = StatisticRange.from(rightExpressionStatistics);
StatisticRange intersect = leftExpressionRange.intersect(rightExpressionRange);
double nullsFilterFactor = (1 - leftExpressionStatistics.getNullsFraction()) * (1 - rightExpressionStatistics.getNullsFraction());
double leftNdv = leftExpressionRange.getDistinctValuesCount();
double rightNdv = rightExpressionRange.getDistinctValuesCount();
double filterFactor = 1.0 / max(leftNdv, rightNdv, 1);
double retainedNdv = min(leftNdv, rightNdv);
PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics)
.setOutputRowCount(inputStatistics.getOutputRowCount() * nullsFilterFactor * filterFactor);
SymbolStatsEstimate equalityStats = SymbolStatsEstimate.builder()
.setAverageRowSize(averageExcludingNaNs(leftExpressionStatistics.getAverageRowSize(), rightExpressionStatistics.getAverageRowSize()))
.setNullsFraction(0)
.setStatisticsRange(intersect)
.setDistinctValuesCount(retainedNdv)
.build();
leftExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics(symbol, equalityStats));
rightExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics(symbol, equalityStats));
return estimate.build();
}
private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression(
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate leftExpressionStatistics,
Optional leftExpressionSymbol,
SymbolStatsEstimate rightExpressionStatistics,
Optional rightExpressionSymbol)
{
double nullsFilterFactor = (1 - leftExpressionStatistics.getNullsFraction()) * (1 - rightExpressionStatistics.getNullsFraction());
PlanNodeStatsEstimate inputNullsFiltered = inputStatistics.mapOutputRowCount(size -> size * nullsFilterFactor);
SymbolStatsEstimate leftNullsFiltered = leftExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0);
SymbolStatsEstimate rightNullsFiltered = rightExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0);
PlanNodeStatsEstimate equalityStats = estimateExpressionEqualToExpression(
inputNullsFiltered,
leftNullsFiltered,
leftExpressionSymbol,
rightNullsFiltered,
rightExpressionSymbol);
if (equalityStats.isOutputRowCountUnknown()) {
return PlanNodeStatsEstimate.unknown();
}
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(inputNullsFiltered);
double equalityFilterFactor = equalityStats.getOutputRowCount() / inputNullsFiltered.getOutputRowCount();
if (!isFinite(equalityFilterFactor)) {
equalityFilterFactor = 0.0;
}
result.setOutputRowCount(inputNullsFiltered.getOutputRowCount() * (1 - equalityFilterFactor));
leftExpressionSymbol.ifPresent(symbol -> result.addSymbolStatistics(symbol, leftNullsFiltered));
rightExpressionSymbol.ifPresent(symbol -> result.addSymbolStatistics(symbol, rightNullsFiltered));
return result.build();
}
private static PlanNodeStatsEstimate estimateExpressionToExpressionInequality(
ComparisonExpression.Operator operator,
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate leftExpressionStatistics,
Optional leftExpressionSymbol,
SymbolStatsEstimate rightExpressionStatistics,
Optional rightExpressionSymbol)
{
if (leftExpressionStatistics.isUnknown() || rightExpressionStatistics.isUnknown()) {
return PlanNodeStatsEstimate.unknown();
}
if (isNaN(leftExpressionStatistics.getNullsFraction()) && isNaN(rightExpressionStatistics.getNullsFraction())) {
return PlanNodeStatsEstimate.unknown();
}
if (leftExpressionStatistics.statisticRange().isEmpty() || rightExpressionStatistics.statisticRange().isEmpty()) {
return inputStatistics.mapOutputRowCount(rowCount -> 0.0);
}
// We don't know the correlation between NULLs, so we take the max nullsFraction from the expression statistics
// to make a conservative estimate (nulls are fully correlated) for the NULLs filter factor
double nullsFilterFactor = 1 - maxExcludeNaN(leftExpressionStatistics.getNullsFraction(), rightExpressionStatistics.getNullsFraction());
switch (operator) {
case LESS_THAN:
case LESS_THAN_OR_EQUAL:
return estimateExpressionLessThanOrEqualToExpression(
inputStatistics,
leftExpressionStatistics,
leftExpressionSymbol,
rightExpressionStatistics,
rightExpressionSymbol,
nullsFilterFactor);
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
return estimateExpressionLessThanOrEqualToExpression(
inputStatistics,
rightExpressionStatistics,
rightExpressionSymbol,
leftExpressionStatistics,
leftExpressionSymbol,
nullsFilterFactor);
default:
throw new IllegalArgumentException("Unsupported inequality operator " + operator);
}
}
private static PlanNodeStatsEstimate estimateExpressionLessThanOrEqualToExpression(
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate leftExpressionStatistics,
Optional leftExpressionSymbol,
SymbolStatsEstimate rightExpressionStatistics,
Optional rightExpressionSymbol,
double nullsFilterFactor)
{
StatisticRange leftRange = StatisticRange.from(leftExpressionStatistics);
StatisticRange rightRange = StatisticRange.from(rightExpressionStatistics);
// left is always greater than right, no overlap
if (leftRange.getLow() > rightRange.getHigh()) {
return inputStatistics.mapOutputRowCount(rowCount -> 0.0);
}
// left is always lesser than right
if (leftRange.getHigh() < rightRange.getLow()) {
PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics);
leftExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics(
symbol,
leftExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0)));
rightExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics(
symbol,
rightExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0)));
return estimate.setOutputRowCount(inputStatistics.getOutputRowCount() * nullsFilterFactor)
.build();
}
PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics);
double leftOverlappingRangeFraction = leftRange.overlapPercentWith(rightRange);
double leftAlwaysLessRangeFraction;
if (leftRange.getLow() < rightRange.getLow()) {
leftAlwaysLessRangeFraction = min(
leftRange.overlapPercentWith(new StatisticRange(leftRange.getLow(), rightRange.getLow(), NaN)),
// Prevents expanding NDVs in case range fractions addition goes beyond 1 for infinite ranges
1 - leftOverlappingRangeFraction);
}
else {
leftAlwaysLessRangeFraction = 0;
}
leftExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics(
symbol,
SymbolStatsEstimate.builder()
.setLowValue(leftRange.getLow())
.setHighValue(minExcludeNaN(leftRange.getHigh(), rightRange.getHigh()))
.setAverageRowSize(leftExpressionStatistics.getAverageRowSize())
.setDistinctValuesCount(leftExpressionStatistics.getDistinctValuesCount() * (leftAlwaysLessRangeFraction + leftOverlappingRangeFraction))
.setNullsFraction(0)
.build()));
double rightOverlappingRangeFraction = rightRange.overlapPercentWith(leftRange);
double rightAlwaysGreaterRangeFraction;
if (leftRange.getHigh() < rightRange.getHigh()) {
rightAlwaysGreaterRangeFraction = min(
rightRange.overlapPercentWith(new StatisticRange(leftRange.getHigh(), rightRange.getHigh(), NaN)),
// Prevents expanding NDVs in case range fractions addition goes beyond 1 for infinite ranges
1 - rightOverlappingRangeFraction);
}
else {
rightAlwaysGreaterRangeFraction = 0;
}
rightExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics(
symbol,
SymbolStatsEstimate.builder()
.setLowValue(maxExcludeNaN(leftRange.getLow(), rightRange.getLow()))
.setHighValue(rightRange.getHigh())
.setAverageRowSize(rightExpressionStatistics.getAverageRowSize())
.setDistinctValuesCount(rightExpressionStatistics.getDistinctValuesCount() * (rightOverlappingRangeFraction + rightAlwaysGreaterRangeFraction))
.setNullsFraction(0)
.build()));
double filterFactor =
// all left range values which are below right range are selected
leftAlwaysLessRangeFraction +
// for pairs in overlapping range, only half of pairs are selected
leftOverlappingRangeFraction * rightOverlappingRangeFraction * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT +
// all pairs where left value is in overlapping range and right value is above left range are selected
leftOverlappingRangeFraction * rightAlwaysGreaterRangeFraction;
return estimate.setOutputRowCount(inputStatistics.getOutputRowCount() * nullsFilterFactor * filterFactor).build();
}
}