All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.linkedin.dagli.reducer.AssociativeClassReducer Maven / Gradle / Ivy

Go to download

DAG-oriented machine learning framework for bug-resistant, readable, efficient, maintainable and trivially deployable models in Java and other JVM languages

There is a newer version: 15.0.0-beta9
Show newest version
package com.linkedin.dagli.reducer;

import com.linkedin.dagli.producer.Producer;
import com.linkedin.dagli.transformer.TransformerVariadic;
import com.linkedin.dagli.transformer.TransformerWithInputBound;
import it.unimi.dsi.fastutil.objects.ObjectArraySet;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;


/**
 * Reduces the case where a child's inputs are associative relative to one or more of its parents, allowing the parents'
 * inputs to be directly accepted by the child (removing the parent) while still producing an equivalent result.  The
 * child must be variadic and the parent must be either variadic or unary.
 *
 * For example, {@code DensifiedVector} is associative relative to a {@code CompositeSparseVector}: creating a composite
 * sparse vector from a set of source vectors and then densifying it yields a result equivalent to just densifying the
 * source vectors directly.
 *
 * Careless use of this reducer could potentially make the DAG more expensive to execute, by removing an intermediate
 * parent whose result might be used elsewhere in the graph.  This reducer is best used only in situations where
 * consuming the parent's inputs directly is not substantively more expensive than consuming the parent's result (since
 * that parent might still be present in the reduced graph if it has other children).
 *
 * This class identifies associative parents by their class alone; consequently, it is not suitable if associativity is
 * conditional on the properties of the parent or the child.
 */
public class AssociativeClassReducer implements Reducer> {
  private final Set>> _parentClasses;

  @Override
  public Level getLevel() {
    return Level.ESSENTIAL; // high value relative to the cost
  }

  /**
   * Creates a new reducer that will check if the target transformer's parent is one of the provided classes.  If it is,
   * the the target transformer will disintermediate the parent and accept its inputs directly (if possible).
   *
   * @param parentClasses the set of classes of parents that are associative relative to the target transformer
   */
  @SafeVarargs
  public AssociativeClassReducer(Class>... parentClasses) {
    _parentClasses = new ObjectArraySet<>(parentClasses);
  }

  @Override
  @SuppressWarnings({"unchecked", "rawtypes"})
  public void reduce(TransformerVariadic target, Context context) {
    List> parents = context.getParents(target);
    if (parents.stream().anyMatch(parent -> _parentClasses.contains(parent.getClass()))) {
      ArrayList> newParentsList = new ArrayList<>(parents.size());
      for (Producer parent : parents) {
        if (_parentClasses.contains(parent.getClass())) {
          newParentsList.addAll((List>) context.getParents(parent));
        } else {
          newParentsList.add(parent);
        }
      }
      context.tryReplaceUnviewed(target, () -> (Producer) target.withInputs(newParentsList));
    }
  }

  @Override
  public boolean equals(Object o) {
    if (this == o) {
      return true;
    }
    if (o == null || getClass() != o.getClass()) {
      return false;
    }
    return _parentClasses.equals(((AssociativeClassReducer) o)._parentClasses);
  }

  @Override
  public int hashCode() {
    return _parentClasses.hashCode();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy