All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hubspot.singularity.async.AsyncSemaphore Maven / Gradle / Ivy

package com.hubspot.singularity.async;

import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.locks.StampedLock;
import java.util.function.Supplier;

import com.google.common.base.Suppliers;
import com.google.common.util.concurrent.ThreadFactoryBuilder;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;

/**
 * AsyncSemaphore guarantees that at most N executions
 * of an underlying completablefuture exeuction are occuring
 * at the same time.
 *
 * The general strategy is to try acquiring a permit for execution.
 * If it succeeds, the semaphore just proceeds normally. Otherwise,
 * it pushes the execution onto a queue.
 *
 * At the completion of any execution, the queue is checked for
 * any pending executions. If any executions are found, they are
 * executed in order.
 *
 * @param 
 */
public class AsyncSemaphore {
  private final StampedLock stampedLock = new StampedLock();
  private final AtomicInteger concurrentRequests = new AtomicInteger();
  private final Queue> requestQueue;
  private final com.google.common.base.Supplier queueRejectionThreshold;
  private final Supplier timeoutExceptionSupplier;
  private final PermitSource permitSource;
  private final ScheduledExecutorService flushingExecutor = Executors.newScheduledThreadPool(5,
      new ThreadFactoryBuilder().setDaemon(true).setNameFormat("async-semaphore-flush-pool- %d").build());;

  /**
   * Create an AsyncSemaphore with the given limit.
   *
   * @param concurrentRequestLimit - A supplier saying how many concurrent requests are allowed
   */
  public static AsyncSemaphoreBuilder newBuilder(Supplier concurrentRequestLimit) {
    return new AsyncSemaphoreBuilder(new PermitSource(concurrentRequestLimit));
  }

  /**
   * Create an AsyncSemaphore with the given permit source.
   *
   * @param permitSource - A source for the permits used by the semaphore
   */
  public static AsyncSemaphoreBuilder newBuilder(PermitSource permitSource) {
    return new AsyncSemaphoreBuilder(permitSource);
  }

  AsyncSemaphore(PermitSource permitSource,
                 Queue> requestQueue,
                 Supplier queueRejectionThreshold,
                 Supplier timeoutExceptionSupplier,
                 boolean flushQueuePeriodically) {
    this.permitSource = permitSource;
    this.requestQueue = requestQueue;
    this.queueRejectionThreshold = Suppliers.memoizeWithExpiration(queueRejectionThreshold::get, 1, TimeUnit.MINUTES);
    this.timeoutExceptionSupplier = timeoutExceptionSupplier;
    if (flushQueuePeriodically) {
      flushingExecutor.scheduleAtFixedRate(() -> flushQueue(), 1, 1, TimeUnit.SECONDS);
    }
  }

  public CompletableFuture call(Callable> execution) {
    return callWithQueueTimeout(execution, Optional.empty());
  }

  /**
   * Try to execute the supplier if there are enough permits available.
   * If not, add the execution to a queue (if there is room).
   * If the queue attempts to start the execution after the timeout
   * has passed, short circuit the execution and complete the future
   * exceptionally with TimeoutException
   *
   * @param execution - The execution of the item
   * @param timeout - The time before which we'll short circuit the execution
   * @param timeUnit
   * @return
   */
  public CompletableFuture callWithQueueTimeout(Callable> execution, long timeout, TimeUnit timeUnit) {
    return callWithQueueTimeout(execution, Optional.of(TimeUnit.MILLISECONDS.convert(timeout, timeUnit)));
  }

  private CompletableFuture callWithQueueTimeout(Callable> execution,
                                                    Optional timeoutInMillis) {

    if (timeoutInMillis.isPresent() && timeoutInMillis.get() <= 0) {
      return CompletableFutures.exceptionalFuture(timeoutExceptionSupplier.get());

    } else if (tryAcquirePermit()) {
      CompletableFuture responseFuture = executeCall(execution);
      pollQueueOnCompletion(responseFuture);
      return responseFuture;

    } else {
      DelayedExecution delayedExecution = new DelayedExecution<>(execution, timeoutExceptionSupplier, timeoutInMillis);
      if (!tryEnqueueAttempt(delayedExecution)) {
        return CompletableFutures.exceptionalFuture(
            new RejectedExecutionException("Could not queue future for execution.")
        );
      }
      return delayedExecution.getResponseFuture();
    }
  }

  private  void pollQueueOnCompletion(CompletableFuture future) {
    future.whenComplete((ignored1, ignored2) -> {

      // iterate through expired executions rather than using callbacks
      // to avoid StackoverflowError if futures are completed or expired
      while (true) {
        DelayedExecution nextExecutionDue = requestQueue.poll();

        if (nextExecutionDue == null) {
          releasePermit();
          return;

        } else if (nextExecutionDue.isExpired()) {
          nextExecutionDue.getResponseFuture().completeExceptionally(timeoutExceptionSupplier.get());

        } else {
          // reuse the previous permit for the queued request
          CompletableFuture nextExecution = nextExecutionDue.execute();

          if (!nextExecution.isDone()) {
            pollQueueOnCompletion(nextExecution);
            return;
          }
        }
      }
    });
  }

  private boolean tryAcquirePermit() {
    boolean acquired = permitSource.tryAcquire();

    if (acquired) {
      concurrentRequests.incrementAndGet();
    }

    return acquired;
  }

  private int releasePermit() {
    permitSource.release();
    return concurrentRequests.decrementAndGet();
  }

  private static  CompletableFuture executeCall(Callable> execution) {
    try {
      return execution.call();
    } catch (Throwable t) {
      return CompletableFutures.exceptionalFuture(t);
    }
  }

  /**
   * enqueue the attempt into our underlying queue. since it's expensive to dynamically
   * resize the queue, we have a separate rejection threshold which, if less than 0 is
   * ignored, but otherwise is the practical cap on the size of the queue.
   */
  private boolean tryEnqueueAttempt(DelayedExecution delayedExecution) {
    int rejectionThreshold = queueRejectionThreshold.get();
    if (rejectionThreshold < 0) {
      return requestQueue.offer(delayedExecution);
    }
    long stamp = stampedLock.readLock();
    try {
      while (requestQueue.size() < rejectionThreshold) {
        long writeStamp = stampedLock.tryConvertToWriteLock(stamp);
        if (writeStamp != 0L) {
          stamp = writeStamp;
          return requestQueue.offer(delayedExecution);
        } else {
          stampedLock.unlock(stamp);
          stamp = stampedLock.writeLock();
        }
      }
      return false;
    } finally {
      stampedLock.unlock(stamp);
    }
  }

  private void  flushQueue() {
    if (tryAcquirePermit()) {
      // Pass in an already completed future so that we execute the callback on this thread
      pollQueueOnCompletion(CompletableFuture.completedFuture(true));
    }
  }

  static class DelayedExecution {
    private static final AtomicIntegerFieldUpdater EXECUTED_UPDATER = AtomicIntegerFieldUpdater.newUpdater(
        DelayedExecution.class,
        "executed"
    );
    private final Callable> execution;
    private final CompletableFuture responseFuture;
    private final Supplier timeoutExceptionSupplier;
    private final long deadlineEpochMillis;
    @SuppressWarnings( "unused" ) // use the EXECUTED_UPDATER
    private volatile int executed = 0;

    private DelayedExecution(Callable> execution,
                             Supplier timeoutExceptionSupplier,
                             Optional timeoutMillis) {
      this.execution = execution;
      this.responseFuture = new CompletableFuture<>();
      this.timeoutExceptionSupplier = timeoutExceptionSupplier;
      this.deadlineEpochMillis = timeoutMillis.map(x -> System.currentTimeMillis() + x).orElse(0L);
    }

    private CompletableFuture getResponseFuture() {
      return responseFuture;
    }

    @SuppressFBWarnings("NP_NONNULL_PARAM_VIOLATION") // https://github.com/findbugsproject/findbugs/issues/79
    private CompletableFuture execute() {
      if (!EXECUTED_UPDATER.compareAndSet(this, 0, 1)) {
        return CompletableFuture.completedFuture(null);
      }

      return executeCall(execution).whenComplete((response, ex) -> {
        if (ex == null) {
          responseFuture.complete(response);
        } else {
          responseFuture.completeExceptionally(ex);
        }
      }).thenApply(ignored -> null);
    }

    private boolean isExpired() {
      return deadlineEpochMillis > 0 && System.currentTimeMillis() > deadlineEpochMillis;
    }
  }

  public int getQueueSize() {
    long stamp = stampedLock.tryOptimisticRead();
    int queueSize = requestQueue.size();
    if (!stampedLock.validate(stamp)) {
      stamp = stampedLock.readLock();
      try {
        queueSize = requestQueue.size();
      } finally {
        stampedLock.unlockRead(stamp);
      }
    }
    return queueSize;
  }

  public int getConcurrentRequests() {
    return concurrentRequests.get();
  }
}