All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.joekerouac.common.tools.collection.MultiChannelQueueImpl Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
 * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
 * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
 * License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 */
package com.github.joekerouac.common.tools.collection;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import com.github.joekerouac.common.tools.constant.ExceptionProviderConst;
import com.github.joekerouac.common.tools.lock.LockTaskUtil;
import com.github.joekerouac.common.tools.util.Assert;

/**
 * 并发安全的实现;
 * 
 * 注意:如果{@link #take()}和{@link #removeChannel(Object)}并发执行,有可能存在通道已经被删除了,但是take仍然会将对应通道的数据返回
 * 
 * @since 1.0.0
 * @author JoeKerouac
 * @date 2022-10-14 14:37:00
 */
public class MultiChannelQueueImpl implements MultiChannelQueue {

    /**
     * 存放所有通道
     */
    private final Map> map;

    /**
     * 全局令牌,用于全局控速
     */
    private final Semaphore semaphore;

    /**
     * 全局队列,主要为了数据消费的公平,使整个队列是FIFO的,list会经常从head删除,往tail添加,所以这里使用LinkedList
     */
    private final LinkedList> list;

    /**
     * 全局锁,用于操作{@link #list}的时候加锁,同时也是为了获取下边的condition
     */
    private final Lock lock;

    /**
     * 条件,添加数据和释放令牌的时候会触发通知,删除数据的时候会消费通知
     */
    private final Condition condition;

    public MultiChannelQueueImpl(int maxConcurrency) {
        this.map = new ConcurrentHashMap<>();
        this.semaphore = new Semaphore(maxConcurrency);
        // 如无必要,请勿修改这个list的实现类,因为这个list会频繁随机删除(不一定是删除头部或者尾部),所以采用LinkedList
        this.list = new LinkedList<>();
        this.lock = new ReentrantLock();
        this.condition = lock.newCondition();
    }

    @Override
    public boolean addChannel(ID id, int size, int maxConcurrency) {
        Assert.argNotNull(id, "id");
        Assert.assertTrue(size > 0, "size必须大于0", ExceptionProviderConst.IllegalArgumentExceptionProvider);
        Assert.assertTrue(maxConcurrency > 0, "maxConcurrency必须大于0",
            ExceptionProviderConst.IllegalArgumentExceptionProvider);

        // 初始化通道节点
        ChannelEntry entry = new ChannelEntry<>();
        // 通道内元素正常情况下都是从head开始删除的
        entry.list = new LinkedList<>();
        entry.lock = new ReentrantLock();
        entry.condition = entry.lock.newCondition();
        entry.size = size;
        entry.token = new Semaphore(maxConcurrency);

        return map.putIfAbsent(id, entry) == null;
    }

    @Override
    public List removeChannel(ID id) {
        ChannelEntry remove = map.remove(id);
        if (remove == null) {
            return Collections.emptyList();
        } else {
            return remove.list;
        }
    }

    @Override
    public Pair take() throws InterruptedException {
        return take0(0, TimeUnit.MILLISECONDS);
    }

    @Override
    public Pair take(long timeout, TimeUnit unit) throws InterruptedException {
        Assert.assertTrue(timeout > 0, "timeout必须大于0", ExceptionProviderConst.IllegalArgumentExceptionProvider);
        Assert.argNotNull(unit, "unit");
        return take0(timeout, unit);
    }

    /**
     * 从队列中获取通道并发数还未达到最大并且较早放入的数据
     * 
     * @param timeout
     *            超时时间,小于等于0表示没有超时判断
     * @param unit
     *            时间单位,不能为null
     * @return 通道数据,如果指定了超时时间则可能返回null
     * @throws InterruptedException
     *             中断异常
     */
    private Pair take0(final long timeout, final TimeUnit unit) throws InterruptedException {
        // 需要先申请全局令牌,注意,申请令牌的时候不要加锁
        semaphore.acquire();

        // 结束时间
        long end = System.currentTimeMillis() + unit.toMillis(timeout);
        try {
            return LockTaskUtil.runInterruptedTaskWithLock(lock, () -> {
                while (true) {
                    Pair pair = null;

                    Iterator> iterator = list.iterator();

                    // 这里要有序从头到尾遍历list,确保公平
                    while (iterator.hasNext()) {
                        Pair p = iterator.next();

                        ChannelEntry entry = map.get(p.getKey());
                        // 尝试申请通道的令牌,申请到就消费这个通道的这个数据
                        if (entry.token.tryAcquire()) {
                            pair = p;
                            // 直接删除
                            iterator.remove();
                            break;
                        }

                    }

                    // pair不等于null的时候说明成功获取到数据了,否则说明当前没有可获取的数据(可能是通道并发都达到上限导致虽然有数据但是无法获取),继续获取
                    if (pair == null) {
                        // timeout大于0,表示指定了超时时间,否则表示永久等待
                        if (timeout > 0) {
                            // 最长等待时间
                            long waitTime = end - System.currentTimeMillis();

                            // 如果当前已经没有等待时间了,直接返回null
                            if (waitTime <= 0) {
                                return null;
                            }

                            // 等待下一次通知,在队列增加数据、释放令牌的时候会通知
                            if (!condition.await(waitTime, TimeUnit.MILLISECONDS)) {
                                // 如果等到超时也没有等来通知,则直接返回null
                                return null;
                            }
                            // 等到了通知,继续循环
                        } else {
                            condition.await();
                        }
                    } else {
                        ChannelEntry entry = map.get(pair.getKey());
                        // 如果这个entry已经被删除了(并发调用了删除通道)
                        if (entry != null) {
                            // 从指定队列中移除
                            Assert.assertTrue(entry.remove(pair.getValue()), "未预期异常,这里应该删除成功的",
                                ExceptionProviderConst.IllegalStateExceptionProvider);
                        }

                        return pair;
                    }
                }
            });
        } catch (InterruptedException e) {
            semaphore.release();
            throw e;
        }
    }

    @Override
    public void add(ID id, T data) throws InterruptedException {
        add0(id, data, 0, TimeUnit.MILLISECONDS);
    }

    @Override
    public boolean add(ID id, T data, long timeout, TimeUnit unit) throws InterruptedException {
        Assert.assertTrue(timeout > 0, ExceptionProviderConst.IllegalArgumentExceptionProvider);

        return add0(id, data, timeout, unit);
    }

    /**
     * 往通道中增加数据,如果通道队列已满则阻塞等待直到可以放入或者等到超时返回false
     * 
     * @param id
     *            id
     * @param data
     *            要放入的数据,不能为null
     * @param timeout
     *            超时时间,小于等于0的时候表示不用超时判断
     * @param unit
     *            时间单位
     * @return 如果到超时时间还没有放入成功则返回false,放入成功返回true
     * @throws InterruptedException
     *             中断异常
     */
    private boolean add0(ID id, T data, long timeout, TimeUnit unit) throws InterruptedException {
        Assert.argNotNull(id, "id");
        Assert.argNotNull(data, "data");
        Assert.argNotNull(unit, "unit");

        ChannelEntry entry = map.get(id);

        if (entry == null) {
            throw new IllegalStateException("当前ID还未初始化队列,id:" + id);
        }

        // 先往通道自己的队列中添加
        if (entry.add(data, timeout, unit)) {
            // 全局队列添加数据,注意要加锁,因为list不是线程安全的变量,并且我们要调用condition
            LockTaskUtil.runInterruptedTaskWithLock(lock, () -> {
                list.add(new Pair<>(id, data));
                // 通知数据变化
                condition.signalAll();
            });
            return true;
        } else {
            return false;
        }
    }

    @Override
    public boolean remove(ID id, T data) {
        ChannelEntry entry = map.get(id);

        if (entry == null) {
            return false;
        }

        // 这里要加锁,要把两个数据一起移除
        return LockTaskUtil.runWithLock(lock, () -> {
            list.remove(new Pair<>(id, data));
            return entry.remove(data);
        });
    }

    @Override
    public void consumed(ID id) {
        Assert.argNotNull(id, "id");

        // 无论如何,先释放全局令牌
        semaphore.release();

        ChannelEntry entry = map.get(id);
        // 如果通道节点不存在,可能是被删除了,不用管
        if (entry != null) {
            entry.token.release();
            // 通知,有令牌被释放了,只有令牌真的释放才会通知,既然该通道已经不存在了,那就不通知了,因为即使通知了后边也找不到该通道的数据了
            LockTaskUtil.runWithLock(lock, condition::signalAll);
        }
    }

    @Override
    public void clear() {
        LockTaskUtil.runWithLock(lock, () -> {
            map.values().forEach(entry -> entry.list.clear());
            map.clear();
            list.clear();
        });
    }

    private static class ChannelEntry {

        /**
         * 通道的队列
         */
        List list;

        /**
         * 对{@link #list}操作加的锁
         */
        Lock lock;

        /**
         * 删除{@link #list}中的数据的时候通知,往{@link #list}中添加数据的时候消费,用于限制list长度
         */
        Condition condition;

        /**
         * 指定{@link #list}的最大长度
         */
        int size;

        /**
         * 令牌
         */
        Semaphore token;

        /**
         * 往通道中增加数据,如果通道队列已满则阻塞等待直到可以放入或者等到超时返回false
         * 
         * @param data
         *            要放入的数据,不能为null
         * @param timeout
         *            超时时间,小于等于0的时候表示不用超时判断
         * @param unit
         *            时间单位
         * @return 如果到超时时间还没有放入成功则返回false,放入成功返回true
         * @throws InterruptedException
         *             中断异常
         */
        boolean add(T data, long timeout, TimeUnit unit) throws InterruptedException {
            long end = System.currentTimeMillis() + unit.toMillis(timeout);

            return LockTaskUtil.runInterruptedTaskWithLock(lock, () -> {
                while (true) {
                    if (list.size() >= size) {
                        long waitTime = end - System.currentTimeMillis();

                        // 如果等到了超时,直接返回false
                        if (waitTime <= 0 || !condition.await(waitTime, TimeUnit.MILLISECONDS)) {
                            return false;
                        }
                    } else {
                        list.add(data);
                        return true;
                    }
                }
            });
        }

        /**
         * 删除指定数据
         * 
         * @param data
         *            要删除的数据
         * @return true表示删除成功
         */
        boolean remove(T data) {
            return LockTaskUtil.runWithLock(lock, () -> {
                if (list.remove(data)) {
                    condition.signalAll();
                    return true;
                } else {
                    return false;
                }
            });
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy