com.pivovarit.collectors.AsyncParallelCollector Maven / Gradle / Ivy
package com.pivovarit.collectors;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Stream;
import static com.pivovarit.collectors.BatchingSpliterator.batching;
import static com.pivovarit.collectors.BatchingSpliterator.partitioned;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.CompletableFuture.supplyAsync;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;
/**
* @author Grzegorz Piwowarek
*/
final class AsyncParallelCollector
implements Collector>, CompletableFuture> {
private final Dispatcher dispatcher;
private final Function task;
private final Function, C> finalizer;
private AsyncParallelCollector(
Function task,
Dispatcher dispatcher,
Function, C> finalizer) {
this.dispatcher = dispatcher;
this.finalizer = finalizer;
this.task = task;
}
@Override
public Supplier>> supplier() {
return ArrayList::new;
}
@Override
public BinaryOperator>> combiner() {
return (left, right) -> {
throw new UnsupportedOperationException("Using parallel stream with parallel collectors is a bad idea");
};
}
@Override
public BiConsumer>, T> accumulator() {
return (acc, e) -> acc.add(dispatcher.enqueue(() -> task.apply(e)));
}
@Override
public Function>, CompletableFuture> finisher() {
return futures -> combine(futures).thenApply(finalizer);
}
@Override
public Set characteristics() {
return Collections.emptySet();
}
private static CompletableFuture> combine(List> futures) {
var combined = allOf(futures.toArray(CompletableFuture[]::new))
.thenApply(__ -> futures.stream().map(CompletableFuture::join));
for (var future : futures) {
future.whenComplete((o, ex) -> {
if (ex != null) {
combined.completeExceptionally(ex);
}
});
}
return combined;
}
static Collector>> collectingToStream(Function mapper) {
requireNonNull(mapper, "mapper can't be null");
return new AsyncParallelCollector<>(mapper, Dispatcher.virtual(), Function.identity());
}
static Collector>> collectingToStream(Function mapper, Executor executor, int parallelism) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
return parallelism == 1
? asyncCollector(mapper, executor, i -> i)
: new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), Function.identity());
}
static Collector> collectingWithCollector(Collector collector, Function mapper) {
requireNonNull(collector, "collector can't be null");
requireNonNull(mapper, "mapper can't be null");
return new AsyncParallelCollector<>(mapper, Dispatcher.virtual(),s -> s.collect(collector));
}
static Collector> collectingWithCollector(Collector collector, Function mapper, Executor executor, int parallelism) {
requireNonNull(collector, "collector can't be null");
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
return parallelism == 1
? asyncCollector(mapper, executor, s -> s.collect(collector))
: new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), s -> s.collect(collector));
}
static void requireValidParallelism(int parallelism) {
if (parallelism < 1) {
throw new IllegalArgumentException("Parallelism can't be lower than 1");
}
}
static Collector> asyncCollector(Function mapper, Executor executor, Function, RR> finisher) {
return collectingAndThen(toList(), list -> supplyAsync(() -> {
Stream.Builder acc = Stream.builder();
for (T t : list) {
acc.add(mapper.apply(t));
}
return finisher.apply(acc.build());
}, executor));
}
static final class BatchingCollectors {
private BatchingCollectors() {
}
static Collector> collectingWithCollector(Collector collector, Function mapper, Executor executor, int parallelism) {
requireNonNull(collector, "collector can't be null");
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
return parallelism == 1
? asyncCollector(mapper, executor, s -> s.collect(collector))
: batchingCollector(mapper, executor, parallelism, s -> s.collect(collector));
}
static Collector>> collectingToStream(
Function mapper,
Executor executor, int parallelism) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
return parallelism == 1
? asyncCollector(mapper, executor, i -> i)
: batchingCollector(mapper, executor, parallelism, s -> s);
}
private static Collector> batchingCollector(Function mapper, Executor executor, int parallelism, Function, RR> finisher) {
return collectingAndThen(
toList(),
list -> {
// no sense to repack into batches of size 1
if (list.size() == parallelism) {
return list.stream()
.collect(new AsyncParallelCollector<>(
mapper,
Dispatcher.from(executor, parallelism),
finisher));
} else {
return partitioned(list, parallelism)
.collect(new AsyncParallelCollector<>(
batching(mapper),
Dispatcher.from(executor, parallelism),
listStream -> finisher.apply(listStream.flatMap(Collection::stream))));
}
});
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy