All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.plan.StatisticAggregations Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.plan;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.AggregationFunctionMetadata;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.plan.AggregationNode.Aggregation;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

public class StatisticAggregations
{
    private final Map aggregations;
    private final List groupingSymbols;

    @JsonCreator
    public StatisticAggregations(
            @JsonProperty("aggregations") Map aggregations,
            @JsonProperty("groupingSymbols") List groupingSymbols)
    {
        this.aggregations = ImmutableMap.copyOf(requireNonNull(aggregations, "aggregations is null"));
        this.groupingSymbols = ImmutableList.copyOf(requireNonNull(groupingSymbols, "groupingSymbols is null"));
    }

    @JsonProperty
    public Map getAggregations()
    {
        return aggregations;
    }

    @JsonProperty
    public List getGroupingSymbols()
    {
        return groupingSymbols;
    }

    public Parts createPartialAggregations(SymbolAllocator symbolAllocator, Session session, PlannerContext plannerContext)
    {
        ImmutableMap.Builder partialAggregation = ImmutableMap.builder();
        ImmutableMap.Builder finalAggregation = ImmutableMap.builder();
        ImmutableMap.Builder mappings = ImmutableMap.builder();
        for (Map.Entry entry : aggregations.entrySet()) {
            Aggregation originalAggregation = entry.getValue();
            ResolvedFunction resolvedFunction = originalAggregation.getResolvedFunction();
            AggregationFunctionMetadata functionMetadata = plannerContext.getMetadata().getAggregationFunctionMetadata(session, resolvedFunction);
            List intermediateTypes = functionMetadata.getIntermediateTypes().stream()
                    .map(plannerContext.getTypeManager()::getType)
                    .collect(toImmutableList());
            Type intermediateType = intermediateTypes.size() == 1 ? intermediateTypes.get(0) : RowType.anonymous(intermediateTypes);
            Symbol partialSymbol = symbolAllocator.newSymbol(resolvedFunction.signature().getName().getFunctionName(), intermediateType);
            mappings.put(entry.getKey(), partialSymbol);
            partialAggregation.put(partialSymbol, new Aggregation(
                    resolvedFunction,
                    originalAggregation.getArguments(),
                    originalAggregation.isDistinct(),
                    originalAggregation.getFilter(),
                    originalAggregation.getOrderingScheme(),
                    originalAggregation.getMask()));
            finalAggregation.put(entry.getKey(),
                    new Aggregation(
                            resolvedFunction,
                            ImmutableList.of(partialSymbol.toSymbolReference()),
                            false,
                            Optional.empty(),
                            Optional.empty(),
                            Optional.empty()));
        }
        groupingSymbols.forEach(symbol -> mappings.put(symbol, symbol));
        return new Parts(
                new StatisticAggregations(partialAggregation.buildOrThrow(), groupingSymbols),
                new StatisticAggregations(finalAggregation.buildOrThrow(), groupingSymbols),
                mappings.buildOrThrow());
    }

    public static class Parts
    {
        private final StatisticAggregations partialAggregation;
        private final StatisticAggregations finalAggregation;
        private final Map mappings;

        public Parts(StatisticAggregations partialAggregation, StatisticAggregations finalAggregation, Map mappings)
        {
            this.partialAggregation = requireNonNull(partialAggregation, "partialAggregation is null");
            this.finalAggregation = requireNonNull(finalAggregation, "finalAggregation is null");
            this.mappings = ImmutableMap.copyOf(requireNonNull(mappings, "mappings is null"));
        }

        public StatisticAggregations getPartialAggregation()
        {
            return partialAggregation;
        }

        public StatisticAggregations getFinalAggregation()
        {
            return finalAggregation;
        }

        public Map getMappings()
        {
            return mappings;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy