All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.StreamingDirectExchangeBuffer Maven / Gradle / Ivy

There is a newer version: 468
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import io.trino.execution.TaskId;
import io.trino.spi.TrinoException;

import java.util.ArrayDeque;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.Executor;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.util.concurrent.Futures.immediateVoidFuture;
import static com.google.common.util.concurrent.Futures.nonCancellationPropagating;
import static io.trino.spi.StandardErrorCode.REMOTE_TASK_FAILED;
import static java.lang.Math.max;
import static java.util.Objects.requireNonNull;

public class StreamingDirectExchangeBuffer
        implements DirectExchangeBuffer
{
    private static final Logger log = Logger.get(StreamingDirectExchangeBuffer.class);

    private final Executor executor;
    private final long bufferCapacityInBytes;

    @GuardedBy("this")
    private final Queue bufferedPages = new ArrayDeque<>();
    @GuardedBy("this")
    private volatile long bufferRetainedSizeInBytes;
    @GuardedBy("this")
    private volatile long maxBufferRetainedSizeInBytes;
    @GuardedBy("this")
    private Queue> blocked = new ArrayDeque<>();
    @GuardedBy("this")
    private final Set activeTasks = new HashSet<>();
    @GuardedBy("this")
    private boolean noMoreTasks;
    @GuardedBy("this")
    private Throwable failure;
    @GuardedBy("this")
    private boolean closed;

    public StreamingDirectExchangeBuffer(Executor executor, DataSize bufferCapacity)
    {
        this.executor = requireNonNull(executor, "executor is null");
        this.bufferCapacityInBytes = bufferCapacity.toBytes();
    }

    @Override
    public synchronized ListenableFuture isBlocked()
    {
        if (!bufferedPages.isEmpty() || isFailed() || (noMoreTasks && activeTasks.isEmpty())) {
            return immediateVoidFuture();
        }
        SettableFuture callback = SettableFuture.create();
        blocked.add(callback);

        return nonCancellationPropagating(callback);
    }

    @Override
    public synchronized Slice pollPage()
    {
        throwIfFailed();

        if (closed) {
            return null;
        }
        Slice page = bufferedPages.poll();
        if (page != null) {
            bufferRetainedSizeInBytes -= page.getRetainedSize();
            checkState(bufferRetainedSizeInBytes >= 0, "unexpected bufferRetainedSizeInBytes: %s", bufferRetainedSizeInBytes);
        }
        return page;
    }

    @Override
    public synchronized void addTask(TaskId taskId)
    {
        if (closed) {
            return;
        }
        checkState(!noMoreTasks, "no more tasks are expected");
        activeTasks.add(taskId);
    }

    @Override
    public void addPages(TaskId taskId, List pages)
    {
        long pagesRetainedSizeInBytes = 0;
        for (Slice page : pages) {
            pagesRetainedSizeInBytes += page.getRetainedSize();
        }
        synchronized (this) {
            if (closed) {
                return;
            }
            checkState(activeTasks.contains(taskId), "taskId is not active: %s", taskId);
            bufferedPages.addAll(pages);
            bufferRetainedSizeInBytes += pagesRetainedSizeInBytes;
            maxBufferRetainedSizeInBytes = max(maxBufferRetainedSizeInBytes, bufferRetainedSizeInBytes);
            // Unblock the same number of consumers as pages to reduce the possibility of a thread waking up with an empty pull from the buffer.
            unblock(pages.size());
        }
    }

    @Override
    public synchronized void taskFinished(TaskId taskId)
    {
        if (closed) {
            return;
        }
        checkState(activeTasks.contains(taskId), "taskId not registered: %s", taskId);
        activeTasks.remove(taskId);
        if (noMoreTasks && activeTasks.isEmpty()) {
            unblockAll();
        }
    }

    @Override
    public synchronized void taskFailed(TaskId taskId, Throwable t)
    {
        if (closed) {
            return;
        }
        checkState(activeTasks.contains(taskId), "taskId not registered: %s", taskId);

        if (t instanceof TrinoException && REMOTE_TASK_FAILED.toErrorCode().equals(((TrinoException) t).getErrorCode())) {
            // This error indicates that a downstream task was trying to fetch results from an upstream task that is marked as failed
            // Instead of failing a downstream task let the coordinator handle and report the failure of an upstream task to ensure correct error reporting
            log.debug("Task failure discovered while fetching task results: %s", taskId);
            return;
        }

        failure = t;
        activeTasks.remove(taskId);
        unblockAll();
    }

    @Override
    public synchronized void noMoreTasks()
    {
        noMoreTasks = true;
        if (activeTasks.isEmpty()) {
            unblockAll();
        }
    }

    @Override
    public synchronized boolean isFinished()
    {
        return failure == null && noMoreTasks && activeTasks.isEmpty() && bufferedPages.isEmpty();
    }

    @Override
    public synchronized boolean isFailed()
    {
        return failure != null;
    }

    @Override
    public long getRemainingCapacityInBytes()
    {
        return max(bufferCapacityInBytes - bufferRetainedSizeInBytes, 0);
    }

    @Override
    public long getRetainedSizeInBytes()
    {
        return bufferRetainedSizeInBytes;
    }

    @Override
    public long getMaxRetainedSizeInBytes()
    {
        return maxBufferRetainedSizeInBytes;
    }

    @Override
    public synchronized int getBufferedPageCount()
    {
        return bufferedPages.size();
    }

    @Override
    public long getSpilledBytes()
    {
        return 0;
    }

    @Override
    public int getSpilledPageCount()
    {
        return 0;
    }

    @Override
    public synchronized void close()
    {
        if (closed) {
            return;
        }
        bufferedPages.clear();
        bufferRetainedSizeInBytes = 0;
        activeTasks.clear();
        noMoreTasks = true;
        closed = true;
        unblockAll();
    }

    private synchronized void unblock(int unblock)
    {
        for (int i = 0; i < unblock; i++) {
            SettableFuture callback = blocked.poll();
            if (callback == null) {
                break;
            }
            executor.execute(() -> callback.set(null));
        }
    }

    private synchronized void unblockAll()
    {
        unblock(blocked.size());
        checkState(blocked.isEmpty(), "blocked callbacks is not empty");
    }

    private synchronized void throwIfFailed()
    {
        if (failure != null) {
            throwIfUnchecked(failure);
            throw new RuntimeException(failure);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy