All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.alibaba.schedulerx.worker.container.DualConcurrencyThreadPoolExecutor Maven / Gradle / Ivy

There is a newer version: 1.12.2
Show newest version
package com.alibaba.schedulerx.worker.container;

import com.google.common.collect.Maps;

import java.util.Map;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * DualConcurrencyThreadPoolExecutor
 * @author yaohui
 * @create 2023/6/3 12:33 AM
 **/
public class DualConcurrencyThreadPoolExecutor extends ThreadPoolExecutor {

    public DualConcurrencyThreadPoolExecutor(int corePoolSize, int workQueueSize) {
        super(corePoolSize, corePoolSize, 0, TimeUnit.NANOSECONDS, new SemaphoreBlockingQueue(workQueueSize));
        super.prestartAllCoreThreads();
    }

    public DualConcurrencyThreadPoolExecutor(int corePoolSize, int workQueueSize, ThreadFactory threadFactory) {
        super(corePoolSize, corePoolSize, 0, TimeUnit.NANOSECONDS, new SemaphoreBlockingQueue(workQueueSize), threadFactory);
        super.prestartAllCoreThreads();
    }

    public DualConcurrencyThreadPoolExecutor(int corePoolSize, int workQueueSize, RejectedExecutionHandler handler) {
        super(corePoolSize, corePoolSize, 0, TimeUnit.NANOSECONDS, new SemaphoreBlockingQueue(workQueueSize), handler);
        super.prestartAllCoreThreads();
    }

    public DualConcurrencyThreadPoolExecutor(int corePoolSize, int workQueueSize, ThreadFactory threadFactory, RejectedExecutionHandler handler) {

        super(corePoolSize, corePoolSize, 0, TimeUnit.NANOSECONDS, new SemaphoreBlockingQueue(workQueueSize), threadFactory, handler);
        super.prestartAllCoreThreads();
    }


    public void registerSemaphore(K key, Integer concurrency) {
        SemaphoreBlockingQueue queue = (SemaphoreBlockingQueue)this.getQueue();
        queue.putSemaphore(key, concurrency);
    }

    public void clear(K key) {
        SemaphoreBlockingQueue queue = (SemaphoreBlockingQueue)this.getQueue();
        queue.clear(key);
    }

    @Override
    protected void afterExecute(Runnable r, Throwable t) {
        SemaphoreBlockingQueue queue = (SemaphoreBlockingQueue)this.getQueue();
        if (r instanceof MarkedRunnable) {
            queue.releaseSemaphore(((MarkedRunnable) r).identify());
        }
    }

    @Override
    protected RunnableFuture newTaskFor(Runnable runnable, Object value) {
        return new MarkedFutureTask((MarkedRunnable)runnable, value);
    }

    static class MarkedFutureTask extends FutureTask implements MarkedRunnable  {

        private MarkedRunnable runnable;

        public MarkedFutureTask(Callable callable) {
            super(callable);
            throw new IllegalArgumentException("can not support callable.");
        }

        public MarkedFutureTask(MarkedRunnable runnable, Object result) {
            super(runnable, result);
            this.runnable = runnable;
        }

        @Override
        public T identify() {
            return runnable.identify();
        }
    }

    static class SemaphoreBlockingQueue extends LinkedBlockingQueue> {

        private final Map semaphoreMap = Maps.newHashMap();

        private final Map>> waitQueueMap = Maps.newHashMap();

        private final AtomicInteger waitCount = new AtomicInteger();

        public SemaphoreBlockingQueue(int size) {
            super(size);
        }

        public SemaphoreBlockingQueue() {
            super();
        }

        @Override
        public MarkedRunnable take() throws InterruptedException {
            // 此处会阻塞,不能加同步锁
            while (true) {
                MarkedRunnable markedRunnable = super.take();
                if (markedRunnable == null) {
                    return null;
                }
                if (checkConcurrency(markedRunnable)) {
                    return markedRunnable;
                }
            }
        }

        @Override
        public MarkedRunnable poll(long timeout, TimeUnit unit) throws InterruptedException {
            while (true) {
                MarkedRunnable markedRunnable = super.poll(timeout, unit);
                if (markedRunnable == null) {
                    return null;
                }
                if (checkConcurrency(markedRunnable)) {
                    return markedRunnable;
                }
            }
        }

        @Override
        public boolean offer(MarkedRunnable runnable) {
            if (this.remainingCapacity() <= 0) {
                return false;
            }
            return super.offer(runnable);
        }

        @Override
        public boolean offer(MarkedRunnable runnable, long timeout, TimeUnit unit) throws InterruptedException {
            if (this.remainingCapacity() <= 0) {
                return false;
            }
            return super.offer(runnable, timeout, unit);
        }

        @Override
        public void put(MarkedRunnable runnable) throws InterruptedException {
            // 暂不重写,线程池未使用
            super.put(runnable);
        }


        @Override
        public int size() {
            return super.size() + waitCount.get();
        }

        @Override
        public int remainingCapacity() {
            synchronized (this) {
                return super.remainingCapacity() - waitCount.get();
            }
        }

        /**
         * checkConcurrency
         * @param markedRunnable
         * @return
         */
        private boolean checkConcurrency(MarkedRunnable markedRunnable) {
            T key = markedRunnable.identify();
            Semaphore semaphore = semaphoreMap.get(key);
            if (semaphore == null) {
                return true;
            }
            synchronized (semaphore) {
                if(!semaphore.tryAcquire()) {
                    // 设置等待
                    Queue> waitQueue = waitQueueMap.get(key);
                    if (waitQueue == null) {
                        if (!semaphoreMap.containsKey(key)) {
                            return true;
                        }
                        waitQueue = new ConcurrentLinkedQueue<>();
                        waitQueueMap.put(key, waitQueue);
                    }
                    if (waitQueue.offer(markedRunnable)) {
                        waitCount.incrementAndGet();
                        return false;
                    } else {
                        // 如果添加等待队列失败,则不进行并发限制
                        return true;
                    }
                }
            }
            return true;
        }

        public void putSemaphore(T key, Integer concurrency) {
            Semaphore semaphore = semaphoreMap.get(key);
            if (semaphore == null) {
                synchronized (semaphoreMap) {
                    semaphore = semaphoreMap.get(key);
                    if (semaphore == null) {
                        semaphoreMap.put(key, new Semaphore(concurrency));
                    }
                }
            }
        }

        public void releaseSemaphore(T key) {
            Semaphore semaphore = semaphoreMap.get(key);
            if (semaphore != null) {
                synchronized (semaphore) {
                    semaphore.release();
                    // 释放等待队列任务
                    Queue> waitQueue = waitQueueMap.get(key);
                    if (waitQueue != null) {
                        MarkedRunnable runnable = waitQueue.poll();
                        if (runnable != null) {
                            waitCount.decrementAndGet();
                            super.offer(runnable);
                        }
                    }
                }
            }
        }

        public void clear(T key) {
            Semaphore semaphore = semaphoreMap.remove(key);
            if (semaphore!=null) {
                synchronized (semaphore) {
                    Queue waitQueue = waitQueueMap.remove(key);
                    if (waitQueue != null) {
                        waitCount.addAndGet(-1 * waitQueue.size());
                    }
                }
            } else {
                Queue waitQueue = waitQueueMap.remove(key);
                if (waitQueue != null) {
                    waitCount.addAndGet(-1 * waitQueue.size());
                }
            }
        }

        @Override
        public void clear() {
            for (T key:semaphoreMap.keySet()) {
                clear(key);
            }
            super.clear();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy