All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.execution.scheduler.faulttolerant.TaskDescriptorStorage Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.execution.scheduler.faulttolerant;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Suppliers;
import com.google.common.base.VerifyException;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Multimap;
import com.google.common.collect.Table;
import com.google.common.math.Quantiles;
import com.google.common.math.Stats;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.inject.Inject;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.annotation.NotThreadSafe;
import io.trino.execution.QueryManagerConfig;
import io.trino.execution.StageId;
import io.trino.metadata.Split;
import io.trino.spi.QueryId;
import io.trino.spi.TrinoException;
import io.trino.sql.planner.plan.PlanNodeId;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.ImmutableSetMultimap.toImmutableSetMultimap;
import static com.google.common.math.Quantiles.percentiles;
import static io.airlift.units.DataSize.succinctBytes;
import static io.trino.spi.StandardErrorCode.EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class TaskDescriptorStorage
{
    private static final Logger log = Logger.get(TaskDescriptorStorage.class);

    private final long maxMemoryInBytes;
    private final JsonCodec splitJsonCodec;
    private final StorageStats storageStats;

    @GuardedBy("this")
    private final Map storages = new HashMap<>();
    @GuardedBy("this")
    private long reservedBytes;

    @Inject
    public TaskDescriptorStorage(
            QueryManagerConfig config,
            JsonCodec splitJsonCodec)
    {
        this(config.getFaultTolerantExecutionTaskDescriptorStorageMaxMemory(), splitJsonCodec);
    }

    public TaskDescriptorStorage(DataSize maxMemory, JsonCodec splitJsonCodec)
    {
        this.maxMemoryInBytes = maxMemory.toBytes();
        this.splitJsonCodec = requireNonNull(splitJsonCodec, "splitJsonCodec is null");
        this.storageStats = new StorageStats(Suppliers.memoizeWithExpiration(this::computeStats, 1, TimeUnit.SECONDS));
    }

    /**
     * Initializes task descriptor storage for a given queryId.
     * It is expected to be called before query scheduling begins.
     */
    public synchronized void initialize(QueryId queryId)
    {
        TaskDescriptors storage = new TaskDescriptors();
        verify(storages.putIfAbsent(queryId, storage) == null, "storage is already initialized for query: %s", queryId);
        updateMemoryReservation(storage.getReservedBytes());
    }

    /**
     * Stores {@link TaskDescriptor} for a task identified by the stageId and partitionId.
     * The partitionId is obtained from the {@link TaskDescriptor} by calling {@link TaskDescriptor#getPartitionId()}.
     * If the query has been terminated the call is ignored.
     *
     * @throws IllegalStateException if the storage already has a task descriptor for a given task
     */
    public synchronized void put(StageId stageId, TaskDescriptor descriptor)
    {
        TaskDescriptors storage = storages.get(stageId.getQueryId());
        if (storage == null) {
            // query has been terminated
            return;
        }
        long previousReservedBytes = storage.getReservedBytes();
        storage.put(stageId, descriptor.getPartitionId(), descriptor);
        long currentReservedBytes = storage.getReservedBytes();
        long delta = currentReservedBytes - previousReservedBytes;
        updateMemoryReservation(delta);
    }

    /**
     * Get task descriptor
     *
     * @return Non empty {@link TaskDescriptor} for a task identified by the stageId and partitionId.
     * Returns {@link Optional#empty()} if the query of a given stageId has been finished (e.g.: cancelled by the user or finished early).
     * @throws java.util.NoSuchElementException if {@link TaskDescriptor} for a given task does not exist
     */
    public synchronized Optional get(StageId stageId, int partitionId)
    {
        TaskDescriptors storage = storages.get(stageId.getQueryId());
        if (storage == null) {
            // query has been terminated
            return Optional.empty();
        }
        return Optional.of(storage.get(stageId, partitionId));
    }

    /**
     * Removes {@link TaskDescriptor} for a task identified by the stageId and partitionId.
     * If the query has been terminated the call is ignored.
     *
     * @throws java.util.NoSuchElementException if {@link TaskDescriptor} for a given task does not exist
     */
    public synchronized void remove(StageId stageId, int partitionId)
    {
        TaskDescriptors storage = storages.get(stageId.getQueryId());
        if (storage == null) {
            // query has been terminated
            return;
        }
        long previousReservedBytes = storage.getReservedBytes();
        storage.remove(stageId, partitionId);
        long currentReservedBytes = storage.getReservedBytes();
        long delta = currentReservedBytes - previousReservedBytes;
        updateMemoryReservation(delta);
    }

    /**
     * Notifies the storage that the query with a given queryId has been finished and the task descriptors can be safely discarded.
     * 

* The engine may decided to destroy the storage while the scheduling is still in process (for example if query was cancelled). Under such * circumstances the implementation will ignore future calls to {@link #put(StageId, TaskDescriptor)} and return * {@link Optional#empty()} from {@link #get(StageId, int)}. The scheduler is expected to handle this condition appropriately. */ public synchronized void destroy(QueryId queryId) { TaskDescriptors storage = storages.remove(queryId); if (storage != null) { updateMemoryReservation(-storage.getReservedBytes()); } } private synchronized void updateMemoryReservation(long delta) { reservedBytes += delta; if (delta <= 0) { return; } while (reservedBytes > maxMemoryInBytes) { // drop a query that uses the most storage QueryId killCandidate = storages.entrySet().stream() .max(Comparator.comparingLong(entry -> entry.getValue().getReservedBytes())) .map(Map.Entry::getKey) .orElseThrow(() -> new VerifyException(format("storage is empty but reservedBytes (%s) is still greater than maxMemoryInBytes (%s)", reservedBytes, maxMemoryInBytes))); TaskDescriptors storage = storages.get(killCandidate); long previousReservedBytes = storage.getReservedBytes(); if (log.isInfoEnabled()) { log.info("Failing query %s; reclaiming %s of %s task descriptor memory from %s queries; extraStorageInfo=%s", killCandidate, storage.getReservedBytes(), succinctBytes(reservedBytes), storages.size(), storage.getDebugInfo()); } storage.fail(new TrinoException( EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY, format("Task descriptor storage capacity has been exceeded: %s > %s", succinctBytes(maxMemoryInBytes), succinctBytes(reservedBytes)))); long currentReservedBytes = storage.getReservedBytes(); reservedBytes += (currentReservedBytes - previousReservedBytes); } } @VisibleForTesting synchronized long getReservedBytes() { return reservedBytes; } @Managed @Nested public StorageStats getStats() { // This should not contain materialized values. GuiceMBeanExporter calls it only once during application startup // and then only @Managed methods all called on that instance. return storageStats; } private synchronized StorageStatsValue computeStats() { int queriesCount = storages.size(); long stagesCount = storages.values().stream().mapToLong(TaskDescriptors::getStagesCount).sum(); Quantiles.ScaleAndIndexes percentiles = percentiles().indexes(50, 90, 95); long queryReservedBytesP50 = 0; long queryReservedBytesP90 = 0; long queryReservedBytesP95 = 0; long queryReservedBytesAvg = 0; long stageReservedBytesP50 = 0; long stageReservedBytesP90 = 0; long stageReservedBytesP95 = 0; long stageReservedBytesAvg = 0; if (queriesCount > 0) { // we cannot compute percentiles for empty set Map queryReservedBytesPercentiles = percentiles.compute( storages.values().stream() .map(TaskDescriptors::getReservedBytes) .collect(toImmutableList())); queryReservedBytesP50 = queryReservedBytesPercentiles.get(50).longValue(); queryReservedBytesP90 = queryReservedBytesPercentiles.get(90).longValue(); queryReservedBytesP95 = queryReservedBytesPercentiles.get(95).longValue(); queryReservedBytesAvg = reservedBytes / queriesCount; List storagesReservedBytes = storages.values().stream() .flatMap(TaskDescriptors::getStagesReservedBytes) .collect(toImmutableList()); if (!storagesReservedBytes.isEmpty()) { Map stagesReservedBytesPercentiles = percentiles.compute( storagesReservedBytes); stageReservedBytesP50 = stagesReservedBytesPercentiles.get(50).longValue(); stageReservedBytesP90 = stagesReservedBytesPercentiles.get(90).longValue(); stageReservedBytesP95 = stagesReservedBytesPercentiles.get(95).longValue(); stageReservedBytesAvg = reservedBytes / stagesCount; } } return new StorageStatsValue( queriesCount, stagesCount, reservedBytes, queryReservedBytesAvg, queryReservedBytesP50, queryReservedBytesP90, queryReservedBytesP95, stageReservedBytesAvg, stageReservedBytesP50, stageReservedBytesP90, stageReservedBytesP95); } @NotThreadSafe private class TaskDescriptors { private final Table descriptors = HashBasedTable.create(); private long reservedBytes; private final Map stagesReservedBytes = new HashMap<>(); private RuntimeException failure; public void put(StageId stageId, int partitionId, TaskDescriptor descriptor) { throwIfFailed(); checkState(!descriptors.contains(stageId, partitionId), "task descriptor is already present for key %s/%s ", stageId, partitionId); descriptors.put(stageId, partitionId, descriptor); long descriptorRetainedBytes = descriptor.getRetainedSizeInBytes(); reservedBytes += descriptorRetainedBytes; stagesReservedBytes.computeIfAbsent(stageId, ignored -> new AtomicLong()).addAndGet(descriptorRetainedBytes); } public TaskDescriptor get(StageId stageId, int partitionId) { throwIfFailed(); TaskDescriptor descriptor = descriptors.get(stageId, partitionId); if (descriptor == null) { throw new NoSuchElementException(format("descriptor not found for key %s/%s", stageId, partitionId)); } return descriptor; } public void remove(StageId stageId, int partitionId) { throwIfFailed(); TaskDescriptor descriptor = descriptors.remove(stageId, partitionId); if (descriptor == null) { throw new NoSuchElementException(format("descriptor not found for key %s/%s", stageId, partitionId)); } long descriptorRetainedBytes = descriptor.getRetainedSizeInBytes(); reservedBytes -= descriptorRetainedBytes; requireNonNull(stagesReservedBytes.get(stageId), () -> format("no entry for stage %s", stageId)).addAndGet(-descriptorRetainedBytes); } public long getReservedBytes() { return reservedBytes; } private String getDebugInfo() { Multimap descriptorsByStageId = descriptors.cellSet().stream() .collect(toImmutableSetMultimap( Table.Cell::getRowKey, Table.Cell::getValue)); Map debugInfoByStageId = descriptorsByStageId.asMap().entrySet().stream() .collect(toImmutableMap( Map.Entry::getKey, entry -> getDebugInfo(entry.getValue()))); List biggestSplits = descriptorsByStageId.entries().stream() .flatMap(entry -> entry.getValue().getSplits().getSplitsFlat().entries().stream().map(splitEntry -> Map.entry("%s/%s".formatted(entry.getKey(), splitEntry.getKey()), splitEntry.getValue()))) .sorted(Comparator.>comparingLong(entry -> entry.getValue().getRetainedSizeInBytes()).reversed()) .limit(3) .map(entry -> "{nodeId=%s, size=%s, split=%s}".formatted(entry.getKey(), entry.getValue().getRetainedSizeInBytes(), splitJsonCodec.toJson(entry.getValue()))) .toList(); return "stagesInfo=%s; biggestSplits=%s".formatted(debugInfoByStageId, biggestSplits); } private String getDebugInfo(Collection taskDescriptors) { int taskDescriptorsCount = taskDescriptors.size(); Stats taskDescriptorsRetainedSizeStats = Stats.of(taskDescriptors.stream().mapToLong(TaskDescriptor::getRetainedSizeInBytes)); Set planNodeIds = taskDescriptors.stream().flatMap(taskDescriptor -> taskDescriptor.getSplits().getSplitsFlat().keySet().stream()).collect(toImmutableSet()); Map splitsDebugInfo = new HashMap<>(); for (PlanNodeId planNodeId : planNodeIds) { Stats splitCountStats = Stats.of(taskDescriptors.stream().mapToLong(taskDescriptor -> taskDescriptor.getSplits().getSplitsFlat().asMap().get(planNodeId).size())); Stats splitSizeStats = Stats.of(taskDescriptors.stream().flatMap(taskDescriptor -> taskDescriptor.getSplits().getSplitsFlat().get(planNodeId).stream()).mapToLong(Split::getRetainedSizeInBytes)); splitsDebugInfo.put( planNodeId, "{splitCountMean=%s, splitCountStdDev=%s, splitSizeMean=%s, splitSizeStdDev=%s}".formatted( splitCountStats.mean(), splitCountStats.populationStandardDeviation(), splitSizeStats.mean(), splitSizeStats.populationStandardDeviation())); } return "[taskDescriptorsCount=%s, taskDescriptorsRetainedSizeMean=%s, taskDescriptorsRetainedSizeStdDev=%s, splits=%s]".formatted( taskDescriptorsCount, taskDescriptorsRetainedSizeStats.mean(), taskDescriptorsRetainedSizeStats.populationStandardDeviation(), splitsDebugInfo); } private void fail(RuntimeException failure) { if (this.failure == null) { descriptors.clear(); reservedBytes = 0; this.failure = failure; } } private void throwIfFailed() { if (failure != null) { throw failure; } } public int getStagesCount() { return descriptors.rowMap().size(); } public Stream getStagesReservedBytes() { return stagesReservedBytes.values().stream() .map(AtomicLong::get); } } private record StorageStatsValue( long queriesCount, long stagesCount, long reservedBytes, long queryReservedBytesAvg, long queryReservedBytesP50, long queryReservedBytesP90, long queryReservedBytesP95, long stageReservedBytesAvg, long stageReservedBytesP50, long stageReservedBytesP90, long stageReservedBytesP95) {} public static class StorageStats { private final Supplier statsSupplier; StorageStats(Supplier statsSupplier) { this.statsSupplier = requireNonNull(statsSupplier, "statsSupplier is null"); } @Managed public long getQueriesCount() { return statsSupplier.get().queriesCount(); } @Managed public long getStagesCount() { return statsSupplier.get().stagesCount(); } @Managed public long getReservedBytes() { return statsSupplier.get().reservedBytes(); } @Managed public long getQueryReservedBytesAvg() { return statsSupplier.get().queryReservedBytesAvg(); } @Managed public long getQueryReservedBytesP50() { return statsSupplier.get().queryReservedBytesP50(); } @Managed public long getQueryReservedBytesP90() { return statsSupplier.get().queryReservedBytesP90(); } @Managed public long getQueryReservedBytesP95() { return statsSupplier.get().queryReservedBytesP95(); } @Managed public long getStageReservedBytesAvg() { return statsSupplier.get().stageReservedBytesP50(); } @Managed public long getStageReservedBytesP50() { return statsSupplier.get().stageReservedBytesP50(); } @Managed public long getStageReservedBytesP90() { return statsSupplier.get().stageReservedBytesP90(); } @Managed public long getStageReservedBytesP95() { return statsSupplier.get().stageReservedBytesP95(); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy