All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.testing.statistics.MetricComparator Maven / Gradle / Ivy

There is a newer version: 458
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.testing.statistics;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.testing.MaterializedRow;
import io.trino.testing.QueryRunner;

import java.util.List;
import java.util.OptionalDouble;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.testing.TransactionBuilder.transaction;
import static java.lang.String.format;
import static java.util.stream.Collectors.joining;

final class MetricComparator
{
    private MetricComparator() {}

    static List getMetricComparisons(String query, QueryRunner runner, List metrics)
    {
        List estimatedValues = getEstimatedValues(metrics, query, runner);
        List actualValues = getActualValues(metrics, query, runner);

        ImmutableList.Builder metricComparisons = ImmutableList.builder();
        for (int i = 0; i < metrics.size(); ++i) {
            //noinspection unchecked
            metricComparisons.add(new MetricComparison(
                    metrics.get(i),
                    estimatedValues.get(i),
                    actualValues.get(i)));
        }
        return metricComparisons.build();
    }

    private static List getEstimatedValues(List metrics, String query, QueryRunner runner)
    {
        return transaction(runner.getTransactionManager(), runner.getPlannerContext().getMetadata(), runner.getAccessControl())
                .singleStatement()
                .execute(runner.getDefaultSession(), (Session session) -> getEstimatedValuesInternal(metrics, query, runner, session));
    }

    private static List getEstimatedValuesInternal(List metrics, String query, QueryRunner runner, Session session)
    // TODO inline back this method
    {
        Plan queryPlan = runner.createPlan(session, query);
        OutputNode outputNode = (OutputNode) queryPlan.getRoot();
        PlanNodeStatsEstimate outputNodeStats = queryPlan.getStatsAndCosts().getStats().getOrDefault(queryPlan.getRoot().getId(), PlanNodeStatsEstimate.unknown());
        StatsContext statsContext = buildStatsContext(outputNode);
        return getEstimatedValues(metrics, outputNodeStats, statsContext);
    }

    private static StatsContext buildStatsContext(OutputNode outputNode)
    {
        ImmutableMap.Builder columnSymbols = ImmutableMap.builder();
        for (int columnId = 0; columnId < outputNode.getColumnNames().size(); ++columnId) {
            columnSymbols.put(outputNode.getColumnNames().get(columnId), outputNode.getOutputSymbols().get(columnId));
        }
        return new StatsContext(columnSymbols.buildOrThrow());
    }

    private static List getActualValues(List metrics, String query, QueryRunner runner)
    {
        String statsQuery = "SELECT "
                + metrics.stream().map(Metric::getComputingAggregationSql).collect(joining(","))
                + " FROM (" + query + ")";
        try {
            MaterializedRow actualValuesRow = getOnlyElement(runner.execute(statsQuery).getMaterializedRows());

            ImmutableList.Builder actualValues = ImmutableList.builder();
            for (int i = 0; i < metrics.size(); ++i) {
                actualValues.add(metrics.get(i).getValueFromAggregationQueryResult(actualValuesRow.getField(i)));
            }
            return actualValues.build();
        }
        catch (Exception e) {
            throw new RuntimeException(format("Failed to execute query to compute actual values: %s", statsQuery), e);
        }
    }

    private static List getEstimatedValues(List metrics, PlanNodeStatsEstimate outputNodeStatisticsEstimates, StatsContext statsContext)
    {
        return metrics.stream()
                .map(metric -> metric.getValueFromPlanNodeEstimate(outputNodeStatisticsEstimates, statsContext))
                .collect(toImmutableList());
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy