All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.rouz.task.context.InMemImmediateContext Maven / Gradle / Ivy

package io.rouz.task.context;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;

import io.rouz.task.Task;
import io.rouz.task.TaskContext;
import io.rouz.task.TaskId;
import io.rouz.task.dsl.TaskBuilder.F0;

/**
 * A {@link TaskContext} that evaluates tasks immediately and memoizes results in memory.
 *
 * Memoized results are tied to the instance the evaluated the values.
 *
 * This context is not thread safe.
 */
public class InMemImmediateContext implements TaskContext {

  private Map cache =  new HashMap<>();

  private InMemImmediateContext() {
  }

  public static TaskContext create() {
    return new InMemImmediateContext();
  }

  @Override
  public  Value evaluate(Task task) {
    final TaskId taskId =  task.id();

    final Value value;
    if (has(taskId)) {
      value = get(taskId);
      LOG.debug("Found calculated value for {} = {}", taskId, value);
    } else {
      value = TaskContext.super.evaluate(task);
      put(taskId, value);
    }

    return value;
  }

  @Override
  public  Value value(F0 value) {
    return new DirectValue<>(value.get());
  }

  @Override
  public  Promise promise() {
    return new ValuePromise<>();
  }

  private boolean has(TaskId taskId) {
    return cache.containsKey(taskId);
  }

  private  void put(TaskId taskId, V value) {
    cache.put(taskId, value);
  }

  private  V get(TaskId taskId) {
    //noinspection unchecked
    return (V) cache.get(taskId);
  }

  private final class DirectValue implements Value {

    private final Semaphore setLatch;
    private final List> valueConsumers = new ArrayList<>();
    private final List> failureConsumers = new ArrayList<>();
    private final AtomicReference>> valueReceiver;
    private final AtomicReference>> failureReceiver;

    private DirectValue() {
      valueReceiver = new AtomicReference<>(valueConsumers::add);
      failureReceiver = new AtomicReference<>(failureConsumers::add);
      this.setLatch = new Semaphore(1);
    }

    private DirectValue(T value) {
      valueReceiver = new AtomicReference<>(c -> c.accept(value));
      failureReceiver = new AtomicReference<>(c -> {});
      this.setLatch = new Semaphore(0);
    }

    @Override
    public TaskContext context() {
      return InMemImmediateContext.this;
    }

    @Override
    public  Value flatMap(Function> fn) {
      Promise promise = promise();
      consume(t -> {
        final Value apply = fn.apply(t);
        apply.consume(promise::set);
        apply.onFail(promise::fail);
      });
      onFail(promise::fail);
      return promise.value();
    }

    @Override
    public void consume(Consumer consumer) {
      valueReceiver.get().accept(consumer);
    }

    @Override
    public void onFail(Consumer errorConsumer) {
      failureReceiver.get().accept(errorConsumer);
    }
  }

  private final class ValuePromise implements Promise {

    private final DirectValue value = new DirectValue<>();

    @Override
    public Value value() {
      return value;
    }

    @Override
    public void set(T t) {
      final boolean completed = value.setLatch.tryAcquire();
      if (!completed) {
        throw new IllegalStateException("Promise was already completed");
      } else {
        value.valueReceiver.set(c -> c.accept(t));
        value.valueConsumers.forEach(c -> c.accept(t));
      }
    }

    @Override
    public void fail(Throwable throwable) {
      final boolean completed = value.setLatch.tryAcquire();
      if (!completed) {
        throw new IllegalStateException("Promise was already completed");
      } else {
        value.failureReceiver.set(c -> c.accept(throwable));
        value.failureConsumers.forEach(c -> c.accept(throwable));
      }
    }
  }
}