io.trino.spi.exchange.ExchangeSourceOutputSelector Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.spi.exchange;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.StringJoiner;
import java.util.TreeMap;
import java.util.function.Function;
import static io.airlift.slice.SizeOf.instanceSize;
import static io.trino.spi.exchange.ExchangeSourceOutputSelector.Selection.EXCLUDED;
import static io.trino.spi.exchange.ExchangeSourceOutputSelector.Selection.INCLUDED;
import static io.trino.spi.exchange.ExchangeSourceOutputSelector.Selection.UNKNOWN;
import static java.lang.Math.max;
import static java.util.Arrays.fill;
import static java.util.Map.entry;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toUnmodifiableMap;
public class ExchangeSourceOutputSelector
{
private static final long INSTANCE_SIZE = instanceSize(ExchangeSourceOutputSelector.class);
private final int version;
private final Map values;
private final boolean finalSelector;
// visible for Jackson
@JsonCreator
public ExchangeSourceOutputSelector(
@JsonProperty("version") int version,
@JsonProperty("values") Map values,
@JsonProperty("finalSelector") boolean finalSelector)
{
this.version = version;
this.values = Map.copyOf(requireNonNull(values, "values is null"));
this.finalSelector = finalSelector;
}
@JsonProperty
public int getVersion()
{
return version;
}
// visible for Jackson
@JsonProperty
public Map getValues()
{
return values;
}
@JsonProperty("finalSelector")
public boolean isFinal()
{
return finalSelector;
}
public Selection getSelection(ExchangeId exchangeId, int taskPartitionId, int attemptId)
{
requireNonNull(exchangeId, "exchangeId is null");
if (taskPartitionId < 0) {
throw new IllegalArgumentException("unexpected taskPartitionId: " + taskPartitionId);
}
if (attemptId < 0 || attemptId > Byte.MAX_VALUE) {
throw new IllegalArgumentException("unexpected attemptId: " + attemptId);
}
Slice exchangeValues = values.get(exchangeId);
if (exchangeValues == null) {
throwIfFinal(exchangeId, taskPartitionId);
return UNKNOWN;
}
if (exchangeValues.length() <= taskPartitionId) {
throwIfFinal(exchangeId, taskPartitionId);
return UNKNOWN;
}
byte selectedAttempt = exchangeValues.getByte(taskPartitionId);
if (selectedAttempt == UNKNOWN.getValue()) {
throwIfFinal(exchangeId, taskPartitionId);
return UNKNOWN;
}
if (selectedAttempt == EXCLUDED.getValue()) {
return EXCLUDED;
}
if (selectedAttempt < 0) {
throw new IllegalArgumentException("unexpected selectedAttempt: " + selectedAttempt);
}
return selectedAttempt == attemptId ? INCLUDED : EXCLUDED;
}
public long getRetainedSizeInBytes()
{
return INSTANCE_SIZE
+ SizeOf.estimatedSizeOf(values, ExchangeId::getRetainedSizeInBytes, Slice::getRetainedSize);
}
public void checkValidTransition(ExchangeSourceOutputSelector newSelector)
{
if (this.version >= newSelector.version) {
throw new IllegalArgumentException("Invalid transition to the same or an older version");
}
if (this.isFinal()) {
throw new IllegalArgumentException("Invalid transition from final selector");
}
Set exchangeIds = new HashSet<>();
exchangeIds.addAll(this.values.keySet());
exchangeIds.addAll(newSelector.values.keySet());
for (ExchangeId exchangeId : exchangeIds) {
int taskPartitionCount = max(this.getPartitionCount(exchangeId), newSelector.getPartitionCount(exchangeId));
for (int taskPartitionId = 0; taskPartitionId < taskPartitionCount; taskPartitionId++) {
byte currentValue = this.getValue(exchangeId, taskPartitionId);
byte newValue = newSelector.getValue(exchangeId, taskPartitionId);
if (currentValue == UNKNOWN.getValue()) {
// transition from UNKNOWN is always valid
continue;
}
if (currentValue != newValue) {
throw new IllegalArgumentException("Invalid transition for exchange %s, taskPartitionId %s: %s -> %s".formatted(exchangeId, taskPartitionId, currentValue, newValue));
}
}
}
}
public ExchangeSourceOutputSelector merge(ExchangeSourceOutputSelector other)
{
Map values = new HashMap<>(this.values);
other.values.forEach((exchangeId, value) -> {
Slice currentValue = values.putIfAbsent(exchangeId, value);
if (currentValue != null) {
throw new IllegalArgumentException("duplicated selector for exchange: " + exchangeId);
}
});
return new ExchangeSourceOutputSelector(
this.version + other.version,
values,
this.finalSelector && other.finalSelector);
}
@Override
public String toString()
{
return new StringJoiner(", ", ExchangeSourceOutputSelector.class.getSimpleName() + "[", "]")
.add("version=" + version)
.add("values=" + values.entrySet().stream()
.map(e -> entry(e.getKey().toString(), valuesSliceToString(e.getValue())))
// collect to TreeMap to ensure ordering of keys
.collect(toMap(
Entry::getKey,
Entry::getValue,
(a, b) -> { throw new IllegalArgumentException("got duplicate key " + a + ", " + b); },
TreeMap::new)))
.add("finalSelector=" + finalSelector)
.toString();
}
private String valuesSliceToString(Slice values)
{
StringBuilder builder = new StringBuilder();
builder.append("[");
try (BasicSliceInput input = new BasicSliceInput(values)) {
int taskPartitionId = 0;
while (true) {
int value = input.read();
if (value == -1) {
break;
}
if (taskPartitionId != 0) {
builder.append(",");
}
builder.append(taskPartitionId);
builder.append("=");
if ((byte) value == EXCLUDED.value) {
builder.append("E");
}
else if ((byte) value == UNKNOWN.value) {
builder.append("U");
}
else {
builder.append(value);
}
taskPartitionId++;
}
}
builder.append("]");
return builder.toString();
}
private int getPartitionCount(ExchangeId exchangeId)
{
Slice values = this.values.get(exchangeId);
if (values == null) {
return 0;
}
return values.length();
}
private byte getValue(ExchangeId exchangeId, int taskPartitionId)
{
Slice exchangeValues = values.get(exchangeId);
if (exchangeValues == null) {
return UNKNOWN.getValue();
}
if (exchangeValues.length() <= taskPartitionId) {
return UNKNOWN.getValue();
}
return exchangeValues.getByte(taskPartitionId);
}
private void throwIfFinal(ExchangeId exchangeId, int taskPartitionId)
{
if (isFinal()) {
throw new IllegalArgumentException("selection not found for exchangeId %s, taskPartitionId %s".formatted(exchangeId, taskPartitionId));
}
}
public enum Selection
{
INCLUDED((byte) -1),
EXCLUDED((byte) -2),
UNKNOWN((byte) -3);
private final byte value;
Selection(byte value)
{
this.value = value;
}
public byte getValue()
{
return value;
}
}
public static Builder builder(Set sourceExchanges)
{
return new Builder(sourceExchanges);
}
public static class Builder
{
private int nextVersion;
private final Map exchangeValues;
private boolean finalSelector;
private final Map exchangeTaskPartitionCount = new HashMap<>();
public Builder(Set sourceExchanges)
{
requireNonNull(sourceExchanges, "sourceExchanges is null");
exchangeValues = sourceExchanges.stream()
.collect(toUnmodifiableMap(Function.identity(), exchangeId -> new ValuesBuilder()));
}
public Builder include(ExchangeId exchangeId, int taskPartitionId, int attemptId)
{
getValuesBuilderForExchange(exchangeId).include(taskPartitionId, attemptId);
return this;
}
public Builder exclude(ExchangeId exchangeId, int taskPartitionId)
{
getValuesBuilderForExchange(exchangeId).exclude(taskPartitionId);
return this;
}
private ValuesBuilder getValuesBuilderForExchange(ExchangeId exchangeId)
{
ValuesBuilder result = exchangeValues.get(exchangeId);
if (result == null) {
throw new IllegalArgumentException("Unexpected exchange: " + exchangeId);
}
return result;
}
public Builder setPartitionCount(ExchangeId exchangeId, int count)
{
Integer previousCount = exchangeTaskPartitionCount.putIfAbsent(exchangeId, count);
if (previousCount != null) {
throw new IllegalStateException("Partition count for exchange is already set: " + exchangeId);
}
return this;
}
public Builder setFinal()
{
if (finalSelector) {
throw new IllegalStateException("selector is already marked as final");
}
for (ExchangeId exchangeId : exchangeValues.keySet()) {
if (!exchangeTaskPartitionCount.containsKey(exchangeId)) {
throw new IllegalStateException("partition count is missing for exchange: " + exchangeId);
}
}
this.finalSelector = true;
return this;
}
public ExchangeSourceOutputSelector build()
{
return new ExchangeSourceOutputSelector(
nextVersion++,
exchangeValues.entrySet().stream()
.collect(toMap(Entry::getKey, entry -> {
ExchangeId exchangeId = entry.getKey();
ValuesBuilder valuesBuilder = entry.getValue();
if (finalSelector) {
return valuesBuilder.buildFinal(exchangeTaskPartitionCount.get(exchangeId));
}
else {
return valuesBuilder.build();
}
})),
finalSelector);
}
}
private static class ValuesBuilder
{
private Slice values = Slices.allocate(0);
private int maxTaskPartitionId = -1;
public void include(int taskPartitionId, int attemptId)
{
updateMaxTaskPartitionIdAndEnsureCapacity(taskPartitionId);
if (attemptId < 0 || attemptId > Byte.MAX_VALUE) {
throw new IllegalArgumentException("unexpected attemptId: " + attemptId);
}
byte currentValue = values.getByte(taskPartitionId);
if (currentValue != UNKNOWN.getValue()) {
throw new IllegalArgumentException("decision for partition %s is already made: %s".formatted(taskPartitionId, currentValue));
}
values.setByte(taskPartitionId, (byte) attemptId);
}
public void exclude(int taskPartitionId)
{
updateMaxTaskPartitionIdAndEnsureCapacity(taskPartitionId);
byte currentValue = values.getByte(taskPartitionId);
if (currentValue != UNKNOWN.getValue()) {
throw new IllegalArgumentException("decision for partition %s is already made: %s".formatted(taskPartitionId, currentValue));
}
values.setByte(taskPartitionId, EXCLUDED.getValue());
}
private void updateMaxTaskPartitionIdAndEnsureCapacity(int taskPartitionId)
{
if (taskPartitionId > maxTaskPartitionId) {
maxTaskPartitionId = taskPartitionId;
}
if (taskPartitionId < values.length()) {
return;
}
byte[] newValues = new byte[(maxTaskPartitionId + 1) * 2];
fill(newValues, UNKNOWN.getValue());
values.getBytes(0, newValues, 0, values.length());
values = Slices.wrappedBuffer(newValues);
}
public Slice build()
{
return createResult(maxTaskPartitionId + 1);
}
public Slice buildFinal(int totalPartitionCount)
{
Slice result = createResult(totalPartitionCount);
for (int partitionId = 0; partitionId < totalPartitionCount; partitionId++) {
byte selectedAttempt = result.getByte(partitionId);
if (selectedAttempt == UNKNOWN.getValue()) {
throw new IllegalStateException("Attempt is unknown for partition: " + partitionId);
}
}
return result;
}
private Slice createResult(int partitionCount)
{
if (maxTaskPartitionId >= partitionCount) {
throw new IllegalArgumentException("expected maxTaskPartitionId to be less than or equal to " + (partitionCount - 1));
}
byte[] result = new byte[partitionCount];
fill(result, UNKNOWN.getValue());
values.getBytes(0, result, 0, maxTaskPartitionId + 1);
return Slices.wrappedBuffer(result);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy