All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.linkedin.dagli.dag.AbstractDAGExecutor Maven / Gradle / Ivy

Go to download

DAG-oriented machine learning framework for bug-resistant, readable, efficient, maintainable and trivially deployable models in Java and other JVM languages

There is a newer version: 15.0.0-beta9
Show newest version
package com.linkedin.dagli.dag;

import com.linkedin.dagli.annotation.Versioned;
import com.linkedin.dagli.objectio.ObjectReader;
import com.linkedin.dagli.preparer.PreparerResult;
import com.linkedin.dagli.util.cloneable.AbstractCloneable;


/**
 * DAG executors prepare and apply DAGs.  Note that certain DAG executors (e.g. {@link FastPreparedDAGExecutor} may not
 * support DAG preparation (training), only application (inference).
 *
 * @param  the type of the derived DAGExecutor
 */
@Versioned
abstract class AbstractDAGExecutor> extends AbstractCloneable
    implements PreparedDAGExecutor {
  private static final long serialVersionUID = 1L;

  @Override
  @SuppressWarnings("unchecked") // S is the derived type of this base class
  public S internalAPI() {
    return (S) this;
  }

  @Override
  public abstract int hashCode(); // force subclasses to override

  @Override
  public abstract boolean equals(Object obj); // force subclasses to override

  protected abstract , T extends PreparableDAGTransformer>
  DAGExecutionResult prepareAndApplyUnsafeImpl(T dag, ObjectReader[] inputValueLists);

  protected , T extends PreparableDAGTransformer>
  PreparerResult prepareUnsafeImpl(T dag, ObjectReader[] inputValueLists) {
    try (DAGExecutionResult res = prepareAndApplyUnsafeImpl(dag, inputValueLists)) {
      return res.getPreparerResult();
    }
  }

  protected abstract > ObjectReader[] applyUnsafeImpl(T dag,
      ObjectReader[] inputValueLists);

  // gets a new PreparerResult containing prepared DAGs with the same properties as the DAG they were prepared from
  private static , T extends PreparableDAGTransformer>
  PreparerResult mapPreparerResult(T dag, PreparerResult result) {
    return result.map(prepared -> prepared.internalAPI().withSameProperties(dag));
  }

  /**
   * Prepares the DAG and applies the prepared DAG to the input values.
   *
   * @param dag the DAG to prepare and apply
   * @param inputValueLists an array of ObjectReaders, one for each of the DAG's placeholders
   * @param  the type of result produced by the DAG
   * @return a {@link DAGExecutionResult} containing both the prepared DAG and the results of applying that DAG on the
   *         inputValueLists
   */
  final , T extends PreparableDAGTransformer>
  DAGExecutionResult prepareAndApplyUnsafe(T dag, ObjectReader[] inputValueLists) {
    DAGExecutionResult result = prepareAndApplyUnsafeImpl(dag, inputValueLists);
    return new DAGExecutionResult<>(mapPreparerResult(dag, result.getPreparerResult()), result.getOutputs());
  }

  /**
   * Applies a prepared DAG to the input values, returning an array of ObjectReaders, one for each of the outputs of the
   * DAG.
   *
   * @param dag the DAG to apply
   * @param inputValueLists an array of ObjectReaders, one for each of the DAG's placeholders
   * @param  the type of result returned by the DAG
   * @return an array of ObjectReaders containing the outputs, one per each of the DAG's outputs
   */
  final > ObjectReader[] applyUnsafe(T dag,
      ObjectReader[] inputValueLists) {
    return applyUnsafeImpl(dag, inputValueLists);
  }

  /**
   * Prepares the DAG and returns the prepared DAG.  Unlike prepareAndApplyUnsafe(...), the executor is not obliged
   * to return the results of applying the prepared DAG to the inputs, which speeds execution.  Use this method when
   * only the prepared model (and not the application of the model to the inputs) is required.
   *
   * @param dag the DAG to prepare
   * @param inputValueLists an array of ObjectReaders, one for each of the DAG's placeholders
   * @param  the type of result returned by the DAG
   * @return the PreparerResult containing the prepared DAGs for the training data and for new data
   */
  @SuppressWarnings("unchecked") // preparing to type N is guaranteed by the API spec
  final , T extends PreparableDAGTransformer>
  PreparerResult prepareUnsafe(T dag, ObjectReader[] inputValueLists) {
    if (dag.internalAPI().getDAGStructure()._isPrepared) {
      return new PreparerResult<>(
          (N) DAGMakerUtil.makePreparedDAGTransformer(dag.internalAPI().getDAGStructure())
              .internalAPI()
              .withSameProperties(dag));
    }
    return mapPreparerResult(dag, prepareUnsafeImpl(dag, inputValueLists));
  }
}