All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.prestosql.jdbc.$internal.airlift.concurrent.AsyncSemaphore Maven / Gradle / Ivy

The newest version!
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.jdbc.$internal.airlift.concurrent;

import io.prestosql.jdbc.$internal.guava.util.concurrent.FutureCallback;
import io.prestosql.jdbc.$internal.guava.util.concurrent.Futures;
import io.prestosql.jdbc.$internal.guava.util.concurrent.ListenableFuture;
import io.prestosql.jdbc.$internal.guava.util.concurrent.SettableFuture;

import io.prestosql.jdbc.$internal.javax.annotation.concurrent.ThreadSafe;

import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;

import static io.prestosql.jdbc.$internal.guava.base.Preconditions.checkArgument;
import static io.prestosql.jdbc.$internal.guava.util.concurrent.MoreExecutors.directExecutor;
import static java.util.Objects.requireNonNull;

/**
 * Guarantees that no more than maxPermits of tasks will be run concurrently.
 * The class will rely on the ListenableFuture returned by the submitter function to determine
 * when a task has been completed. The submitter function NEEDS to be thread-safe and is recommended
 * to do the bulk of its work asynchronously.
 */
@ThreadSafe
public class AsyncSemaphore
{
    private final Queue> queuedTasks = new ConcurrentLinkedQueue<>();
    private final AtomicInteger counter = new AtomicInteger();
    private final Runnable runNextTask = this::runNext;
    private final int maxPermits;
    private final Executor submitExecutor;
    private final Function> submitter;

    public AsyncSemaphore(int maxPermits, Executor submitExecutor, Function> submitter)
    {
        checkArgument(maxPermits > 0, "must have at least one permit");
        this.maxPermits = maxPermits;
        this.submitExecutor = requireNonNull(submitExecutor, "submitExecutor is null");
        this.submitter = requireNonNull(submitter, "submitter is null");
    }

    public ListenableFuture submit(T task)
    {
        QueuedTask queuedTask = new QueuedTask<>(task);
        queuedTasks.add(queuedTask);
        acquirePermit();
        return queuedTask.getCompletionFuture();
    }

    private void acquirePermit()
    {
        if (counter.incrementAndGet() <= maxPermits) {
            // Kick off a task if not all permits have been handed out
            submitExecutor.execute(runNextTask);
        }
    }

    private void releasePermit()
    {
        if (counter.getAndDecrement() > maxPermits) {
            // Now that a task has finished, we can kick off another task if there are more tasks than permits
            submitExecutor.execute(runNextTask);
        }
    }

    private void runNext()
    {
        final QueuedTask queuedTask = queuedTasks.poll();
        ListenableFuture future = submitTask(queuedTask.getTask());
        FutureCallback callback = new FutureCallback()
        {
            @Override
            public void onSuccess(Object result)
            {
                queuedTask.markCompleted();
                releasePermit();
            }

            @Override
            public void onFailure(Throwable t)
            {
                queuedTask.markFailure(t);
                releasePermit();
            }
        };
        Futures.addCallback(future, callback, directExecutor());
    }

    private ListenableFuture submitTask(T task)
    {
        try {
            ListenableFuture future = submitter.apply(task);
            if (future == null) {
                return Futures.immediateFailedFuture(new NullPointerException("Submitter returned a null future for task: " + task));
            }
            return future;
        }
        catch (Exception e) {
            return Futures.immediateFailedFuture(e);
        }
    }

    private static class QueuedTask
    {
        private final T task;
        private final SettableFuture settableFuture = SettableFuture.create();

        private QueuedTask(T task)
        {
            this.task = requireNonNull(task, "task is null");
        }

        public T getTask()
        {
            return task;
        }

        public void markFailure(Throwable throwable)
        {
            settableFuture.setException(throwable);
        }

        public void markCompleted()
        {
            settableFuture.set(null);
        }

        public ListenableFuture getCompletionFuture()
        {
            return settableFuture;
        }
    }
}