All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.tonivade.purefun.monad.IO Maven / Gradle / Ivy

/*
 * Copyright (c) 2018-2023, Antonio Gabriel Muñoz Conejo 
 * Distributed under the terms of the MIT License
 */
package com.github.tonivade.purefun.monad;

import static com.github.tonivade.purefun.Function1.identity;
import static com.github.tonivade.purefun.Matcher1.always;
import static com.github.tonivade.purefun.Precondition.checkNonNull;

import java.time.Duration;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;

import com.github.tonivade.purefun.CheckedRunnable;
import com.github.tonivade.purefun.Consumer1;
import com.github.tonivade.purefun.Effect;
import com.github.tonivade.purefun.Function1;
import com.github.tonivade.purefun.Function2;
import com.github.tonivade.purefun.HigherKind;
import com.github.tonivade.purefun.Kind;
import com.github.tonivade.purefun.Operator1;
import com.github.tonivade.purefun.PartialFunction1;
import com.github.tonivade.purefun.Producer;
import com.github.tonivade.purefun.Recoverable;
import com.github.tonivade.purefun.Tuple;
import com.github.tonivade.purefun.Tuple2;
import com.github.tonivade.purefun.Unit;
import com.github.tonivade.purefun.concurrent.Future;
import com.github.tonivade.purefun.concurrent.Promise;
import com.github.tonivade.purefun.data.ImmutableList;
import com.github.tonivade.purefun.data.ImmutableMap;
import com.github.tonivade.purefun.data.Sequence;
import com.github.tonivade.purefun.type.Either;
import com.github.tonivade.purefun.type.Option;
import com.github.tonivade.purefun.type.Try;
import com.github.tonivade.purefun.typeclasses.Fiber;

@HigherKind
public sealed interface IO extends IOOf, Effect, Recoverable {

  IO UNIT = pure(Unit.unit());

  default Future runAsync() {
    return Future.from(runAsync(this, IOConnection.UNCANCELLABLE));
  }

  default Future runAsync(Executor executor) {
    return forked(executor).andThen(this).runAsync();
  }
  
  default T unsafeRunSync() {
    return safeRunSync().getOrElseThrow();
  }

  default Try safeRunSync() {
    return runAsync().await();
  }

  default void safeRunAsync(Consumer1> callback) {
    safeRunAsync(Future.DEFAULT_EXECUTOR, callback);
  }

  default void safeRunAsync(Executor executor, Consumer1> callback) {
    runAsync(executor).onComplete(callback);
  }

  @Override
  default  IO map(Function1 map) {
    return flatMap(map.andThen(IO::pure));
  }

  @Override
  default  IO flatMap(Function1> map) {
    return new FlatMapped<>(this, map.andThen(IOOf::narrowK));
  }

  @Override
  default  IO andThen(Kind after) {
    return flatMap(ignore -> after);
  }

  @Override
  default  IO ap(Kind> apply) {
    return parMap2(Future.DEFAULT_EXECUTOR, this, apply, (v, a) -> a.apply(v));
  }

  default IO> attempt() {
    return map(Try::success).recover(Try::failure);
  }

  default IO> either() {
    return attempt().map(Try::toEither);
  }

  default  IO> either(Function1 mapError,
                                         Function1 mapper) {
    return either().map(either -> either.bimap(mapError, mapper));
  }

  default  IO redeem(Function1 mapError,
                           Function1 mapper) {
    return attempt().map(result -> result.fold(mapError, mapper));
  }

  default  IO redeemWith(Function1> mapError,
                               Function1> mapper) {
    return attempt().flatMap(result -> result.fold(mapError, mapper));
  }

  default IO recover(Function1 mapError) {
    return recoverWith(PartialFunction1.of(always(), mapError.andThen(IO::pure)));
  }

  @SuppressWarnings("unchecked")
  default  IO recover(Class type, Function1 function) {
    return recoverWith(PartialFunction1.of(error -> error.getClass().equals(type), t -> function.andThen(IO::pure).apply((X) t)));
  }
  
  default IO recoverWith(PartialFunction1> mapper) {
    return new Recover<>(this, mapper.andThen(IOOf::narrowK));
  }

  @Override
  default IO> timed() {
    return IO.task(System::nanoTime).flatMap(
      start -> map(result -> Tuple.of(Duration.ofNanos(System.nanoTime() - start), result)));
  }
  
  default IO> fork() {
    return async(callback -> {
      IOConnection connection = IOConnection.cancellable();
      Promise promise = runAsync(this, connection);
      
      IO join = fromPromise(promise);
      IO cancel = exec(connection::cancel);
      
      callback.accept(Try.success(Fiber.of(join, cancel)));
    });
  }

  default IO timeout(Duration duration) {
    return timeout(Future.DEFAULT_EXECUTOR, duration);
  }
  
  default IO timeout(Executor executor, Duration duration) {
    return racePair(executor, this, sleep(duration)).flatMap(either -> either.fold(
        ta -> ta.get2().cancel().fix(IOOf.toIO()).map(x -> ta.get1()),
        tb -> tb.get1().cancel().fix(IOOf.toIO()).flatMap(x -> IO.raiseError(new TimeoutException()))));
  }

  @Override
  default IO repeat() {
    return repeat(1);
  }

  @Override
  default IO repeat(int times) {
    return repeat(this, unit(), times);
  }

  @Override
  default IO repeat(Duration delay) {
    return repeat(delay, 1);
  }

  @Override
  default IO repeat(Duration delay, int times) {
    return repeat(this, sleep(delay), times);
  }

  @Override
  default IO retry() {
    return retry(1);
  }

  @Override
  default IO retry(int maxRetries) {
    return retry(this, unit(), maxRetries);
  }

  @Override
  default IO retry(Duration delay) {
    return retry(delay, 1);
  }

  @Override
  default IO retry(Duration delay, int maxRetries) {
    return retry(this, sleep(delay), maxRetries);
  }

  static  IO pure(T value) {
    return new Pure<>(value);
  }
  
  static  IO> race(Kind fa, Kind fb) {
    return race(Future.DEFAULT_EXECUTOR, fa, fb);
  }
  
  static  IO> race(Executor executor, Kind fa, Kind fb) {
    return racePair(executor, fa, fb).flatMap(either -> either.fold(
        ta -> ta.get2().cancel().fix(IOOf.toIO()).map(x -> Either.left(ta.get1())),
        tb -> tb.get1().cancel().fix(IOOf.toIO()).map(x -> Either.right(tb.get2()))));
  }
  
  static  IO>, Tuple2, B>>> racePair(Executor executor, Kind fa, Kind fb) {
    return cancellable(callback -> {
      
      IOConnection connection1 = IOConnection.cancellable();
      IOConnection connection2 = IOConnection.cancellable();
      
      Promise promiseA = runAsync(IO.forked(executor).andThen(fa), connection1);
      Promise promiseB = runAsync(IO.forked(executor).andThen(fb), connection2);

      promiseA.onComplete(result -> callback.accept(
          result.map(a -> Either.left(Tuple.of(a, Fiber.of(IO.fromPromise(promiseB), IO.exec(connection2::cancel)))))));
      promiseB .onComplete(result -> callback.accept(
          result.map(b -> Either.right(Tuple.of(Fiber.of(IO.fromPromise(promiseA), IO.exec(connection2::cancel)), b)))));

      return IO.exec(() -> {
        try {
          connection1.cancel();
        } finally {
          connection2.cancel();
        }
      });
    });
  }

  static  IO raiseError(Throwable error) {
    return new Failure<>(error);
  }

  static  IO delay(Duration delay, Producer lazy) {
    return sleep(delay).andThen(task(lazy));
  }

  static  IO suspend(Producer> lazy) {
    return new Suspend<>(lazy.andThen(IOOf::narrowK));
  }

  static  Function1> lift(Function1 task) {
    return task.andThen(IO::pure);
  }

  static  Function1> liftOption(Function1> function) {
    return value -> fromOption(function.apply(value));
  }

  static  Function1> liftTry(Function1> function) {
    return value -> fromTry(function.apply(value));
  }

  static  Function1> liftEither(Function1> function) {
    return value -> fromEither(function.apply(value));
  }

  static  IO fromOption(Option task) {
    return fromEither(task.toEither());
  }

  static  IO fromTry(Try task) {
    return fromEither(task.toEither());
  }

  static  IO fromEither(Either task) {
    return task.fold(IO::raiseError, IO::pure);
  }
  
  static  IO fromPromise(Promise promise) {
    Consumer1>> callback = promise::onComplete;
    return async(callback);
  }
  
  static  IO fromCompletableFuture(CompletableFuture promise) {
    return fromPromise(Promise.from(promise));
  }

  static IO sleep(Duration duration) {
    return sleep(Future.DEFAULT_EXECUTOR, duration);
  }

  static IO sleep(Executor executor, Duration duration) {
    return cancellable(callback -> {
      Future sleep = Future.sleep(executor, duration)
        .onComplete(result -> callback.accept(Try.success(Unit.unit())));
      return IO.exec(() -> sleep.cancel(true));
    });
  }

  static IO exec(CheckedRunnable task) {
    return task(task.asProducer());
  }

  static  IO task(Producer producer) {
    return new Delay<>(producer);
  }

  static  IO never() {
    return async(callback -> {});
  }
  
  static IO forked() {
    return forked(Future.DEFAULT_EXECUTOR);
  }
  
  static IO forked(Executor executor) {
    return async(callback -> executor.execute(() -> callback.accept(Try.success(Unit.unit()))));
  }

  static  IO async(Consumer1>> callback) {
    return cancellable(callback.asFunction().andThen(IO::pure));
  }

  static  IO cancellable(Function1>, IO> callback) {
    return new Async<>(callback);
  }

  static  IO>> memoize(Function1> function) {
    return memoize(Future.DEFAULT_EXECUTOR, function);
  }

  static  IO>> memoize(Executor executor, Function1> function) {
    var ref = Ref.make(ImmutableMap.>empty());
    return ref.map(r -> {
      Function1>> result = a -> r.modify(map -> map.get(a).fold(() -> {
        Promise promise = Promise.make();
        function.apply(a).safeRunAsync(executor, promise::tryComplete);
        return Tuple.of(IO.fromPromise(promise), map.put(a, promise));
      }, promise -> Tuple.of(IO.fromPromise(promise), map)));
      return result.andThen(io -> io.flatMap(identity()));
    });
  }

  static IO unit() {
    return UNIT;
  }

  static  IO bracket(Kind acquire, 
      Function1> use, Function1> release) {
    return cancellable(callback -> {
      
      IOConnection cancellable = IOConnection.cancellable();
      
      Promise promise = runAsync(acquire.fix(IOOf::narrowK), cancellable);
      
      promise
        .onFailure(error -> callback.accept(Try.failure(error)))
        .onSuccess(resource -> runAsync(use.andThen(IOOf::narrowK).apply(resource), cancellable)
          .onComplete(result -> runAsync(release.andThen(IOOf::narrowK).apply(resource), cancellable)
            .onComplete(ignore -> callback.accept(result))
        ));
      
      return IO.exec(cancellable::cancel);
    });
  }

  static  IO bracket(Kind acquire, 
      Function1> use, Consumer1 release) {
    return bracket(acquire, use, release.asFunction().andThen(IO::pure));
  }

  static  IO bracket(Kind acquire, 
      Function1> use) {
    return bracket(acquire, use, AutoCloseable::close);
  }

  static IO sequence(Sequence> sequence) {
    Kind initial = IO.unit().kind();
    return sequence.foldLeft(initial, 
        (Kind a, Kind b) -> a.fix(IOOf::narrowK).andThen(b.fix(IOOf::narrowK))).fix(IOOf::narrowK).andThen(IO.unit());
  }

  static  IO> traverse(Sequence> sequence) {
    return traverse(Future.DEFAULT_EXECUTOR, sequence);
  }

  static  IO> traverse(Executor executor, Sequence> sequence) {
    return sequence.foldLeft(pure(ImmutableList.empty()), 
        (Kind> xs, Kind a) -> parMap2(executor, xs, a, Sequence::append));
  }

  static  IO parMap2(Kind fa, Kind fb,
                              Function2 mapper) {
    return parMap2(Future.DEFAULT_EXECUTOR, fa, fb, mapper);
  }

  static  IO parMap2(Executor executor, Kind fa, Kind fb,
                              Function2 mapper) {
    return cancellable(callback -> {
      
      IOConnection connection1 = IOConnection.cancellable();
      IOConnection connection2 = IOConnection.cancellable();
      
      Promise promiseA = runAsync(IO.forked(executor).andThen(fa), connection1);
      Promise promiseB = runAsync(IO.forked(executor).andThen(fb), connection2);
      
      promiseA.onComplete(a -> promiseB.onComplete(b -> callback.accept(Try.map2(a, b, mapper))));
      
      return IO.exec(() -> {
        try {
          connection1.cancel();
        } finally {
          connection2.cancel();
        }
      });
    });
  }

  static  IO> tuple(Kind fa, Kind fb) {
    return tuple(Future.DEFAULT_EXECUTOR, fa, fb);
  }

  static  IO> tuple(Executor executor, Kind fa, Kind fb) {
    return parMap2(executor, fa, fb, Tuple::of);
  }

  private static  Promise runAsync(IO current, IOConnection connection) {
    return runAsync(current, connection, new CallStack<>(), Promise.make());
  }

  @SuppressWarnings("unchecked")
  private static  Promise runAsync(IO current, IOConnection connection, CallStack stack, Promise promise) {
    while (true) {
      try {
        current = unwrap(current, stack, identity());
        
        if (current instanceof Pure pure) {
          return promise.succeeded(pure.value);
        }
        
        if (current instanceof Async async) {
          return executeAsync(async, connection, promise);
        }
        
        if (current instanceof FlatMapped) {
          stack.push();

          var flatMapped = (FlatMapped) current;
          IO source = unwrap(flatMapped.current, stack, u -> u.flatMap(flatMapped.next)).fix(IOOf::narrowK);
          
          if (source instanceof Async async) {
            Promise nextPromise = Promise.make();
            
            nextPromise.then(u -> {
              Function1> andThen = flatMapped.next.andThen(IOOf::narrowK);
              runAsync(andThen.apply(u), connection, stack, promise);
            });
            
            executeAsync(async, connection, nextPromise);
            
            return promise;
          }

          if (source instanceof Pure pure) {
            Function1> andThen = flatMapped.next.andThen(IOOf::narrowK);
            current = andThen.apply(pure.value);
          } else if (source instanceof FlatMapped) {
            FlatMapped flatMapped2 = (FlatMapped) source;
            current = flatMapped2.current.flatMap(a -> flatMapped2.next.apply(a).flatMap(flatMapped.next));
          }
        } else {
          stack.pop();
        }
      } catch (Throwable error) {
        Option> result = stack.tryHandle(error);
        
        if (result.isPresent()) {
          current = result.getOrElseThrow();
        } else {
          return promise.failed(error);
        }
      }
    }
  }

  private static  IO unwrap(IO current, CallStack stack, Function1, IO> next) {
    while (true) {
      if (current instanceof Failure failure) {
        return stack.sneakyThrow(failure.error);
      } else if (current instanceof Recover recover) {
        stack.add(recover.mapper.andThen(next));
        current = recover.current;
      } else if (current instanceof Suspend suspend) {
        Producer> andThen = (suspend).lazy.andThen(IOOf::narrowK);
        current = andThen.get();
      } else if (current instanceof Delay delay) {
        return IO.pure(delay.task.get());
      } else if (current instanceof Pure) {
        return current;
      } else if (current instanceof FlatMapped) {
        return current;
      } else if (current instanceof Async) {
        return current;
      } else {
        throw new IllegalStateException();
      }
    }
  }

  private static  Promise executeAsync(Async current, IOConnection connection, Promise promise) {
    if (connection.isCancellable() && !connection.updateState(StateIO::startingNow).isRunnable()) {
      return promise.cancel();
    }
    
    connection.setCancelToken(current.callback.apply(promise::tryComplete));
    
    promise.thenRun(() -> connection.setCancelToken(UNIT));
    
    if (connection.isCancellable() && connection.updateState(StateIO::notStartingNow).isCancellingNow()) {
      connection.cancelNow();
    }

    return promise;
  }

  private static  IO repeat(IO self, IO pause, int times) {
    return self.redeemWith(IO::raiseError, value -> {
      if (times > 0) {
        return pause.andThen(repeat(self, pause, times - 1));
      } else return IO.pure(value);
    });
  }

  private static  IO retry(IO self, IO pause, int maxRetries) {
    return self.redeemWith(error -> {
      if (maxRetries > 0) {
        return pause.andThen(retry(self, pause.repeat(), maxRetries - 1));
      } else return IO.raiseError(error);
    }, IO::pure);
  }

  final class Pure implements IO {

    private final T value;

    private Pure(T value) {
      this.value = checkNonNull(value);
    }

    @Override
    public String toString() {
      return "Pure(" + value + ")";
    }
  }

  final class Failure implements IO, Recoverable {

    private final Throwable error;

    private Failure(Throwable error) {
      this.error = checkNonNull(error);
    }

    @Override
    public String toString() {
      return "Failure(" + error + ")";
    }
  }

  final class FlatMapped implements IO {

    private final IO current;
    private final Function1> next;

    private FlatMapped(IO current,
                         Function1> next) {
      this.current = checkNonNull(current);
      this.next = checkNonNull(next);
    }

    @Override
    public String toString() {
      return "FlatMapped(" + current + ", ?)";
    }
  }

  final class Delay implements IO {

    private final Producer task;

    private Delay(Producer task) {
      this.task = checkNonNull(task);
    }

    @Override
    public String toString() {
      return "Delay(?)";
    }
  }

  final class Async implements IO {

    private final Function1>, IO> callback;

    private Async(Function1>, IO> callback) {
      this.callback = checkNonNull(callback);
    }

    @Override
    public String toString() {
      return "Async(?)";
    }
  }

  final class Suspend implements IO {

    private final Producer> lazy;

    private Suspend(Producer> lazy) {
      this.lazy = checkNonNull(lazy);
    }

    @Override
    public String toString() {
      return "Suspend(?)";
    }
  }

  final class Recover implements IO {

    private final IO current;
    private final PartialFunction1> mapper;

    private Recover(IO current, PartialFunction1> mapper) {
      this.current = checkNonNull(current);
      this.mapper = checkNonNull(mapper);
    }

    @Override
    public String toString() {
      return "Recover(" + current + ", ?)";
    }
  }
}

sealed interface IOConnection {
  
  IOConnection UNCANCELLABLE = new Uncancellable();
  
  boolean isCancellable();

  void setCancelToken(IO cancel);
  
  void cancelNow();
  
  void cancel();
  
  StateIO updateState(Operator1 update);
  
  static IOConnection cancellable() {
    return new Cancellable();
  }
  
  final class Uncancellable implements IOConnection {
    
    private Uncancellable() { }

    @Override
    public boolean isCancellable() {
      return false;
    }

    @Override
    public void setCancelToken(IO cancel) {
      // uncancellable
    }

    @Override
    public void cancelNow() {
      // uncancellable
    }

    @Override
    public void cancel() {
      // uncancellable
    }

    @Override
    public StateIO updateState(Operator1 update) {
      return StateIO.INITIAL;
    }
  }
  
  final class Cancellable implements IOConnection {

    private IO cancelToken;
    private final AtomicReference state = new AtomicReference<>(StateIO.INITIAL);

    private Cancellable() { }

    @Override
    public boolean isCancellable() {
      return true;
    }

    @Override
    public void setCancelToken(IO cancel) {
      this.cancelToken = checkNonNull(cancel);
    }

    @Override
    public void cancelNow() {
      cancelToken.runAsync();
    }

    @Override
    public void cancel() {
      if (state.getAndUpdate(StateIO::cancellingNow).isCancelable()) {
        cancelNow();

        state.set(StateIO.CANCELLED);
      }
    }

    @Override
    public StateIO updateState(Operator1 update) {
      return state.updateAndGet(update::apply);
    }
  }
}

final class StateIO {
  
  public static final StateIO INITIAL = new StateIO(false, false, false);
  public static final StateIO CANCELLED = new StateIO(true, false, false);
  
  private final boolean isCancelled;
  private final boolean cancellingNow;
  private final boolean startingNow;
  
  public StateIO(boolean isCancelled, boolean cancellingNow, boolean startingNow) {
    this.isCancelled = isCancelled;
    this.cancellingNow = cancellingNow;
    this.startingNow = startingNow;
  }
  
  public boolean isCancelled() {
    return isCancelled;
  }
  
  public boolean isCancellingNow() {
    return cancellingNow;
  }
  
  public boolean isStartingNow() {
    return startingNow;
  }
  
  public StateIO cancellingNow() {
    return new StateIO(isCancelled, true, startingNow);
  }
  
  public StateIO startingNow() {
    return new StateIO(isCancelled, cancellingNow, true);
  }
  
  public StateIO notStartingNow() {
    return new StateIO(isCancelled, cancellingNow, false);
  }
  
  public boolean isCancelable() {
    return !isCancelled && !cancellingNow && !startingNow;
  }
  
  public boolean isRunnable() {
    return !isCancelled && !cancellingNow;
  }
}

final class CallStack implements Recoverable {
  
  private StackItem top = new StackItem<>();
  
  public void push() {
    top.push();
  }

  public void pop() {
    if (top.count() > 0) {
      top.pop();
    } else {
      top = top.prev();
    }
  }
  
  public void add(PartialFunction1> mapError) {
    if (top.count() > 0) {
      top.pop();
      top = new StackItem<>(top);
    }
    top.add(mapError);
  }
  
  public Option> tryHandle(Throwable error) {
    while (top != null) {
      top.reset();
      Option> result = top.tryHandle(error);
      
      if (result.isPresent()) {
        return result;
      } else {
        top = top.prev();
      }
    }
    return Option.none();
  }
}

final class StackItem {
  
  private int count = 0;
  private final Deque>> recover = new ArrayDeque<>();

  private final StackItem prev;

  public StackItem() {
    this(null);
  }

  public StackItem(StackItem prev) {
    this.prev = prev;
  }
  
  public StackItem prev() {
    return prev;
  }
  
  public int count() {
    return count;
  }
  
  public void push() {
    count++;
  }
  
  public void pop() {
    count--;
  }
  
  public void reset() {
    count = 0;
  }
  
  public void add(PartialFunction1> mapError) {
    recover.addFirst(mapError);
  }

  public Option> tryHandle(Throwable error) {
    while (!recover.isEmpty()) {
      var mapError = recover.removeFirst();
      if (mapError.isDefinedAt(error)) {
        return Option.some(mapError.andThen(IOOf::narrowK).apply(error));
      }
    }
    return Option.none();
  }
}