All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.dakusui.cmd.core.Tee Maven / Gradle / Ivy

There is a newer version: 1.0.1
Show newest version
package com.github.dakusui.cmd.core;

import java.util.*;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import static java.util.stream.Collectors.toList;

public class Tee extends Thread {
  private static final Object SENTINEL = new Object() {
    @Override
    public String toString() {
      return "SENTINEL";
    }
  };
  private final Stream           in;
  private final List> queues;
  private final List>     streams;

  Tee(Stream in, int numDownStreams, int queueSize) {
    this.in = Objects.requireNonNull(in);
    this.queues = new LinkedList<>();
    for (int i = 0; i < numDownStreams; i++) {
      this.queues.add(new ArrayBlockingQueue<>(queueSize));
    }
    this.streams = createDownStreams();
  }

  @SuppressWarnings("WeakerAccess")
  public List> streams() {
    return Collections.unmodifiableList(this.streams);
  }

  @Override
  public void run() {
    List> pendings = new LinkedList<>();
    Stream.concat(in, Stream.of(SENTINEL))
        .forEach((Object t) -> {
          pendings.addAll(this.queues);
          synchronized (this.queues) {
            while (!pendings.isEmpty()) {
              this.queues.stream()
                  .filter(pendings::contains)
                  .filter(queue -> queue.offer(t))
                  .forEach(pendings::remove);
              this.queues.notifyAll();
              try {
                this.queues.wait();
              } catch (InterruptedException ignored) {
              }
            }
          }
        });
  }

  public static  Connector tee(Stream in) {
    return new Connector<>(in);
  }

  public static  Connector tee(Stream in, int queueSize) {
    return tee(in).setQueueSize(queueSize);
  }

  private List> createDownStreams() {
    return queues.stream()
        .map((Queue queue) -> (Iterable) () -> new Iterator() {
              Object next;

              @Override
              public boolean hasNext() {
                if (next == null)
                  getNext();
                return next != SENTINEL;
              }

              @SuppressWarnings("unchecked")
              @Override
              public T next() {
                if (next == null)
                  getNext();
                T ret = check((T) next, v -> v != SENTINEL, NoSuchElementException::new);
                if (next != SENTINEL)
                  next = null;
                return ret;
              }

              private void getNext() {
                synchronized (queues) {
                  while ((next = pollQueue()) == null) {
                    try {
                      queues.wait();
                    } catch (InterruptedException ignored) {
                    }
                  }
                }
              }

              private Object pollQueue() {
                try {
                  return next = queue.poll();
                } finally {
                  queues.notifyAll();
                }
              }
            }
        ).map(
            (Iterable iterable) -> StreamSupport.stream(iterable.spliterator(), false)
        ).collect(
            toList()
        );
  }

  private static  T check(T value, Predicate check, Supplier exceptionSupplier) throws E {
    if (check.test(value))
      return value;
    throw exceptionSupplier.get();
  }

  public static class Connector {
    private final Stream in;
    private       int                       queueSize   = 8192;
    private       long                      timeOut     = 60;
    private       TimeUnit                  timeOutUnit = TimeUnit.SECONDS;
    private final List>> consumers   = new LinkedList<>();


    public Connector(Stream in) {
      this.in = Objects.requireNonNull(in);
    }

    public Connector setQueueSize(int queueSize) {
      this.queueSize = check(queueSize, v -> v > 0, IllegalArgumentException::new);
      return this;
    }

    public Connector timeOut(long timeOut, TimeUnit timeUnit) {
      this.timeOut = check(timeOut, v -> v > 0, IllegalArgumentException::new);
      this.timeOutUnit = Objects.requireNonNull(timeUnit);
      return this;
    }


    public  Connector connect(Function, Stream> map, Consumer action) {
      this.consumers.add(stream -> map.apply(stream).forEach(action));
      return this;
    }

    public Connector connect(Consumer consumer) {
      this.connect(stream -> stream, consumer);
      return this;
    }

    /**
     * Blocks until all tasks have completed execution after a
     * shutdown request, or the timeout occurs, or the current thread
     * is interrupted, whichever happens first.
     *
     * @return {@code true} if this executor terminated and
     * {@code false} if the timeout elapsed before termination
     * @throws InterruptedException if interrupted while waiting
     * @see ForkJoinPool#awaitTermination(long, TimeUnit)
     */
    public boolean run() throws InterruptedException {
      return run(this.timeOut, this.timeOutUnit);
    }

    /**
     * Blocks until all tasks have completed execution after a
     * shutdown request, or the timeout occurs, or the current thread
     * is interrupted, whichever happens first.
     *
     * @param timeOut the maximum time to wait
     * @param unit    the time unit of the timeout argument
     * @return {@code true} if this executor terminated and
     * {@code false} if the timeout elapsed before termination
     * @throws InterruptedException if interrupted while waiting
     */
    public boolean run(long timeOut, TimeUnit unit) throws InterruptedException {
      Tee tee = new Tee<>(this.in, consumers.size(), this.queueSize);
      AtomicInteger i = new AtomicInteger(0);
      ForkJoinPool pool = new ForkJoinPool(consumers.size());
      tee.start();
      try {
        consumers.stream(
        ).map(
            (Consumer> consumer) -> (Runnable) () -> consumer.accept(tee.streams().get(i.getAndIncrement()))
        ).map(
            pool::submit
        ).parallel(
        ).forEach(
            task -> {
            }
        );
      } finally {
        pool.shutdown();
      }
      return pool.awaitTermination(timeOut, unit);
    }
  }
}