org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of x-pack-esql Show documentation
Show all versions of x-pack-esql Show documentation
The plugin that powers ESQL for Elasticsearch
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.planner;
import org.elasticsearch.compute.aggregation.Aggregator;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.operator.AggregationOperator;
import org.elasticsearch.compute.operator.HashAggregationOperator.HashAggregationOperatorFactory;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.PhysicalOperation;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import static java.util.Collections.emptyList;
public abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders {
private final AggregateMapper aggregateMapper = new AggregateMapper();
@Override
public final PhysicalOperation groupingPhysicalOperation(
AggregateExec aggregateExec,
PhysicalOperation source,
LocalExecutionPlannerContext context
) {
Layout.Builder layout = new Layout.Builder();
Operator.OperatorFactory operatorFactory = null;
AggregateExec.Mode mode = aggregateExec.getMode();
var aggregates = aggregateExec.aggregates();
var sourceLayout = source.layout;
AggregatorMode aggregatorMode;
if (mode == AggregateExec.Mode.FINAL) {
aggregatorMode = AggregatorMode.FINAL;
} else if (mode == AggregateExec.Mode.PARTIAL) {
if (aggregateExec.child() instanceof ExchangeSourceExec) {// the reducer step at data node (local) level
aggregatorMode = AggregatorMode.INTERMEDIATE;
} else {
aggregatorMode = AggregatorMode.INITIAL;
}
} else {
assert false : "Invalid aggregator mode [" + mode + "]";
aggregatorMode = AggregatorMode.SINGLE;
}
if (aggregateExec.groupings().isEmpty()) {
// not grouping
List aggregatorFactories = new ArrayList<>();
// append channels to the layout
if (mode == AggregateExec.Mode.FINAL) {
layout.append(aggregates);
} else {
layout.append(aggregateMapper.mapNonGrouping(aggregates));
}
// create the agg factories
aggregatesToFactory(
aggregates,
aggregatorMode,
sourceLayout,
false, // non-grouping
s -> aggregatorFactories.add(s.supplier.aggregatorFactory(s.mode))
);
if (aggregatorFactories.isEmpty() == false) {
operatorFactory = new AggregationOperator.AggregationOperatorFactory(aggregatorFactories, aggregatorMode);
}
} else {
// grouping
List aggregatorFactories = new ArrayList<>();
List groupSpecs = new ArrayList<>(aggregateExec.groupings().size());
for (Expression group : aggregateExec.groupings()) {
var groupAttribute = Expressions.attribute(group);
if (groupAttribute == null) {
throw new EsqlIllegalArgumentException("Unexpected non-named expression[{}] as grouping in [{}]", group, aggregateExec);
}
Layout.ChannelSet groupAttributeLayout = new Layout.ChannelSet(new HashSet<>(), groupAttribute.dataType());
groupAttributeLayout.nameIds().add(groupAttribute.id());
/*
* Check for aliasing in aggregates which occurs in two cases (due to combining project + stats):
* - before stats (keep x = a | stats by x) which requires the partial input to use a's channel
* - after stats (stats by a | keep x = a) which causes the output layout to refer to the follow-up alias
*/
for (NamedExpression agg : aggregates) {
if (agg instanceof Alias a) {
if (a.child() instanceof Attribute attr) {
if (groupAttribute.id().equals(attr.id())) {
groupAttributeLayout.nameIds().add(a.id());
// TODO: investigate whether a break could be used since it shouldn't be possible to have multiple
// attributes pointing to the same attribute
}
// partial mode only
// check if there's any alias used in grouping - no need for the final reduction since the intermediate data
// is in the output form
// if the group points to an alias declared in the aggregate, use the alias child as source
else if (mode == AggregateExec.Mode.PARTIAL) {
if (groupAttribute.semanticEquals(a.toAttribute())) {
groupAttribute = attr;
break;
}
}
}
}
}
layout.append(groupAttributeLayout);
Layout.ChannelAndType groupInput = source.layout.get(groupAttribute.id());
groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), groupAttribute));
}
if (mode == AggregateExec.Mode.FINAL) {
for (var agg : aggregates) {
if (Alias.unwrap(agg) instanceof AggregateFunction) {
layout.append(agg);
}
}
} else {
layout.append(aggregateMapper.mapGrouping(aggregates));
}
// create the agg factories
aggregatesToFactory(
aggregates,
aggregatorMode,
sourceLayout,
true, // grouping
s -> aggregatorFactories.add(s.supplier.groupingAggregatorFactory(s.mode))
);
if (groupSpecs.size() == 1 && groupSpecs.get(0).channel == null) {
operatorFactory = ordinalGroupingOperatorFactory(
source,
aggregateExec,
aggregatorFactories,
groupSpecs.get(0).attribute,
groupSpecs.get(0).elementType(),
context
);
} else {
operatorFactory = new HashAggregationOperatorFactory(
groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(),
aggregatorFactories,
context.pageSize(aggregateExec.estimatedRowSize())
);
}
}
if (operatorFactory != null) {
return source.with(operatorFactory, layout.build());
}
throw new EsqlIllegalArgumentException("no operator factory");
}
/***
* Creates a standard layout for intermediate aggregations, typically used across exchanges.
* Puts the group first, followed by each aggregation.
*
* It's similar to the code above (groupingPhysicalOperation) but ignores the factory creation.
*/
public static List intermediateAttributes(List extends NamedExpression> aggregates, List extends Expression> groupings) {
var aggregateMapper = new AggregateMapper();
List attrs = new ArrayList<>();
// no groups
if (groupings.isEmpty()) {
attrs = Expressions.asAttributes(aggregateMapper.mapNonGrouping(aggregates));
}
// groups
else {
for (Expression group : groupings) {
var groupAttribute = Expressions.attribute(group);
if (groupAttribute == null) {
throw new EsqlIllegalArgumentException("Unexpected non-named expression[{}] as grouping", group);
}
Set grpAttribIds = new HashSet<>();
grpAttribIds.add(groupAttribute.id());
/*
* Check for aliasing in aggregates which occurs in two cases (due to combining project + stats):
* - before stats (keep x = a | stats by x) which requires the partial input to use a's channel
* - after stats (stats by a | keep x = a) which causes the output layout to refer to the follow-up alias
*/
for (NamedExpression agg : aggregates) {
if (agg instanceof Alias a) {
if (a.child() instanceof Attribute attr) {
if (groupAttribute.id().equals(attr.id())) {
grpAttribIds.add(a.id());
// TODO: investigate whether a break could be used since it shouldn't be possible to have multiple
// attributes
// pointing to the same attribute
}
}
}
}
attrs.add(groupAttribute);
}
attrs.addAll(Expressions.asAttributes(aggregateMapper.mapGrouping(aggregates)));
}
return attrs;
}
private record AggFunctionSupplierContext(AggregatorFunctionSupplier supplier, AggregatorMode mode) {}
private void aggregatesToFactory(
List extends NamedExpression> aggregates,
AggregatorMode mode,
Layout layout,
boolean grouping,
Consumer consumer
) {
for (NamedExpression ne : aggregates) {
if (ne instanceof Alias alias) {
var child = alias.child();
if (child instanceof AggregateFunction aggregateFunction) {
List extends NamedExpression> sourceAttr;
if (mode == AggregatorMode.INITIAL) {
// TODO: this needs to be made more reliable - use casting to blow up when dealing with expressions (e+1)
Expression field = aggregateFunction.field();
// Only count can now support literals - all the other aggs should be optimized away
if (field.foldable()) {
if (aggregateFunction instanceof Count) {
sourceAttr = emptyList();
} else {
throw new InvalidArgumentException(
"Does not support yet aggregations over constants - [{}]",
aggregateFunction.sourceText()
);
}
} else {
sourceAttr = aggregateFunction.inputExpressions().stream().map(e -> {
Attribute attr = Expressions.attribute(e);
if (attr == null) {
throw new EsqlIllegalArgumentException(
"Cannot work with target field [{}] for agg [{}]",
e.sourceText(),
aggregateFunction.sourceText()
);
}
return attr;
}).toList();
}
} else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) {
if (grouping) {
sourceAttr = aggregateMapper.mapGrouping(aggregateFunction);
} else {
sourceAttr = aggregateMapper.mapNonGrouping(aggregateFunction);
}
} else {
throw new EsqlIllegalArgumentException("illegal aggregation mode");
}
List inputChannels = sourceAttr.stream().map(attr -> layout.get(attr.id()).channel()).toList();
assert inputChannels.stream().allMatch(i -> i >= 0) : inputChannels;
if (aggregateFunction instanceof ToAggregator agg) {
consumer.accept(new AggFunctionSupplierContext(agg.supplier(inputChannels), mode));
} else {
throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
}
}
}
}
}
private record GroupSpec(Integer channel, Attribute attribute) {
BlockHash.GroupSpec toHashGroupSpec() {
if (channel == null) {
throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
}
return new BlockHash.GroupSpec(channel, elementType());
}
ElementType elementType() {
return PlannerUtils.toElementType(attribute.dataType());
}
}
/**
* Build a grouping operator that operates on ordinals if possible.
*/
public abstract Operator.OperatorFactory ordinalGroupingOperatorFactory(
PhysicalOperation source,
AggregateExec aggregateExec,
List aggregatorFactories,
Attribute attrSource,
ElementType groupType,
LocalExecutionPlannerContext context
);
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy