All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.pivovarit.collectors.AsyncParallelCollector Maven / Gradle / Ivy

There is a newer version: 3.2.0
Show newest version
package com.pivovarit.collectors;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Stream;

import static com.pivovarit.collectors.BatchingSpliterator.batching;
import static com.pivovarit.collectors.BatchingSpliterator.partitioned;
import static com.pivovarit.collectors.Dispatcher.getDefaultParallelism;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.CompletableFuture.supplyAsync;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;

/**
 * @author Grzegorz Piwowarek
 */
final class AsyncParallelCollector
  implements Collector>, CompletableFuture> {

    private final Dispatcher dispatcher;
    private final Function mapper;
    private final Function, C> processor;

    private AsyncParallelCollector(
      Function mapper,
      Dispatcher dispatcher,
      Function, C> processor) {
        this.dispatcher = dispatcher;
        this.processor = processor;
        this.mapper = mapper;
    }

    @Override
    public Supplier>> supplier() {
        return ArrayList::new;
    }

    @Override
    public BinaryOperator>> combiner() {
        return (left, right) -> {
            throw new UnsupportedOperationException("Using parallel stream with parallel collectors is a bad idea");
        };
    }

    @Override
    public BiConsumer>, T> accumulator() {
        return (acc, e) -> acc.add(dispatcher.enqueue(() -> mapper.apply(e)));
    }

    @Override
    public Function>, CompletableFuture> finisher() {
        return futures -> combine(futures).thenApply(processor);
    }

    @Override
    public Set characteristics() {
        return Collections.emptySet();
    }

    private static  CompletableFuture> combine(List> futures) {
        CompletableFuture[] futuresArray = (CompletableFuture[]) futures.toArray(new CompletableFuture[0]);
        CompletableFuture> combined = allOf(futuresArray)
          .thenApply(__ -> Arrays.stream(futuresArray).map(CompletableFuture::join));

        for (CompletableFuture f : futuresArray) {
            f.exceptionally(ex -> {
                combined.completeExceptionally(ex);
                return null;
            });
        }

        return combined;
    }

    static  Collector>> collectingToStream(Function mapper, Executor executor) {
        return collectingToStream(mapper, executor, getDefaultParallelism());
    }

    static  Collector>> collectingToStream(Function mapper, Executor executor, int parallelism) {
        requireNonNull(executor, "executor can't be null");
        requireNonNull(mapper, "mapper can't be null");
        requireValidParallelism(parallelism);

        return parallelism == 1
          ? asyncCollector(mapper, executor, i -> i)
          : new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), t -> t);
    }

    static  Collector> collectingWithCollector(Collector collector, Function mapper, Executor executor) {
        return collectingWithCollector(collector, mapper, executor, getDefaultParallelism());
    }

    static  Collector> collectingWithCollector(Collector collector, Function mapper, Executor executor, int parallelism) {
        requireNonNull(collector, "collector can't be null");
        requireNonNull(executor, "executor can't be null");
        requireNonNull(mapper, "mapper can't be null");
        requireValidParallelism(parallelism);

        return parallelism == 1
          ? asyncCollector(mapper, executor, s -> s.collect(collector))
          : new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), s -> s.collect(collector));
    }

    static void requireValidParallelism(int parallelism) {
        if (parallelism < 1) {
            throw new IllegalArgumentException("Parallelism can't be lower than 1");
        }
    }

    static  Collector> asyncCollector(Function mapper, Executor executor, Function, RR> finisher) {
        return collectingAndThen(toList(), list -> supplyAsync(() -> {
            Stream.Builder acc = Stream.builder();
            for (T t : list) {
                acc.add(mapper.apply(t));
            }
            return finisher.apply(acc.build());
        }, executor));
    }

    static final class BatchingCollectors {

        private BatchingCollectors() {
        }

        static  Collector> collectingWithCollector(Collector collector, Function mapper, Executor executor, int parallelism) {
            requireNonNull(collector, "collector can't be null");
            requireNonNull(executor, "executor can't be null");
            requireNonNull(mapper, "mapper can't be null");
            requireValidParallelism(parallelism);

            return parallelism == 1
              ? asyncCollector(mapper, executor, s -> s.collect(collector))
              : batchingCollector(mapper, executor, parallelism, s -> s.collect(collector));
        }

        static  Collector>> collectingToStream(
          Function mapper,
          Executor executor, int parallelism) {
            requireNonNull(executor, "executor can't be null");
            requireNonNull(mapper, "mapper can't be null");
            requireValidParallelism(parallelism);

            return parallelism == 1
              ? asyncCollector(mapper, executor, i -> i)
              : batchingCollector(mapper, executor, parallelism, s -> s);
        }

        private static  Collector> batchingCollector(Function mapper, Executor executor, int parallelism, Function, RR> finisher) {
            return collectingAndThen(
              toList(),
              list -> {
                  // no sense to repack into batches of size 1
                  if (list.size() == parallelism) {
                      return list.stream()
                        .collect(new AsyncParallelCollector<>(
                          mapper,
                          Dispatcher.from(executor, parallelism),
                          finisher));
                  } else {
                      return partitioned(list, parallelism)
                        .collect(new AsyncParallelCollector<>(
                          batching(mapper),
                          Dispatcher.from(executor, parallelism),
                          listStream -> finisher.apply(listStream.flatMap(Collection::stream))));
                  }
              });
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy