All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.linkedin.dagli.dag.LocalDAGExecutor Maven / Gradle / Ivy

Go to download

DAG-oriented machine learning framework for bug-resistant, readable, efficient, maintainable and trivially deployable models in Java and other JVM languages

There is a newer version: 15.0.0-beta9
Show newest version
package com.linkedin.dagli.dag;

import com.linkedin.dagli.objectio.ObjectReader;
import com.linkedin.dagli.preparer.PreparerResult;
import java.util.Objects;


/**
 * {@link LocalDAGExecutor} is a user-friendly way to execute a DAG on the local machine in the best way possible, using
 * either FastPreparedDAGExecutor or MultithreadedDAGExecutor behind the scenes (the former is preferred for prepared
 * DAGs).
 */
public class LocalDAGExecutor extends AbstractDAGExecutor implements DAGExecutor {
  private static final int DEFAULT_THREAD_COUNT = 2 * Runtime.getRuntime().availableProcessors();
  private static final int DEFAULT_MIN_INPUTS_PER_THREAD = 128;

  private static final long serialVersionUID = 1L;

  private final MultithreadedDAGExecutor _multithreadedDAGExecutor;
  private final FastPreparedDAGExecutor _fastPreparedDAGExecutor;

  /**
   * The maximum number of threads that will be used by the DAG executor.
   * The default value is twice the number of logical CPU cores.
   *
   * @param maxThreads the maximum number of threads used by the executor
   * @return a copy of this instance that uses the specified number of threads
   */
  public LocalDAGExecutor withMaxThreads(int maxThreads) {
    return new LocalDAGExecutor(_multithreadedDAGExecutor.withMaxThreads(maxThreads),
        _fastPreparedDAGExecutor.withMaxThreads(maxThreads));
  }

  public LocalDAGExecutor withBatchSize(int batchSize) {
    return new LocalDAGExecutor(_multithreadedDAGExecutor.withBatchSize(batchSize), _fastPreparedDAGExecutor);
  }

  public LocalDAGExecutor withConcurrentBatches(int maxConcurrentBatches) {
    return new LocalDAGExecutor(_multithreadedDAGExecutor.withConcurrentBatches(maxConcurrentBatches),
        _fastPreparedDAGExecutor);
  }

  public LocalDAGExecutor withStorage(LocalStorage storage) {
    return new LocalDAGExecutor(_multithreadedDAGExecutor.withStorage(storage), _fastPreparedDAGExecutor);
  }

  public LocalDAGExecutor() {
    this(new MultithreadedDAGExecutor()
            .withBatchSize(MultithreadedDAGExecutor.DEFAULT_BATCH_SIZE)
            .withConcurrentBatches(MultithreadedDAGExecutor.DEFAULT_MAX_CONCURRENT_BATCHES)
            .withMaxThreads(DEFAULT_THREAD_COUNT),
         new FastPreparedDAGExecutor()
             .withMinInputsPerThread(DEFAULT_MIN_INPUTS_PER_THREAD)
            .withMaxThreads(DEFAULT_THREAD_COUNT));
  }

  private LocalDAGExecutor(MultithreadedDAGExecutor mtde, FastPreparedDAGExecutor fpde) {
    _multithreadedDAGExecutor = Objects.requireNonNull(mtde);
    _fastPreparedDAGExecutor = Objects.requireNonNull(fpde);
  }

  @Override
  protected , T extends PreparableDAGTransformer> DAGExecutionResult prepareAndApplyUnsafeImpl(
      T dag, ObjectReader[] inputValueLists) {
    return _multithreadedDAGExecutor.prepareAndApplyUnsafeImpl(dag, inputValueLists);
  }

  @Override
  protected , T extends PreparableDAGTransformer> PreparerResult
  prepareUnsafeImpl(T dag, ObjectReader[] inputValueLists) {
    return _multithreadedDAGExecutor.prepareUnsafeImpl(dag, inputValueLists);
  }

  @Override
  protected > ObjectReader[] applyUnsafeImpl(T dag,
      ObjectReader[] inputValueLists) {
    return _fastPreparedDAGExecutor.applyUnsafeImpl(dag, inputValueLists);
  }

  @Override
  public boolean equals(Object o) {
    if (this == o) {
      return true;
    }
    if (o == null || getClass() != o.getClass()) {
      return false;
    }
    LocalDAGExecutor that = (LocalDAGExecutor) o;
    return _multithreadedDAGExecutor.equals(that._multithreadedDAGExecutor) && _fastPreparedDAGExecutor.equals(
        that._fastPreparedDAGExecutor);
  }

  @Override
  public int hashCode() {
    return Objects.hash(_multithreadedDAGExecutor, _fastPreparedDAGExecutor);
  }
}