All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.prestosql.sql.planner.GroupingOperationRewriter Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.sql.planner;

import io.prestosql.sql.analyzer.FieldId;
import io.prestosql.sql.analyzer.RelationId;
import io.prestosql.sql.analyzer.ResolvedField;
import io.prestosql.sql.tree.ArithmeticBinaryExpression;
import io.prestosql.sql.tree.ArrayConstructor;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.GenericLiteral;
import io.prestosql.sql.tree.GroupingOperation;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.SubscriptExpression;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.prestosql.sql.tree.ArithmeticBinaryExpression.Operator.ADD;
import static java.util.Objects.requireNonNull;

public final class GroupingOperationRewriter
{
    private GroupingOperationRewriter() {}

    public static Expression rewriteGroupingOperation(GroupingOperation expression, List> groupingSets, Map, ResolvedField> columnReferenceFields, Optional groupIdSymbol)
    {
        requireNonNull(groupIdSymbol, "groupIdSymbol is null");

        // No GroupIdNode and a GROUPING() operation imply a single grouping, which
        // means that any columns specified as arguments to GROUPING() will be included
        // in the group and none of them will be aggregated over. Hence, re-write the
        // GroupingOperation to a constant literal of 0.
        // See SQL:2011:4.16.2 and SQL:2011:6.9.10.
        if (groupingSets.size() == 1) {
            return new LongLiteral("0");
        }
        else {
            checkState(groupIdSymbol.isPresent(), "groupId symbol is missing");

            RelationId relationId = columnReferenceFields.get(NodeRef.of(expression.getGroupingColumns().get(0))).getFieldId().getRelationId();

            List columns = expression.getGroupingColumns().stream()
                    .map(NodeRef::of)
                    .peek(groupingColumn -> checkState(columnReferenceFields.containsKey(groupingColumn), "the grouping column is not in the columnReferencesField map"))
                    .map(columnReferenceFields::get)
                    .map(ResolvedField::getFieldId)
                    .map(fieldId -> translateFieldToInteger(fieldId, relationId))
                    .collect(toImmutableList());

            List groupingResults = groupingSets.stream()
                    .map(groupingSet -> String.valueOf(calculateGrouping(groupingSet, columns)))
                    .map(LongLiteral::new)
                    .collect(toImmutableList());

            // It is necessary to add a 1 to the groupId because the underlying array is indexed starting at 1
            return new SubscriptExpression(
                    new ArrayConstructor(groupingResults),
                    new ArithmeticBinaryExpression(ADD, groupIdSymbol.get().toSymbolReference(), new GenericLiteral("BIGINT", "1")));
        }
    }

    private static int translateFieldToInteger(FieldId fieldId, RelationId requiredOriginRelationId)
    {
        // TODO: this section should be rewritten when support is added for GROUP BY columns to reference an outer scope
        checkState(fieldId.getRelationId().equals(requiredOriginRelationId), "grouping arguments must all come from the same relation");
        return fieldId.getFieldIndex();
    }

    /**
     * The grouping function is used in conjunction with GROUPING SETS, ROLLUP and CUBE to
     * indicate which columns are present in that grouping.
     *
     * 

The grouping function must be invoked with arguments that exactly match the columns * referenced in the corresponding GROUPING SET, ROLLUP or CUBE clause at the associated * query level. Those column arguments are not evaluated and instead the function is * re-written with the arguments below. * *

To compute the resulting bit set for a particular row, bits are assigned to the * argument columns with the rightmost column being the most significant bit. For a * given grouping, a bit is set to 0 if the corresponding column is included in the * grouping and 1 otherwise. For an example, see the SQL documentation for the * function. * * @param columns The column arguments with which the function was invoked * converted to ordinals with respect to the base table column ordering. * @param groupingSet A collection containing the ordinals of the * columns present in the grouping. * @return A bit set converted to decimal indicating which columns are present in * the grouping. If a column is NOT present in the grouping its corresponding * bit is set to 1 and to 0 if the column is present in the grouping. */ static long calculateGrouping(Set groupingSet, List columns) { long grouping = (1L << columns.size()) - 1; for (int index = 0; index < columns.size(); index++) { int column = columns.get(index); if (groupingSet.contains(column)) { // Leftmost argument to grouping() (i.e. when index = 0) corresponds to // the most significant bit in the result. That is why we shift 1L starting // from the columns.size() - 1 bit index. grouping = grouping & ~(1L << (columns.size() - 1 - index)); } } return grouping; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy