
com.github.phantomthief.concurrent.AdaptiveExecutor Maven / Gradle / Ivy
The newest version!
/**
*
*/
package com.github.phantomthief.concurrent;
import static com.github.phantomthief.util.MoreSuppliers.lazy;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.Iterables.partition;
import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService;
import static com.google.common.util.concurrent.MoreExecutors.shutdownAndAwaitTermination;
import static java.lang.Math.ceil;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.Thread.currentThread;
import static java.time.Duration.ofMillis;
import static java.util.Collections.emptyList;
import static java.util.concurrent.TimeUnit.DAYS;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.stream.Collectors.toList;
import static org.slf4j.LoggerFactory.getLogger;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.ThreadPoolExecutor.CallerRunsPolicy;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntUnaryOperator;
import java.util.stream.Stream;
import org.slf4j.Logger;
import com.github.phantomthief.util.MoreSuppliers.CloseableSupplier;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
/**
* @author w.vela
*/
public class AdaptiveExecutor implements AutoCloseable {
private static final long DEFAULT_TIMEOUT = SECONDS.toMillis(60);
private static final Object EMPTY_OBJECT = new Object();
private static final CallerRunsPolicy CALLER_RUNS_POLICY = new CallerRunsPolicy();
private static final ListeningExecutorService DIRECT_EXECUTOR_SERVICE = newDirectExecutorService();
private static Logger logger = getLogger(AdaptiveExecutor.class);
private static CloseableSupplier cpuCoreAdaptive = lazy(
AdaptiveExecutor.newBuilder() //
.withGlobalMaxThread(Runtime.getRuntime().availableProcessors()) //
.maxThreadAsPossible(Runtime.getRuntime().availableProcessors())::build);
private final CloseableSupplier threadPoolExecutor;
private final IntUnaryOperator threadCountFunction;
private AdaptiveExecutor(Builder builder) {
this.threadCountFunction = builder.threadCountFunction;
this.threadPoolExecutor = lazy(() -> builder.executorFactory.create(0,
builder.globalMaxThread, ofMillis(builder.threadTimeout), new SynchronousQueue<>(),
CALLER_RUNS_POLICY));
}
public static Builder newBuilder() {
return new Builder();
}
public static AdaptiveExecutor getCpuCoreAdpativeExecutor() {
return cpuCoreAdaptive.get();
}
public final void run(Collection keys, Consumer func) {
run(null, keys, func);
}
public final void run(String threadSuffixName, Collection keys, Consumer func) {
invokeAll(threadSuffixName, keys, i -> {
func.accept(i);
return EMPTY_OBJECT;
});
}
public final Map invokeAll(Collection keys, Function func) {
return invokeAll(null, keys, func);
}
public final Map invokeAll(String threadSuffixName, Collection keys,
Function func) {
List> calls = keys.stream().> map(k -> () -> func.apply(k))
.collect(toList());
List callResult = invokeAll(threadSuffixName, calls);
Iterator iterator = callResult.iterator();
Map result = new HashMap<>();
for (K key : keys) {
V r;
if (iterator.hasNext()) {
r = iterator.next();
} else {
r = null;
}
result.put(key, r);
}
return result;
}
public final List invokeAll(List> calls) {
return invokeAll(null, calls);
}
public final List invokeAll(String threadSuffixName, List> calls) {
if (calls == null || calls.isEmpty()) {
return emptyList();
}
ExecutorService executorService;
int thread = max(1, threadCountFunction.applyAsInt(calls.size()));
if (thread == 1) {
executorService = DIRECT_EXECUTOR_SERVICE;
} else {
executorService = threadPoolExecutor.get();
}
Thread callersThread = currentThread();
List>> packed = new ArrayList<>();
for (List> list : partition(calls,
(int) ceil((double) calls.size() / thread))) {
packed.add(() -> {
String origThreadName = null;
Thread runningThread = currentThread();
if (runningThread != callersThread) {
origThreadName = renameCurrentThread(threadSuffixName);
}
try {
List result = new ArrayList<>(list.size());
for (Callable callable : list) {
result.add(callable.call());
}
return result;
} finally {
if (origThreadName != null) {
runningThread.setName(origThreadName);
}
}
});
}
try {
List>> invokeAll = executorService.invokeAll(packed);
return invokeAll.stream().flatMap(this::futureGet).collect(toList());
} catch (Throwable e) {
logger.error("Ops.", e);
return emptyList();
}
}
private String renameCurrentThread(String threadNameSuffix) {
Thread currentThread = currentThread();
String originalName = currentThread.getName();
currentThread.setName(originalName + "-" + threadNameSuffix);
return originalName;
}
private Stream futureGet(Future> future) {
try {
return future.get().stream();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
@Override
public void close() {
threadPoolExecutor.tryClose(exec -> shutdownAndAwaitTermination(exec, 1, DAYS));
}
/**
*
*/
public static final class Builder {
private int globalMaxThread;
private IntUnaryOperator threadCountFunction;
private long threadTimeout;
private ExecutorFactory executorFactory;
public Builder executorFactory(ExecutorFactory factory) {
this.executorFactory = factory;
return this;
}
public Builder withGlobalMaxThread(int globalMaxThread) {
this.globalMaxThread = globalMaxThread;
return this;
}
public Builder withThreadStrategy(IntUnaryOperator func) {
this.threadCountFunction = func;
return this;
}
public Builder threadTimeout(long time, TimeUnit unit) {
this.threadTimeout = unit.toMillis(time);
return this;
}
/**
* @param maxThreadPerOp 每个操作最多的线程数,尽可能多的使用多线程
*/
public Builder maxThreadAsPossible(int maxThreadPerOp) {
this.threadCountFunction = i -> min(maxThreadPerOp, i);
return this;
}
/**
* @param minMultiThreadThreshold 操作数超过这个阈值就启用多线程
* @param maxThreadPerOp 每个操作最多的线程数,尽可能多的使用多线程
*/
public Builder maxThreadAsPossible(int minMultiThreadThreshold, int maxThreadPerOp) {
this.threadCountFunction = i -> i <= minMultiThreadThreshold ? 1 : min(maxThreadPerOp,
i);
return this;
}
/**
* @param opPerThread 1个线程使用n个操作
* @param maxThreadPerOp 单次操作最多线程数
*/
public Builder adaptiveThread(int opPerThread, int maxThreadPerOp) {
this.threadCountFunction = i -> min(maxThreadPerOp, i / opPerThread);
return this;
}
public AdaptiveExecutor build() {
ensure();
return new AdaptiveExecutor(this);
}
private void ensure() {
checkNotNull(threadCountFunction, "thread count function is null.");
checkArgument(globalMaxThread > 0, "global max thread is illeagl.");
if (threadTimeout <= 0) {
threadTimeout = DEFAULT_TIMEOUT;
}
if (executorFactory == null) {
executorFactory = (corePoolSize, maxPoolThread, keepAliveTime, workQueue,
policy) -> new ThreadPoolExecutor(corePoolSize, maxPoolThread,
keepAliveTime.toMillis(), MILLISECONDS, new SynchronousQueue<>(),
new ThreadFactoryBuilder() //
.setNameFormat("pool-adaptive-thread-%d") //
.setUncaughtExceptionHandler(
(t, e) -> logger.error("Ops.", e)) //
.build(),
CALLER_RUNS_POLICY);
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy