All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.phantomthief.concurrent.AdaptiveExecutor Maven / Gradle / Ivy

There is a newer version: 0.1.4
Show newest version
/**
 * 
 */
package com.github.phantomthief.concurrent;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.ThreadPoolExecutor.CallerRunsPolicy;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntUnaryOperator;
import java.util.stream.Collectors;

import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;

/**
 * @author w.vela
 */
public class AdaptiveExecutor {

    private static final Object EMPTY_OBJECT = new Object();
    private static final CallerRunsPolicy CALLER_RUNS_POLICY = new CallerRunsPolicy();
    private static final ListeningExecutorService DIRECT_EXECUTOR_SERVICE = MoreExecutors
            .newDirectExecutorService();

    private final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(getClass());

    private final IntUnaryOperator threadCountFunction;
    private final ThreadFactory threadFactory;
    private final boolean callerRuns;

    private volatile int threadCounter;

    /**
     * @param globalMaxThread
     * @param threadCountFunction
     * @param threadFactory
     * @param callerRuns
     */
    private AdaptiveExecutor(int globalMaxThread, //
            IntUnaryOperator threadCountFunction, //
            ThreadFactory threadFactory, //
            boolean callerRuns) {
        this.threadCountFunction = threadCountFunction;
        this.threadFactory = threadFactory;
        this.callerRuns = callerRuns;
        this.threadCounter = globalMaxThread;
    }

    public final  void run(Collection keys, Consumer func) {
        invokeAll(keys, i -> {
            func.accept(i);
            return EMPTY_OBJECT;
        });
    }

    public final  Map invokeAll(Collection keys, Function func) {
        List> calls = keys.stream().> map(k -> () -> func.apply(k))
                .collect(Collectors.toList());
        List callResult = invokeAll(calls);
        Iterator iterator = callResult.iterator();
        Map result = new HashMap<>();
        for (K key : keys) {
            V r;
            if (iterator.hasNext()) {
                r = iterator.next();
            } else {
                r = null;
            }
            result.put(key, r);
        }
        return result;
    }

    public final  List invokeAll(List> calls) {
        if (calls == null || calls.isEmpty()) {
            return Collections.emptyList();
        }
        ExecutorService executorService = newExecutor(calls.size(), callerRuns);
        try {
            List> invokeAll = executorService.invokeAll(calls);
            return invokeAll.stream().map(this::futureGet).collect(Collectors.toList());
        } catch (Throwable e) {
            logger.error("Ops.", e);
            return Collections.emptyList();
        } finally {
            shutdownExecutor(executorService);
        }
    }

    private ExecutorService newExecutor(int keySize, boolean callerRuns) {
        int needThread = threadCountFunction.applyAsInt(keySize);
        if (needThread <= 1) {
            logger.trace("need thread one, using director service.");
            return DIRECT_EXECUTOR_SERVICE;
        }
        int leftThread;
        synchronized (this) {
            if (threadCounter >= needThread) {
                leftThread = needThread;
                threadCounter -= needThread;
            } else {
                leftThread = threadCounter;
                threadCounter = 0;
            }
        }
        if (leftThread <= 0) {
            logger.trace("no left thread availabled, using direct executor service.");
            return DIRECT_EXECUTOR_SERVICE;
        } else {
            ThreadPoolExecutor threadPoolExecutor;
            if (callerRuns) {
                threadPoolExecutor = new ThreadPoolExecutor(leftThread, leftThread, 0L,
                        TimeUnit.MILLISECONDS, new LinkedBlockingQueue(1), threadFactory,
                        CALLER_RUNS_POLICY);
            } else {
                threadPoolExecutor = new ThreadPoolExecutor(leftThread, leftThread, 0L,
                        TimeUnit.MILLISECONDS, new LinkedBlockingQueue(1) {

                            private static final long serialVersionUID = 1L;

                            @Override
                            public boolean offer(Runnable e) {
                                try {
                                    put(e);
                                    return true;
                                } catch (InterruptedException ie) {
                                    Thread.currentThread().interrupt();
                                }
                                return false;
                            }
                        }, threadFactory);
            }
            logger.trace("init a executor, thread count:{}", leftThread);
            return threadPoolExecutor;
        }
    }

    private final  V futureGet(Future future) {
        try {
            return future.get();
        } catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
    }

    private final void shutdownExecutor(ExecutorService executorService) {
        if (executorService instanceof ListeningExecutorService) {
            return;
        }
        if (MoreExecutors.shutdownAndAwaitTermination(executorService, 1, TimeUnit.DAYS)) {
            if (executorService instanceof ThreadPoolExecutor) {
                ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) executorService;
                synchronized (this) {
                    threadCounter += threadPoolExecutor.getCorePoolSize();
                    logger.trace("destoried a executor, with thread:{}, availabled thread:{}",
                            threadPoolExecutor.getCorePoolSize(), threadCounter);
                }
            }
        }
    }

    public static final class Builder {

        private int globalMaxThread;
        private IntUnaryOperator threadCountFunction;
        private ThreadFactory threadFactory;
        private boolean callerRuns;

        public Builder withGlobalMaxThread(int globalMaxThread) {
            this.globalMaxThread = globalMaxThread;
            return this;
        }

        public Builder enableCallerRunsPolicy() {
            callerRuns = true;
            return this;
        }

        public Builder withThreadStrategy(IntUnaryOperator func) {
            this.threadCountFunction = func;
            return this;
        }

        /**
         * @param maxThreadPerOp 每个操作最多的线程数,尽可能多的使用多线程
         * @return
         */
        public Builder maxThreadAsPossible(int maxThreadPerOp) {
            this.threadCountFunction = i -> Math.min(maxThreadPerOp, i);
            return this;
        }

        public Builder threadFactory(ThreadFactory threadFactory) {
            this.threadFactory = threadFactory;
            return this;
        }

        /**
         * @param minMultiThreadThreshold 操作数超过这个阈值就启用多线程
         * @param maxThreadPerOp 每个操作最多的线程数,尽可能多的使用多线程
         * @return
         */
        public Builder maxThreadAsPossible(int minMultiThreadThreshold, int maxThreadPerOp) {
            this.threadCountFunction = i -> i <= minMultiThreadThreshold ? 1
                    : Math.min(maxThreadPerOp, i);
            return this;
        }

        /**
         * @param opPerThread 1个线程使用n个操作
         * @param maxThreadPerOp 单次操作最多线程数
         * @return
         */
        public Builder adaptiveThread(int opPerThread, int maxThreadPerOp) {
            this.threadCountFunction = i -> Math.min(maxThreadPerOp, i / opPerThread);
            return this;
        }

        public AdaptiveExecutor build() {
            ensure();
            return new AdaptiveExecutor(globalMaxThread, threadCountFunction, threadFactory,
                    callerRuns);
        }

        private void ensure() {
            if (threadCountFunction == null) {
                throw new NullPointerException("thread count function is null.");
            }
            if (globalMaxThread <= 0) {
                throw new IllegalArgumentException("global max thread is illeagl.");
            }
            if (threadFactory == null) {
                threadFactory = new ThreadFactoryBuilder() //
                        .setNameFormat("pool-adaptive-thread-%d") //
                        .build();
            }
        }

    }

    public static final Builder newBuilder() {
        return new Builder();
    }

    private static Supplier cpuCoreAdaptive = Suppliers
            .memoize(AdaptiveExecutor.newBuilder() //
                    .withGlobalMaxThread(Runtime.getRuntime().availableProcessors()) //
                    .maxThreadAsPossible(Runtime.getRuntime().availableProcessors()) //
                    ::build);

    public static final AdaptiveExecutor getCpuCoreAdpativeExecutor() {
        return cpuCoreAdaptive.get();
    }

    public int getLeftThreadCount() {
        return threadCounter;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy