All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.bucket4j.util.concurrent.batch.BatchHelper Maven / Gradle / Ivy

The newest version!
/*-
 * ========================LICENSE_START=================================
 * Bucket4j
 * %%
 * Copyright (C) 2015 - 2022 Vladimir Bukhtoyarov
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =========================LICENSE_END==================================
 */
package io.github.bucket4j.util.concurrent.batch;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Function;

import static java.util.Objects.requireNonNull;

/**
 * Helper class for batching
 *
 * @param  Task type
 * @param  Task result type
 * @param  Combined task type
 * @param  Combined task result
 */
public class BatchHelper {

    private static final Object NEED_TO_EXECUTE_NEXT_BATCH = new Object();
    private static final WaitingTask QUEUE_EMPTY_BUT_EXECUTION_IN_PROGRESS = new WaitingTask<>(null);
    private static final WaitingTask QUEUE_EMPTY = new WaitingTask<>(null);

    private final Function, CT> taskCombiner;
    private final Function combinedTaskExecutor;
    private final Function taskExecutor;

    private final BiFunction> combinedResultSplitter;

    private final AtomicReference headReference = new AtomicReference<>(QUEUE_EMPTY);

    public static  BatchHelper create(
            Function, CT> taskCombiner,
            Function combinedTaskExecutor,
            Function taskExecutor,
            BiFunction> combinedResultSplitter) {
        return new BatchHelper<>(taskCombiner, combinedTaskExecutor, taskExecutor, combinedResultSplitter);
    }

    public static  BatchHelper create(
            Function, CT> taskCombiner,
            Function combinedTaskExecutor,
            BiFunction> combinedResultSplitter) {
        Function taskExecutor = new Function<>() {
            @Override
            public R apply(T task) {
                CT combinedTask = taskCombiner.apply(Collections.singletonList(task));
                CR combinedResult = combinedTaskExecutor.apply(combinedTask);
                List results = combinedResultSplitter.apply(combinedTask, combinedResult);
                return results.get(0);
            }
        };
        return new BatchHelper<>(taskCombiner, combinedTaskExecutor, taskExecutor, combinedResultSplitter);
    }

    private BatchHelper(Function, CT> taskCombiner,
                        Function combinedTaskExecutor,
                        Function taskExecutor,
                        BiFunction> combinedResultSplitter) {
        this.taskCombiner = requireNonNull(taskCombiner);
        this.combinedTaskExecutor = requireNonNull(combinedTaskExecutor);
        this.taskExecutor = requireNonNull(taskExecutor);
        this.combinedResultSplitter = requireNonNull(combinedResultSplitter);
    }

    public R execute(T task) {
        WaitingTask waitingNode = lockExclusivelyOrEnqueue(task);

        if (waitingNode == null) {
            try {
                return taskExecutor.apply(task);
            } finally {
                wakeupAnyThreadFromNextBatchOrFreeLock();
            }
        }

        R result = waitingNode.waitUninterruptedly();
        if (result != NEED_TO_EXECUTE_NEXT_BATCH) {
            // our future completed by another thread from current batch
            return result;
        }

        // current thread is responsible to execute the batch of commands
        try {
            return executeBatch(waitingNode);
        } finally {
            wakeupAnyThreadFromNextBatchOrFreeLock();
        }
    }

    private R executeBatch(WaitingTask currentWaitingNode) {
        List> waitingNodes = takeAllWaitingTasksOrFreeLock();

        if (waitingNodes.size() == 1) {
            T singleCommand = waitingNodes.get(0).wrappedTask;
            return taskExecutor.apply(singleCommand);
        }

        try {
            int resultIndex = -1;
            List commandsInBatch = new ArrayList<>(waitingNodes.size());
            for (int i = 0; i < waitingNodes.size(); i++) {
                WaitingTask waitingNode = waitingNodes.get(i);
                commandsInBatch.add(waitingNode.wrappedTask);
                if (waitingNode == currentWaitingNode) {
                    resultIndex = i;
                }
            }
            CT multiCommand = taskCombiner.apply(commandsInBatch);

            CR multiResult = combinedTaskExecutor.apply(multiCommand);
            List singleResults = combinedResultSplitter.apply(multiCommand, multiResult);
            for (int i = 0; i < waitingNodes.size(); i++) {
                R singleResult = singleResults.get(i);
                waitingNodes.get(i).future.complete(singleResult);
            }

            return singleResults.get(resultIndex);
        } catch (Throwable e) {
            for (WaitingTask waitingNode : waitingNodes) {
                waitingNode.future.completeExceptionally(e);
            }
            throw new BatchFailedException(e);
        }
    }

    private WaitingTask lockExclusivelyOrEnqueue(T command) {
        WaitingTask currentTask = new WaitingTask<>(command);

        while (true) {
            WaitingTask previous = headReference.get();
            if (previous == QUEUE_EMPTY) {
                if (headReference.compareAndSet(previous, QUEUE_EMPTY_BUT_EXECUTION_IN_PROGRESS)) {
                    return null;
                } else {
                    continue;
                }
            }

            currentTask.previous = previous;
            if (headReference.compareAndSet(previous, currentTask)) {
                return currentTask;
            } else {
                currentTask.previous = null;
            }
        }
    }

    private void wakeupAnyThreadFromNextBatchOrFreeLock() {
        while (true) {
            WaitingTask previous = headReference.get();
            if (previous == QUEUE_EMPTY_BUT_EXECUTION_IN_PROGRESS) {
                if (headReference.compareAndSet(QUEUE_EMPTY_BUT_EXECUTION_IN_PROGRESS, QUEUE_EMPTY)) {
                    return;
                } else {
                    continue;
                }
            } else if (previous != QUEUE_EMPTY) {
                previous.future.complete((R) NEED_TO_EXECUTE_NEXT_BATCH);
                return;
            } else {
                // should never come there
                String msg = "Detected illegal usage of API, wakeupAnyThreadFromNextBatchOrFreeLock should not be called on empty queue";
                throw new IllegalStateException(msg);
            }
        }
    }

    private List> takeAllWaitingTasksOrFreeLock() {
        WaitingTask head;
        while (true) {
            head = headReference.get();
            if (head == QUEUE_EMPTY_BUT_EXECUTION_IN_PROGRESS) {
                if (headReference.compareAndSet(QUEUE_EMPTY_BUT_EXECUTION_IN_PROGRESS, QUEUE_EMPTY)) {
                    return Collections.emptyList();
                } else {
                    continue;
                }
            }

            if (headReference.compareAndSet(head, QUEUE_EMPTY_BUT_EXECUTION_IN_PROGRESS)) {
                break;
            }
        }

        WaitingTask current = head;
        List> waitingNodes = new ArrayList<>();
        while (current != QUEUE_EMPTY_BUT_EXECUTION_IN_PROGRESS) {
            waitingNodes.add(current);
            WaitingTask tmp = current.previous;
            current.previous = null; // nullify the reference to previous node in order to avoid GC nepotism
            current = tmp;
        }
        Collections.reverse(waitingNodes);
        return waitingNodes;
    }

    private static class WaitingTask {

        public final T wrappedTask;
        public final CompletableFuture future = new CompletableFuture<>();
        public final Thread thread = Thread.currentThread();

        public WaitingTask previous;

        WaitingTask(T task) {
            this.wrappedTask = task;
        }

        public R waitUninterruptedly() {
            boolean wasInterrupted = false;
            try {
                while (true) {
                    wasInterrupted = wasInterrupted || Thread.interrupted();
                    try {
                        return future.get();
                    } catch (InterruptedException e) {
                        wasInterrupted = true;
                    } catch (ExecutionException e) {
                        throw new BatchFailedException(e.getCause());
                    }
                }
            } finally {
                if (wasInterrupted) {
                    Thread.currentThread().interrupt();
                }
            }
        }
    }

    public static class BatchFailedException extends IllegalStateException {

        public BatchFailedException(Throwable e) {
            super(e);
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy