All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.xpack.esql.optimizer.rules.CombineProjections Maven / Gradle / Ivy

There is a newer version: 8.16.1
Show newest version
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

package org.elasticsearch.xpack.esql.optimizer.rules;

import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules;
import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Project;

import java.util.ArrayList;
import java.util.List;

public final class CombineProjections extends OptimizerRules.OptimizerRule {

    public CombineProjections() {
        super(OptimizerRules.TransformDirection.UP);
    }

    @Override
    @SuppressWarnings("unchecked")
    protected LogicalPlan rule(UnaryPlan plan) {
        LogicalPlan child = plan.child();

        if (plan instanceof Project project) {
            if (child instanceof Project p) {
                // eliminate lower project but first replace the aliases in the upper one
                project = p.withProjections(combineProjections(project.projections(), p.projections()));
                child = project.child();
                plan = project;
                // don't return the plan since the grandchild (now child) might be an aggregate that could not be folded on the way up
                // e.g. stats c = count(x) | project c, c as x | project x
                // try to apply the rule again opportunistically as another node might be pushed in (a limit might be pushed in)
            }
            // check if the projection eliminates certain aggregates
            // but be mindful of aliases to existing aggregates that we don't want to duplicate to avoid redundant work
            if (child instanceof Aggregate a) {
                var aggs = a.aggregates();
                var newAggs = projectAggregations(project.projections(), aggs);
                // project can be fully removed
                if (newAggs != null) {
                    var newGroups = replacePrunedAliasesUsedInGroupBy(a.groupings(), aggs, newAggs);
                    plan = new Aggregate(a.source(), a.child(), a.aggregateType(), newGroups, newAggs);
                }
            }
            return plan;
        }

        // Agg with underlying Project (group by on sub-queries)
        if (plan instanceof Aggregate a) {
            if (child instanceof Project p) {
                var groupings = a.groupings();
                List groupingAttrs = new ArrayList<>(a.groupings().size());
                for (Expression grouping : groupings) {
                    if (grouping instanceof Attribute attribute) {
                        groupingAttrs.add(attribute);
                    } else {
                        // After applying ReplaceStatsNestedExpressionWithEval, groupings can only contain attributes.
                        throw new EsqlIllegalArgumentException("Expected an Attribute, got {}", grouping);
                    }
                }
                plan = new Aggregate(
                    a.source(),
                    p.child(),
                    a.aggregateType(),
                    combineUpperGroupingsAndLowerProjections(groupingAttrs, p.projections()),
                    combineProjections(a.aggregates(), p.projections())
                );
            }
        }

        return plan;
    }

    // variant of #combineProjections specialized for project followed by agg due to the rewrite rules applied on aggregations
    // this method tries to combine the projections by paying attention to:
    // - aggregations that are projected away - remove them
    // - aliases in the project that point to aggregates - keep them in place (to avoid duplicating the aggs)
    private static List projectAggregations(
        List upperProjection,
        List lowerAggregations
    ) {
        AttributeSet seen = new AttributeSet();
        for (NamedExpression upper : upperProjection) {
            Expression unwrapped = Alias.unwrap(upper);
            // projection contains an inner alias (point to an existing fields inside the projection)
            if (seen.contains(unwrapped)) {
                return null;
            }
            seen.add(Expressions.attribute(unwrapped));
        }

        lowerAggregations = combineProjections(upperProjection, lowerAggregations);

        return lowerAggregations;
    }

    // normally only the upper projections should survive but since the lower list might have aliases definitions
    // that might be reused by the upper one, these need to be replaced.
    // for example an alias defined in the lower list might be referred in the upper - without replacing it the alias becomes invalid
    private static List combineProjections(List upper, List lower) {

        // collect named expressions declaration in the lower list
        AttributeMap namedExpressions = new AttributeMap<>();
        // while also collecting the alias map for resolving the source (f1 = 1, f2 = f1, etc..)
        AttributeMap aliases = new AttributeMap<>();
        for (NamedExpression ne : lower) {
            // record the alias
            aliases.put(ne.toAttribute(), Alias.unwrap(ne));

            // record named expression as is
            if (ne instanceof Alias as) {
                Expression child = as.child();
                namedExpressions.put(ne.toAttribute(), as.replaceChild(aliases.resolve(child, child)));
            }
        }
        List replaced = new ArrayList<>();

        // replace any matching attribute with a lower alias (if there's a match)
        // but clean-up non-top aliases at the end
        for (NamedExpression ne : upper) {
            NamedExpression replacedExp = (NamedExpression) ne.transformUp(Attribute.class, a -> namedExpressions.resolve(a, a));
            replaced.add((NamedExpression) trimNonTopLevelAliases(replacedExp));
        }
        return replaced;
    }

    private static List combineUpperGroupingsAndLowerProjections(
        List upperGroupings,
        List lowerProjections
    ) {
        // Collect the alias map for resolving the source (f1 = 1, f2 = f1, etc..)
        AttributeMap aliases = new AttributeMap<>();
        for (NamedExpression ne : lowerProjections) {
            // Projections are just aliases for attributes, so casting is safe.
            aliases.put(ne.toAttribute(), (Attribute) Alias.unwrap(ne));
        }

        // Replace any matching attribute directly with the aliased attribute from the projection.
        AttributeSet replaced = new AttributeSet();
        for (Attribute attr : upperGroupings) {
            // All substitutions happen before; groupings must be attributes at this point.
            replaced.add(aliases.resolve(attr, attr));
        }
        return new ArrayList<>(replaced);
    }

    /**
     * Replace grouping alias previously contained in the aggregations that might have been projected away.
     */
    private List replacePrunedAliasesUsedInGroupBy(
        List groupings,
        List oldAggs,
        List newAggs
    ) {
        AttributeMap removedAliases = new AttributeMap<>();
        AttributeSet currentAliases = new AttributeSet(Expressions.asAttributes(newAggs));

        // record only removed aliases
        for (NamedExpression ne : oldAggs) {
            if (ne instanceof Alias alias) {
                var attr = ne.toAttribute();
                if (currentAliases.contains(attr) == false) {
                    removedAliases.put(attr, alias.child());
                }
            }
        }

        if (removedAliases.isEmpty()) {
            return groupings;
        }

        var newGroupings = new ArrayList(groupings.size());
        for (Expression group : groupings) {
            var transformed = group.transformUp(Attribute.class, a -> removedAliases.resolve(a, a));
            if (Expressions.anyMatch(newGroupings, g -> Expressions.equalsAsAttribute(g, transformed)) == false) {
                newGroupings.add(transformed);
            }
        }

        return newGroupings;
    }

    public static Expression trimNonTopLevelAliases(Expression e) {
        return e instanceof Alias a ? a.replaceChild(trimAliases(a.child())) : trimAliases(e);
    }

    private static Expression trimAliases(Expression e) {
        return e.transformDown(Alias.class, Alias::child);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy