org.elasticsearch.compute.operator.OrdinalsGroupingOperator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of x-pack-esql-compute Show documentation
Show all versions of x-pack-esql-compute Show documentation
Elasticsearch subproject :x-pack:plugin:esql:compute
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.operator;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.PriorityQueue;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.compute.Describable;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.aggregation.GroupingAggregator.Factory;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash.GroupSpec;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.mapper.BlockLoader;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.IntFunction;
import java.util.function.Supplier;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;
/**
* Unlike {@link HashAggregationOperator}, this hash operator also extracts values or ordinals of the input documents.
*/
public class OrdinalsGroupingOperator implements Operator {
public record OrdinalsGroupingOperatorFactory(
IntFunction blockLoaders,
List shardContexts,
ElementType groupingElementType,
int docChannel,
String groupingField,
List aggregators,
int maxPageSize
) implements OperatorFactory {
@Override
public Operator get(DriverContext driverContext) {
return new OrdinalsGroupingOperator(
blockLoaders,
shardContexts,
groupingElementType,
docChannel,
groupingField,
aggregators,
maxPageSize,
driverContext
);
}
@Override
public String describe() {
return "OrdinalsGroupingOperator(aggs = " + aggregators.stream().map(Describable::describe).collect(joining(", ")) + ")";
}
}
private final IntFunction blockLoaders;
private final List shardContexts;
private final int docChannel;
private final String groupingField;
private final List aggregatorFactories;
private final ElementType groupingElementType;
private final Map ordinalAggregators;
private final DriverContext driverContext;
private boolean finished = false;
// used to extract and aggregate values
private final int maxPageSize;
private ValuesAggregator valuesAggregator;
public OrdinalsGroupingOperator(
IntFunction blockLoaders,
List shardContexts,
ElementType groupingElementType,
int docChannel,
String groupingField,
List aggregatorFactories,
int maxPageSize,
DriverContext driverContext
) {
Objects.requireNonNull(aggregatorFactories);
this.blockLoaders = blockLoaders;
this.shardContexts = shardContexts;
this.groupingElementType = groupingElementType;
this.docChannel = docChannel;
this.groupingField = groupingField;
this.aggregatorFactories = aggregatorFactories;
this.ordinalAggregators = new HashMap<>();
this.maxPageSize = maxPageSize;
this.driverContext = driverContext;
}
@Override
public boolean needsInput() {
return finished == false;
}
@Override
public void addInput(Page page) {
checkState(needsInput(), "Operator is already finishing");
requireNonNull(page, "page is null");
DocVector docVector = page.getBlock(docChannel).asVector();
final int shardIndex = docVector.shards().getInt(0);
final var blockLoader = blockLoaders.apply(shardIndex);
boolean pagePassed = false;
try {
if (docVector.singleSegmentNonDecreasing() && blockLoader.supportsOrdinals()) {
final IntVector segmentIndexVector = docVector.segments();
assert segmentIndexVector.isConstant();
final OrdinalSegmentAggregator ordinalAggregator = this.ordinalAggregators.computeIfAbsent(
new SegmentID(shardIndex, segmentIndexVector.getInt(0)),
k -> {
try {
return new OrdinalSegmentAggregator(
driverContext.blockFactory(),
this::createGroupingAggregators,
() -> blockLoader.ordinals(shardContexts.get(k.shardIndex).reader().leaves().get(k.segmentIndex)),
driverContext.bigArrays()
);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
);
pagePassed = true;
ordinalAggregator.addInput(docVector.docs(), page);
} else {
if (valuesAggregator == null) {
int channelIndex = page.getBlockCount(); // extractor will append a new block at the end
valuesAggregator = new ValuesAggregator(
blockLoaders,
shardContexts,
groupingElementType,
docChannel,
groupingField,
channelIndex,
aggregatorFactories,
maxPageSize,
driverContext
);
}
pagePassed = true;
valuesAggregator.addInput(page);
}
} finally {
if (pagePassed == false) {
Releasables.closeExpectNoException(page::releaseBlocks);
}
}
}
private List createGroupingAggregators() {
boolean success = false;
List aggregators = new ArrayList<>(aggregatorFactories.size());
try {
for (GroupingAggregator.Factory aggregatorFactory : aggregatorFactories) {
aggregators.add(aggregatorFactory.apply(driverContext));
}
success = true;
return aggregators;
} finally {
if (success == false) {
Releasables.close(aggregators);
}
}
}
@Override
public Page getOutput() {
if (finished == false) {
return null;
}
if (valuesAggregator != null) {
try {
return valuesAggregator.getOutput();
} finally {
final ValuesAggregator aggregator = this.valuesAggregator;
this.valuesAggregator = null;
Releasables.close(aggregator);
}
}
if (ordinalAggregators.isEmpty() == false) {
try {
return mergeOrdinalsSegmentResults();
} catch (IOException e) {
throw new UncheckedIOException(e);
} finally {
Releasables.close(() -> Releasables.close(ordinalAggregators.values()), ordinalAggregators::clear);
}
}
return null;
}
@Override
public void finish() {
finished = true;
if (valuesAggregator != null) {
valuesAggregator.finish();
}
}
private Page mergeOrdinalsSegmentResults() throws IOException {
// TODO: Should we also combine from the results from ValuesAggregator
final PriorityQueue pq = new PriorityQueue<>(ordinalAggregators.size()) {
@Override
protected boolean lessThan(AggregatedResultIterator a, AggregatedResultIterator b) {
return a.currentTerm.compareTo(b.currentTerm) < 0;
}
};
final List aggregators = createGroupingAggregators();
try {
boolean seenNulls = false;
for (OrdinalSegmentAggregator agg : ordinalAggregators.values()) {
if (agg.seenNulls()) {
seenNulls = true;
for (int i = 0; i < aggregators.size(); i++) {
aggregators.get(i).addIntermediateRow(0, agg.aggregators.get(i), 0);
}
}
}
for (OrdinalSegmentAggregator agg : ordinalAggregators.values()) {
final AggregatedResultIterator it = agg.getResultIterator();
if (it.next()) {
pq.add(it);
}
}
final int startPosition = seenNulls ? 0 : -1;
int position = startPosition;
final BytesRefBuilder lastTerm = new BytesRefBuilder();
final Block[] blocks;
final int[] aggBlockCounts;
try (var keysBuilder = driverContext.blockFactory().newBytesRefBlockBuilder(1)) {
if (seenNulls) {
keysBuilder.appendNull();
}
while (pq.size() > 0) {
final AggregatedResultIterator top = pq.top();
if (position == startPosition || lastTerm.get().equals(top.currentTerm) == false) {
position++;
lastTerm.copyBytes(top.currentTerm);
keysBuilder.appendBytesRef(top.currentTerm);
}
for (int i = 0; i < top.aggregators.size(); i++) {
aggregators.get(i).addIntermediateRow(position, top.aggregators.get(i), top.currentPosition());
}
if (top.next()) {
pq.updateTop();
} else {
pq.pop();
}
}
aggBlockCounts = aggregators.stream().mapToInt(GroupingAggregator::evaluateBlockCount).toArray();
blocks = new Block[1 + Arrays.stream(aggBlockCounts).sum()];
blocks[0] = keysBuilder.build();
}
boolean success = false;
try {
try (IntVector selected = IntVector.range(0, blocks[0].getPositionCount(), driverContext.blockFactory())) {
int offset = 1;
for (int i = 0; i < aggregators.size(); i++) {
aggregators.get(i).evaluate(blocks, offset, selected, driverContext);
offset += aggBlockCounts[i];
}
}
success = true;
return new Page(blocks);
} finally {
if (success == false) {
Releasables.closeExpectNoException(blocks);
}
}
} finally {
Releasables.close(() -> Releasables.close(aggregators));
}
}
@Override
public boolean isFinished() {
return finished && valuesAggregator == null && ordinalAggregators.isEmpty();
}
@Override
public void close() {
Releasables.close(() -> Releasables.close(ordinalAggregators.values()), valuesAggregator);
}
private static void checkState(boolean condition, String msg) {
if (condition == false) {
throw new IllegalArgumentException(msg);
}
}
@Override
public String toString() {
return this.getClass().getSimpleName() + "[" + "aggregators=" + aggregatorFactories + "]";
}
record SegmentID(int shardIndex, int segmentIndex) {
}
static final class OrdinalSegmentAggregator implements Releasable, SeenGroupIds {
private final BlockFactory blockFactory;
private final List aggregators;
private final CheckedSupplier docValuesSupplier;
private final BitArray visitedOrds;
private BlockOrdinalsReader currentReader;
OrdinalSegmentAggregator(
BlockFactory blockFactory,
Supplier> aggregatorsSupplier,
CheckedSupplier docValuesSupplier,
BigArrays bigArrays
) throws IOException {
boolean success = false;
List groupingAggregators = null;
BitArray bitArray = null;
try {
final SortedSetDocValues sortedSetDocValues = docValuesSupplier.get();
bitArray = new BitArray(sortedSetDocValues.getValueCount(), bigArrays);
groupingAggregators = aggregatorsSupplier.get();
this.currentReader = BlockOrdinalsReader.newReader(blockFactory, sortedSetDocValues);
this.blockFactory = blockFactory;
this.docValuesSupplier = docValuesSupplier;
this.aggregators = groupingAggregators;
this.visitedOrds = bitArray;
success = true;
} finally {
if (success == false) {
if (bitArray != null) Releasables.close(bitArray);
if (groupingAggregators != null) Releasables.close(groupingAggregators);
}
}
}
void addInput(IntVector docs, Page page) {
try {
GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[aggregators.size()];
for (int i = 0; i < prepared.length; i++) {
prepared[i] = aggregators.get(i).prepareProcessPage(this, page);
}
if (BlockOrdinalsReader.canReuse(currentReader, docs.getInt(0)) == false) {
currentReader = BlockOrdinalsReader.newReader(blockFactory, docValuesSupplier.get());
}
try (IntBlock ordinals = currentReader.readOrdinalsAdded1(docs)) {
final IntVector ordinalsVector = ordinals.asVector();
if (ordinalsVector != null) {
addOrdinalsInput(ordinalsVector, prepared);
} else {
addOrdinalsInput(ordinals, prepared);
}
}
} catch (IOException e) {
throw new UncheckedIOException(e);
} finally {
page.releaseBlocks();
}
}
void addOrdinalsInput(IntBlock ordinals, GroupingAggregatorFunction.AddInput[] prepared) {
for (int p = 0; p < ordinals.getPositionCount(); p++) {
int start = ordinals.getFirstValueIndex(p);
int end = start + ordinals.getValueCount(p);
for (int i = start; i < end; i++) {
long ord = ordinals.getInt(i);
visitedOrds.set(ord);
}
}
for (GroupingAggregatorFunction.AddInput addInput : prepared) {
addInput.add(0, ordinals);
}
}
void addOrdinalsInput(IntVector ordinals, GroupingAggregatorFunction.AddInput[] prepared) {
for (int p = 0; p < ordinals.getPositionCount(); p++) {
long ord = ordinals.getInt(p);
visitedOrds.set(ord);
}
for (GroupingAggregatorFunction.AddInput addInput : prepared) {
addInput.add(0, ordinals);
}
}
AggregatedResultIterator getResultIterator() throws IOException {
return new AggregatedResultIterator(aggregators, visitedOrds, docValuesSupplier.get());
}
boolean seenNulls() {
return visitedOrds.get(0);
}
@Override
public BitArray seenGroupIds(BigArrays bigArrays) {
final BitArray seen = new BitArray(0, bigArrays);
boolean success = false;
try {
// the or method can grow the `seen` bits
seen.or(visitedOrds);
success = true;
return seen;
} finally {
if (success == false) {
Releasables.close(seen);
}
}
}
@Override
public void close() {
Releasables.close(visitedOrds, () -> Releasables.close(aggregators));
}
}
private static class AggregatedResultIterator {
private BytesRef currentTerm;
private long currentOrd = 0;
private final List aggregators;
private final BitArray ords;
private final SortedSetDocValues dv;
AggregatedResultIterator(List aggregators, BitArray ords, SortedSetDocValues dv) {
this.aggregators = aggregators;
this.ords = ords;
this.dv = dv;
}
int currentPosition() {
assert currentOrd != Long.MAX_VALUE : "Must not read position when iterator is exhausted";
return Math.toIntExact(currentOrd);
}
boolean next() throws IOException {
currentOrd = ords.nextSetBit(currentOrd + 1);
assert currentOrd > 0 : currentOrd;
if (currentOrd < Long.MAX_VALUE) {
currentTerm = dv.lookupOrd(currentOrd - 1);
return true;
} else {
currentTerm = null;
return false;
}
}
}
private static class ValuesAggregator implements Releasable {
private final ValuesSourceReaderOperator extractor;
private final HashAggregationOperator aggregator;
ValuesAggregator(
IntFunction blockLoaders,
List shardContexts,
ElementType groupingElementType,
int docChannel,
String groupingField,
int channelIndex,
List aggregatorFactories,
int maxPageSize,
DriverContext driverContext
) {
this.extractor = new ValuesSourceReaderOperator(
driverContext.blockFactory(),
List.of(new ValuesSourceReaderOperator.FieldInfo(groupingField, groupingElementType, blockLoaders)),
shardContexts,
docChannel
);
this.aggregator = new HashAggregationOperator(
aggregatorFactories,
() -> BlockHash.build(
List.of(new GroupSpec(channelIndex, groupingElementType)),
driverContext.blockFactory(),
maxPageSize,
false
),
driverContext
);
}
void addInput(Page page) {
extractor.addInput(page);
Page out = extractor.getOutput();
if (out != null) {
aggregator.addInput(out);
}
}
void finish() {
aggregator.finish();
}
Page getOutput() {
return aggregator.getOutput();
}
@Override
public void close() {
Releasables.close(extractor, aggregator);
}
}
abstract static class BlockOrdinalsReader {
protected final Thread creationThread;
protected final BlockFactory blockFactory;
BlockOrdinalsReader(BlockFactory blockFactory) {
this.blockFactory = blockFactory;
this.creationThread = Thread.currentThread();
}
static BlockOrdinalsReader newReader(BlockFactory blockFactory, SortedSetDocValues sortedSetDocValues) {
SortedDocValues singleValues = DocValues.unwrapSingleton(sortedSetDocValues);
if (singleValues != null) {
return new SortedDocValuesBlockOrdinalsReader(blockFactory, singleValues);
} else {
return new SortedSetDocValuesBlockOrdinalsReader(blockFactory, sortedSetDocValues);
}
}
abstract IntBlock readOrdinalsAdded1(IntVector docs) throws IOException;
abstract int docID();
/**
* Checks if the reader can be used to read a range documents starting with the given docID by the current thread.
*/
static boolean canReuse(BlockOrdinalsReader reader, int startingDocID) {
return reader != null && reader.creationThread == Thread.currentThread() && reader.docID() <= startingDocID;
}
}
private static class SortedSetDocValuesBlockOrdinalsReader extends BlockOrdinalsReader {
private final SortedSetDocValues sortedSetDocValues;
SortedSetDocValuesBlockOrdinalsReader(BlockFactory blockFactory, SortedSetDocValues sortedSetDocValues) {
super(blockFactory);
this.sortedSetDocValues = sortedSetDocValues;
}
@Override
IntBlock readOrdinalsAdded1(IntVector docs) throws IOException {
final int positionCount = docs.getPositionCount();
try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(positionCount)) {
for (int p = 0; p < positionCount; p++) {
int doc = docs.getInt(p);
if (false == sortedSetDocValues.advanceExact(doc)) {
builder.appendInt(0);
continue;
}
int count = sortedSetDocValues.docValueCount();
if (count == 1) {
builder.appendInt(Math.toIntExact(sortedSetDocValues.nextOrd() + 1));
continue;
}
builder.beginPositionEntry();
for (int i = 0; i < count; i++) {
builder.appendInt(Math.toIntExact(sortedSetDocValues.nextOrd() + 1));
}
builder.endPositionEntry();
}
return builder.build();
}
}
@Override
int docID() {
return sortedSetDocValues.docID();
}
}
private static class SortedDocValuesBlockOrdinalsReader extends BlockOrdinalsReader {
private final SortedDocValues sortedDocValues;
SortedDocValuesBlockOrdinalsReader(BlockFactory blockFactory, SortedDocValues sortedDocValues) {
super(blockFactory);
this.sortedDocValues = sortedDocValues;
}
@Override
IntBlock readOrdinalsAdded1(IntVector docs) throws IOException {
final int positionCount = docs.getPositionCount();
try (IntVector.FixedBuilder builder = blockFactory.newIntVectorFixedBuilder(positionCount)) {
for (int p = 0; p < positionCount; p++) {
if (sortedDocValues.advanceExact(docs.getInt(p))) {
builder.appendInt(p, sortedDocValues.ordValue() + 1);
} else {
builder.appendInt(p, 0);
}
}
return builder.build().asBlock();
}
}
@Override
int docID() {
return sortedDocValues.docID();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy