org.elasticsearch.compute.aggregation.blockhash.PackedValuesBlockHash Maven / Gradle / Ivy
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation.blockhash;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.mvdedupe.BatchEncoder;
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;
import java.util.Arrays;
import java.util.List;
/**
* Maps any number of columns to a group ids with every unique combination resulting
* in a unique group id. Works by unique-ing the values of each column and concatenating
* the combinatorial explosion of all values into a byte array and then hashing each
* byte array. If the values are
* {@code
* a=(1, 2, 3) b=(2, 3) c=(4, 5, 5)
* }
* Then you get these grouping keys:
* {@code
* 1, 2, 4
* 1, 2, 5
* 1, 3, 4
* 1, 3, 5
* 2, 2, 4
* 2, 2, 5
* 2, 3, 4
* 2, 3, 5
* 3, 2, 4
* 3, 3, 5
* }
*
* The iteration order in the above is how we do it - it's as though it's
* nested {@code for} loops with the first column being the outer-most loop
* and the last column being the inner-most loop. See {@link Group} for more.
*
*/
final class PackedValuesBlockHash extends BlockHash {
static final int DEFAULT_BATCH_SIZE = Math.toIntExact(ByteSizeValue.ofKb(10).getBytes());
private final int emitBatchSize;
private final BytesRefHash bytesRefHash;
private final int nullTrackingBytes;
private final BytesRefBuilder bytes = new BytesRefBuilder();
private final List specs;
PackedValuesBlockHash(List specs, BlockFactory blockFactory, int emitBatchSize) {
super(blockFactory);
this.specs = specs;
this.emitBatchSize = emitBatchSize;
this.bytesRefHash = new BytesRefHash(1, blockFactory.bigArrays());
this.nullTrackingBytes = (specs.size() + 7) / 8;
bytes.grow(nullTrackingBytes);
}
@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
add(page, addInput, DEFAULT_BATCH_SIZE);
}
void add(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) {
try (AddWork work = new AddWork(page, addInput, batchSize)) {
work.add();
}
}
/**
* The on-heap representation of a {@code for} loop for each group key.
*/
private static class Group implements Releasable {
final GroupSpec spec;
final BatchEncoder encoder;
int positionOffset;
int valueOffset;
/**
* The number of values we've written for this group. Think of it as
* the loop variable in a {@code for} loop.
*/
int writtenValues;
/**
* The number of values of this group at this position. Think of it as
* the maximum value in a {@code for} loop.
*/
int valueCount;
int bytesStart;
Group(GroupSpec spec, Page page, int batchSize) {
this.spec = spec;
this.encoder = MultivalueDedupe.batchEncoder(page.getBlock(spec.channel()), batchSize, true);
}
@Override
public void close() {
encoder.close();
}
}
class AddWork extends AbstractAddBlock {
final Group[] groups;
final int positionCount;
int position;
AddWork(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) {
super(blockFactory, emitBatchSize, addInput);
this.groups = specs.stream().map(s -> new Group(s, page, batchSize)).toArray(Group[]::new);
this.positionCount = page.getPositionCount();
}
/**
* Encodes one permutation of the keys at time into {@link #bytes} and adds it
* to the {@link #bytesRefHash}. The encoding is mostly provided by
* {@link BatchEncoder} with nulls living in a bit mask at the front of the bytes.
*/
void add() {
for (position = 0; position < positionCount; position++) {
boolean singleEntry = startPosition(groups);
if (singleEntry) {
addSingleEntry();
} else {
addMultipleEntries();
}
}
emitOrds();
}
private void addSingleEntry() {
fillBytesSv(groups);
ords.appendInt(Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))));
addedValue(position);
}
private void addMultipleEntries() {
ords.beginPositionEntry();
int g = 0;
do {
fillBytesMv(groups, g);
// emit ords
ords.appendInt(Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))));
addedValueInMultivaluePosition(position);
g = rewindKeys(groups);
} while (g >= 0);
ords.endPositionEntry();
for (Group group : groups) {
group.valueOffset += group.valueCount;
}
}
@Override
public void close() {
Releasables.closeExpectNoException(super::close, Releasables.wrap(groups));
}
}
@Override
public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) {
return new LookupWork(page, targetBlockSize.getBytes(), DEFAULT_BATCH_SIZE);
}
class LookupWork implements ReleasableIterator {
private final Group[] groups;
private final long targetByteSize;
private final int positionCount;
private int position;
LookupWork(Page page, long targetByteSize, int batchSize) {
this.groups = specs.stream().map(s -> new Group(s, page, batchSize)).toArray(Group[]::new);
this.positionCount = page.getPositionCount();
this.targetByteSize = targetByteSize;
}
@Override
public boolean hasNext() {
return position < positionCount;
}
@Override
public IntBlock next() {
int size = Math.toIntExact(Math.min(positionCount - position, targetByteSize / Integer.BYTES / 2));
try (IntBlock.Builder ords = blockFactory.newIntBlockBuilder(size)) {
if (ords.estimatedBytes() > targetByteSize) {
throw new IllegalStateException(
"initial builder overshot target [" + ords.estimatedBytes() + "] vs [" + targetByteSize + "]"
);
}
while (position < positionCount && ords.estimatedBytes() < targetByteSize) {
// TODO a test where targetByteSize is very small should still make a few rows.
boolean singleEntry = startPosition(groups);
if (singleEntry) {
lookupSingleEntry(ords);
} else {
lookupMultipleEntries(ords);
}
position++;
}
return ords.build();
}
}
private void lookupSingleEntry(IntBlock.Builder ords) {
fillBytesSv(groups);
long found = bytesRefHash.find(bytes.get());
if (found < 0) {
ords.appendNull();
} else {
ords.appendInt(Math.toIntExact(found));
}
}
private void lookupMultipleEntries(IntBlock.Builder ords) {
long firstFound = -1;
boolean began = false;
int g = 0;
int count = 0;
do {
fillBytesMv(groups, g);
// emit ords
long found = bytesRefHash.find(bytes.get());
if (found >= 0) {
if (firstFound < 0) {
firstFound = found;
} else {
if (began == false) {
began = true;
ords.beginPositionEntry();
ords.appendInt(Math.toIntExact(firstFound));
count++;
}
ords.appendInt(Math.toIntExact(found));
count++;
if (count > Block.MAX_LOOKUP) {
// TODO replace this with a warning and break
throw new IllegalArgumentException("Found a single entry with " + count + " entries");
}
}
}
g = rewindKeys(groups);
} while (g >= 0);
if (firstFound < 0) {
ords.appendNull();
} else if (began) {
ords.endPositionEntry();
} else {
// Only found one value
ords.appendInt(Math.toIntExact(hashOrdToGroup(firstFound)));
}
for (Group group : groups) {
group.valueOffset += group.valueCount;
}
}
@Override
public void close() {
Releasables.closeExpectNoException(groups);
}
}
/**
* Correctly position all {@code groups}, clear the {@link #bytes},
* and position it past the null tracking bytes. Call this before
* encoding a new position.
* @return true if this position has only a single ordinal
*/
private boolean startPosition(Group[] groups) {
boolean singleEntry = true;
for (Group g : groups) {
/*
* Make sure all encoders have encoded the current position and the
* offsets are queued to its start.
*/
var encoder = g.encoder;
g.positionOffset++;
while (g.positionOffset >= encoder.positionCount()) {
encoder.encodeNextBatch();
g.positionOffset = 0;
g.valueOffset = 0;
}
g.valueCount = encoder.valueCount(g.positionOffset);
singleEntry &= (g.valueCount == 1);
}
Arrays.fill(bytes.bytes(), 0, nullTrackingBytes, (byte) 0);
bytes.setLength(nullTrackingBytes);
return singleEntry;
}
private void fillBytesSv(Group[] groups) {
for (int g = 0; g < groups.length; g++) {
Group group = groups[g];
assert group.writtenValues == 0;
assert group.valueCount == 1;
if (group.encoder.read(group.valueOffset++, bytes) == 0) {
int nullByte = g / 8;
int nullShift = g % 8;
bytes.bytes()[nullByte] |= (byte) (1 << nullShift);
}
}
}
private void fillBytesMv(Group[] groups, int startingGroup) {
for (int g = startingGroup; g < groups.length; g++) {
Group group = groups[g];
group.bytesStart = bytes.length();
if (group.encoder.read(group.valueOffset + group.writtenValues, bytes) == 0) {
assert group.valueCount == 1 : "null value in non-singleton list";
int nullByte = g / 8;
int nullShift = g % 8;
bytes.bytes()[nullByte] |= (byte) (1 << nullShift);
}
++group.writtenValues;
}
}
private int rewindKeys(Group[] groups) {
int g = groups.length - 1;
Group group = groups[g];
bytes.setLength(group.bytesStart);
while (group.writtenValues == group.valueCount) {
group.writtenValues = 0;
if (g == 0) {
return -1;
} else {
group = groups[--g];
bytes.setLength(group.bytesStart);
}
}
return g;
}
@Override
public Block[] getKeys() {
int size = Math.toIntExact(bytesRefHash.size());
BatchEncoder.Decoder[] decoders = new BatchEncoder.Decoder[specs.size()];
Block.Builder[] builders = new Block.Builder[specs.size()];
try {
for (int g = 0; g < builders.length; g++) {
ElementType elementType = specs.get(g).elementType();
decoders[g] = BatchEncoder.decoder(elementType);
builders[g] = elementType.newBlockBuilder(size, blockFactory);
}
BytesRef[] values = new BytesRef[(int) Math.min(100, bytesRefHash.size())];
BytesRef[] nulls = new BytesRef[values.length];
for (int offset = 0; offset < values.length; offset++) {
values[offset] = new BytesRef();
nulls[offset] = new BytesRef();
nulls[offset].length = nullTrackingBytes;
}
int offset = 0;
for (int i = 0; i < bytesRefHash.size(); i++) {
values[offset] = bytesRefHash.get(i, values[offset]);
// Reference the null bytes in the nulls array and values in the values
nulls[offset].bytes = values[offset].bytes;
nulls[offset].offset = values[offset].offset;
values[offset].offset += nullTrackingBytes;
values[offset].length -= nullTrackingBytes;
offset++;
if (offset == values.length) {
readKeys(decoders, builders, nulls, values, offset);
offset = 0;
}
}
if (offset > 0) {
readKeys(decoders, builders, nulls, values, offset);
}
return Block.Builder.buildAll(builders);
} finally {
Releasables.closeExpectNoException(builders);
}
}
private void readKeys(BatchEncoder.Decoder[] decoders, Block.Builder[] builders, BytesRef[] nulls, BytesRef[] values, int count) {
for (int g = 0; g < builders.length; g++) {
int nullByte = g / 8;
int nullShift = g % 8;
byte nullTest = (byte) (1 << nullShift);
BatchEncoder.IsNull isNull = offset -> {
BytesRef n = nulls[offset];
return (n.bytes[n.offset + nullByte] & nullTest) != 0;
};
decoders[g].decode(builders[g], isNull, values, count);
}
}
@Override
public IntVector nonEmpty() {
return IntVector.range(0, Math.toIntExact(bytesRefHash.size()), blockFactory);
}
@Override
public BitArray seenGroupIds(BigArrays bigArrays) {
return new SeenGroupIds.Range(0, Math.toIntExact(bytesRefHash.size())).seenGroupIds(bigArrays);
}
@Override
public void close() {
bytesRefHash.close();
}
@Override
public String toString() {
StringBuilder b = new StringBuilder();
b.append("PackedValuesBlockHash{groups=[");
boolean first = true;
for (int i = 0; i < specs.size(); i++) {
if (i > 0) {
b.append(", ");
}
GroupSpec spec = specs.get(i);
b.append(spec.channel()).append(':').append(spec.elementType());
}
b.append("], entries=").append(bytesRefHash.size());
b.append(", size=").append(ByteSizeValue.ofBytes(bytesRefHash.ramBytesUsed()));
return b.append("}").toString();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy