All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.pivovarit.collectors.AsyncParallelCollector Maven / Gradle / Ivy

The newest version!
package com.pivovarit.collectors;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Stream;

import static com.pivovarit.collectors.BatchingSpliterator.batching;
import static com.pivovarit.collectors.BatchingSpliterator.partitioned;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.CompletableFuture.supplyAsync;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;

/**
 * @author Grzegorz Piwowarek
 */
final class AsyncParallelCollector
  implements Collector>, CompletableFuture> {

    private final Dispatcher dispatcher;
    private final Function task;
    private final Function, C> finalizer;

    private AsyncParallelCollector(
      Function task,
      Dispatcher dispatcher,
      Function, C> finalizer) {
        this.dispatcher = dispatcher;
        this.finalizer = finalizer;
        this.task = task;
    }

    @Override
    public Supplier>> supplier() {
        return ArrayList::new;
    }

    @Override
    public BinaryOperator>> combiner() {
        return (left, right) -> {
            throw new UnsupportedOperationException("Using parallel stream with parallel collectors is a bad idea");
        };
    }

    @Override
    public BiConsumer>, T> accumulator() {
        return (acc, e) -> {
            if (!dispatcher.isRunning()) {
                dispatcher.start();
            }
            acc.add(dispatcher.enqueue(() -> task.apply(e)));
        };
    }

    @Override
    public Function>, CompletableFuture> finisher() {
        return futures -> {
            dispatcher.stop();

            return combine(futures).thenApply(finalizer);
        };
    }

    @Override
    public Set characteristics() {
        return Collections.emptySet();
    }

    private static  CompletableFuture> combine(List> futures) {
        var combined = allOf(futures.toArray(CompletableFuture[]::new))
          .thenApply(__ -> futures.stream().map(CompletableFuture::join));

        for (var future : futures) {
            future.whenComplete((o, ex) -> {
                if (ex != null) {
                    combined.completeExceptionally(ex);
                }
            });
        }

        return combined;
    }

    static  Collector>> collectingToStream(Function mapper) {
        requireNonNull(mapper, "mapper can't be null");

        return new AsyncParallelCollector<>(mapper, Dispatcher.virtual(), Function.identity());
    }

    static  Collector>> collectingToStream(Function mapper, int parallelism) {
        requireNonNull(mapper, "mapper can't be null");
        requireValidParallelism(parallelism);

        return new AsyncParallelCollector<>(mapper, Dispatcher.virtual(parallelism), Function.identity());
    }

    static  Collector>> collectingToStream(Function mapper, Executor executor, int parallelism) {
        requireNonNull(executor, "executor can't be null");
        requireNonNull(mapper, "mapper can't be null");
        requireValidParallelism(parallelism);

        return parallelism == 1
          ? asyncCollector(mapper, executor, i -> i)
          : new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), Function.identity());
    }

    static  Collector> collectingWithCollector(Collector collector, Function mapper) {
        requireNonNull(collector, "collector can't be null");
        requireNonNull(mapper, "mapper can't be null");

        return new AsyncParallelCollector<>(mapper, Dispatcher.virtual(), s -> s.collect(collector));
    }

    static  Collector> collectingWithCollector(Collector collector, Function mapper, int parallelism) {
        requireNonNull(collector, "collector can't be null");
        requireNonNull(mapper, "mapper can't be null");
        requireValidParallelism(parallelism);

        return parallelism == 1
          ? asyncCollector(mapper, Executors.newVirtualThreadPerTaskExecutor(), s -> s.collect(collector))
          : new AsyncParallelCollector<>(mapper, Dispatcher.virtual(parallelism), s -> s.collect(collector));
    }

    static  Collector> collectingWithCollector(Collector collector, Function mapper, Executor executor, int parallelism) {
        requireNonNull(collector, "collector can't be null");
        requireNonNull(executor, "executor can't be null");
        requireNonNull(mapper, "mapper can't be null");
        requireValidParallelism(parallelism);

        return parallelism == 1
          ? asyncCollector(mapper, executor, s -> s.collect(collector))
          : new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), s -> s.collect(collector));
    }

    static void requireValidParallelism(int parallelism) {
        if (parallelism < 1) {
            throw new IllegalArgumentException("Parallelism can't be lower than 1");
        }
    }

    static  Collector> asyncCollector(Function mapper, Executor executor, Function, RR> finisher) {
        return collectingAndThen(toList(), list -> supplyAsync(() -> {
            Stream.Builder acc = Stream.builder();
            for (T t : list) {
                acc.add(mapper.apply(t));
            }
            return finisher.apply(acc.build());
        }, executor));
    }

    static final class BatchingCollectors {

        private BatchingCollectors() {
        }

        static  Collector> collectingWithCollector(Collector collector, Function mapper, Executor executor, int parallelism) {
            requireNonNull(collector, "collector can't be null");
            requireNonNull(executor, "executor can't be null");
            requireNonNull(mapper, "mapper can't be null");
            requireValidParallelism(parallelism);

            return parallelism == 1
              ? asyncCollector(mapper, executor, s -> s.collect(collector))
              : batchingCollector(mapper, executor, parallelism, s -> s.collect(collector));
        }

        static  Collector>> collectingToStream(
          Function mapper,
          Executor executor, int parallelism) {
            requireNonNull(executor, "executor can't be null");
            requireNonNull(mapper, "mapper can't be null");
            requireValidParallelism(parallelism);

            return parallelism == 1
              ? asyncCollector(mapper, executor, i -> i)
              : batchingCollector(mapper, executor, parallelism, s -> s);
        }

        private static  Collector> batchingCollector(Function mapper, Executor executor, int parallelism, Function, RR> finisher) {
            return collectingAndThen(
              toList(),
              list -> {
                  // no sense to repack into batches of size 1
                  if (list.size() == parallelism) {
                      return list.stream()
                        .collect(new AsyncParallelCollector<>(
                          mapper,
                          Dispatcher.from(executor, parallelism),
                          finisher));
                  } else {
                      return partitioned(list, parallelism)
                        .collect(new AsyncParallelCollector<>(
                          batching(mapper),
                          Dispatcher.from(executor, parallelism),
                          listStream -> finisher.apply(listStream.flatMap(Collection::stream))));
                  }
              });
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy