All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.linkedin.dagli.tester.PreparedTransformerTestBuilder Maven / Gradle / Ivy

Go to download

DAG-oriented machine learning framework for bug-resistant, readable, efficient, maintainable and trivially deployable models in Java and other JVM languages

There is a newer version: 15.0.0-beta9
Show newest version
package com.linkedin.dagli.tester;

import com.linkedin.dagli.dag.DAGTransformer;
import com.linkedin.dagli.dag.PreparedDAGTransformer;
import com.linkedin.dagli.reducer.Reducer;
import com.linkedin.dagli.transformer.PreparedTransformer;
import com.linkedin.dagli.util.array.ArraysEx;
import com.linkedin.dagli.util.collection.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Collectors;


/**
 * Tests {@link PreparedTransformer} nodes.
 *
 * @param  the type of result of the transformer
 * @param  the type of the transformer
 */
public final class PreparedTransformerTestBuilder>
    extends AbstractTransformerTestBuilder> {

  /**
   * Creates a new instance that will test the provided Dagli node.
   *
   * @param testSubject the primary test subject
   */
  public PreparedTransformerTestBuilder(T testSubject) {
    super(testSubject);
  }

  /**
   * @return a variety of ways the prepared transformer can be invoked on a minibatch of examples
   */
  private List, List>> getMinibatchAppliers() {
    List, List>> standard = Arrays.asList((prepared, inputs) -> {
      R[] results = (R[]) new Object[inputs.size()];
      prepared.internalAPI()
          .applyAllUnsafe(prepared.internalAPI().createExecutionCache(inputs.size()), inputs.size(),
              ArraysEx.transpose(inputs.toArray(new Object[0][])), results);
      return Arrays.asList(results);
    }, (prepared, inputs) -> {
      ArrayList resultList = new ArrayList<>(inputs.size());
      Object[][] transposed = ArraysEx.transpose(inputs.toArray(new Object[0][]));
      prepared.internalAPI()
          .applyAllUnsafe(prepared.internalAPI().createExecutionCache(inputs.size()), inputs.size(),
              Arrays.stream(transposed).map(Arrays::asList).collect(Collectors.toList()), resultList);
      return resultList;
    });

    // make sure handlers correctly deal with over-sized arrays
    List, List>> enlengthened = standard.stream()
        .map(function -> (BiFunction, List>) (prepared, inputs) -> function.apply(prepared,
            inputs.stream().map(arr -> Arrays.copyOf(arr, arr.length + 1)).collect(Collectors.toList())))
        .collect(Collectors.toList());

    return Iterables.concatenate(standard, enlengthened);
  }

  private static Object[][] lift(Object[] array) {
    Object[][] res = new Object[array.length][1];
    for (int i = 0; i < array.length; i++) {
      res[i][0] = array[i];
    }
    return res;
  }

  /**
   * @return a variety of ways the prepared transformer can be invoked on an example
   */
  private List> getAppliers() {
    List> standard = Arrays.asList((prepared, inputs) -> prepared.internalAPI()
            .applyUnsafe(prepared.internalAPI().createExecutionCache(1), inputs),
        (prepared, inputs) -> prepared.internalAPI()
            .applyUnsafe(prepared.internalAPI().createExecutionCache(1), Arrays.asList(inputs)),
        (prepared, inputs) -> prepared.internalAPI()
            .applyUnsafe(prepared.internalAPI().createExecutionCache(1), lift(inputs), 0));

    // make sure handlers correctly deal with over-sized arrays
    List> enlengthened = standard.stream()
        .map(function -> (BiFunction) (prepared, inputs) -> function.apply(prepared,
            Arrays.copyOf(inputs, inputs.length + 1)))
        .collect(Collectors.toList());

    return Iterables.concatenate(standard, enlengthened);
  }

  @SuppressWarnings("unchecked")
  private void simpleReductionTest() {
    if (_skipSimpleReductionTest) {
      return;
    }

    // test a reduced graph
    PreparedDAGTransformer dag =
        (PreparedDAGTransformer) DAGTransformer.withOutput(_testSubject).withReduction(Reducer.Level.EXPENSIVE);

    Tester.of(dag)
        .skipNonTrivialEqualityCheck()
        .skipValidation(_skipValidation)
        .skipSimpleReductionTest() // avoids infinite recursion!
        .allInputs(_inputs)
        .allOutputTests(_outputsTesters)
        .skipNonTrivialEqualityCheck()
        .distinctOutputs(_distinctOutputs)
        .test();
  }

  @Override
  public void test() {
    super.test();
    for (BiFunction applier : getAppliers()) {
      checkInputsAndOutputsForAll(applier);
      checkInputsAndOutputsFor(withPlaceholderInputs(_testSubject), applier);
    }
    for (BiFunction, List> minibatchApplier : getMinibatchAppliers()) {
      checkMinibatchedInputsAndOutputsForAll(minibatchApplier);
      checkMinibatchedInputsAndOutputsFor(withPlaceholderInputs(_testSubject), minibatchApplier);
    }

    if (_distinctOutputs) {
      HashSet resultSet = new HashSet<>();
      Object executionCache = _testSubject.internalAPI().createExecutionCache(_inputs.size());
      for (Object[] input : _inputs) {
        R result = _testSubject.internalAPI().applyUnsafe(executionCache, input);
        if (!resultSet.add(result)) {
          throw new AssertionError(
              "The prepared transformer " + _testSubject + " produced the result " + result + " for the input sequence "
                  + Arrays.toString(input)
                  + ", which is equals() to a result prepared for another tested input.  This is an error "
                  + "because this test was configured with distinctOutputs().");
        }
      }
    }

    simpleReductionTest();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy