com.linkedin.dagli.dag.DAGStructure Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of core Show documentation
Show all versions of core Show documentation
DAG-oriented machine learning framework for bug-resistant, readable, efficient, maintainable and trivially deployable models in Java and other JVM languages
package com.linkedin.dagli.dag;
import com.linkedin.dagli.annotation.equality.ValueEquality;
import com.linkedin.dagli.generator.Constant;
import com.linkedin.dagli.generator.Generator;
import com.linkedin.dagli.placeholder.Placeholder;
import com.linkedin.dagli.producer.ChildProducer;
import com.linkedin.dagli.producer.Producer;
import com.linkedin.dagli.transformer.AbstractPreparedTransformerDynamic;
import com.linkedin.dagli.transformer.PreparableTransformer;
import com.linkedin.dagli.transformer.PreparedTransformer;
import com.linkedin.dagli.transformer.Transformer;
import com.linkedin.dagli.tuple.Tuple;
import com.linkedin.dagli.util.collection.LinkedStack;
import com.linkedin.dagli.view.TransformerView;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Class used internally by Dagli to store the structure of a DAG. Note that some information is stored redundantly for
* efficient execution.
*
* @param the type of result produced by this DAG. For result arity greater than 1, this will be a tuple. For a
* DAG1x2, for example, this would be a Tuple2.
*/
class DAGStructure implements Serializable, Graph> {
private static final long serialVersionUID = 1;
// used to indicate that a node is not present/missing when storing a node index
private static final int MISSING_NODE_INDEX = -1;
final List> _placeholders;
final List> _outputs;
// children map will contain duplicate children if a child has more than one input from a single parent
final HashMap, ArrayList>> _childrenMap;
// easily derived from childrenMap, but we can save time by storing this explicitly:
final List> _generators;
// nodes in DAG order (no node appears before a parent) AND phase order (no node occurs before any of a previous
// phase. Placeholders are always first, followed by generators, followed by transformers and views. In any
// phase >= 1, nodes are always ordered within the phase as:
// (1) preparable transformers
// (2) views
// (3) prepared transformers
final Producer>[] _nodes;
private final Object2IntOpenHashMap> _nodeIndexMap;
// Phases of the nodes. These will be in monotonically increasing order, starting at 0.
final int[] _phases;
// All parents for a given node, corresponding to its input list. Will contain duplicates if a node has the same
// parent as more than one input
final int[][] _parents;
// Children, sorted in increasing order of index. Will contain duplicates for children with multiple inputs from the
// same parent.
final int[][] _children;
// Indices of the outputs in the node list.
final int[] _outputIndices;
// True iff all transformers in the DAG are prepared.
final boolean _isPrepared;
// The maximum minibatch size of all prepared transformers in the graph; if the graph contains no prepared
//transformers, the value will be 1.
final int _maxMinibatchSize;
// The maximum number of parents possessed by any node in the graph; if the graph contains no child nodes, the value
// will be 0.
final int _maxParentCount;
// A DAG is always constant if all of its outputs are constant (this would be unusual as it implies the placeholder
// values are ignored entirely, but possible)
final boolean _isAlwaysConstant;
// A preparable DAG will have an idempotent preparer iff all its preparable transformers have idempotent preparers.
// For prepared DAGs, this flag will be trivially true.
final boolean _hasIdempotentPreparer;
// a (modified) copy of the DAG that is used exclusively for equality checking (and hashing)
final EqualityLeaf _equalityDAG;
/**
* Creates a new instance from a {@link DeduplicatedDAG}.
* @param dag
*/
DAGStructure(DeduplicatedDAG dag) {
this(dag._placeholders, dag._outputs, dag._childrenMap);
}
/**
* Creates a new instance.
*
* @param placeholders the placeholders of the DAG
* @param outputs the outputs of the DAG
* @param childrenMap the map of children from nodes to their children
*/
DAGStructure(List> placeholders, List> outputs,
HashMap, ArrayList>> childrenMap) {
// validate all extant
childrenMap.keySet().forEach(Producer::validate);
_placeholders = placeholders;
_outputs = outputs;
_childrenMap = childrenMap;
_generators =
(List) childrenMap.keySet().stream().filter(p -> p instanceof Generator>).collect(Collectors.toList());
_nodes = new Producer[_childrenMap.size()];
_nodeIndexMap = new Object2IntOpenHashMap<>(_nodes.length);
_nodeIndexMap.defaultReturnValue(MISSING_NODE_INDEX);
_phases = new int[_nodes.length];
_parents = new int[_nodes.length][];
_children = new int[_nodes.length][];
LinkedList> preparableQueue = new LinkedList<>();
LinkedList> preparedQueue = new LinkedList<>();
LinkedList> viewQueue = new LinkedList<>();
// get a map of child producers to a set of their unsatisfied dependencies
IdentityHashMap, Set>> unsatisfiedDependencies =
DAGUtil.producerToInputSetMap(_childrenMap.keySet());
for (Producer> root : _placeholders) {
addNode(root, 0, unsatisfiedDependencies, preparableQueue, preparedQueue, viewQueue);
}
for (Producer> root : _generators) {
addNode(root, 0, unsatisfiedDependencies, preparableQueue, preparedQueue, viewQueue);
}
int phase = 0;
while (_nodeIndexMap.size() < _nodes.length) {
// add as many non-preparable nodes as possible; those nodes that are added will be those who dependencies are
// satisfied in this or previous phases
// note that views always have a single preparable dependency; adding prepared transformers in the next loop will
// *not* possibly allow more views to be added to this phase (which would otherwise create a bug)
while (!viewQueue.isEmpty()) {
addNode(viewQueue.remove(), phase, unsatisfiedDependencies, preparableQueue, preparedQueue, viewQueue);
}
while (!preparedQueue.isEmpty()) {
addNode(preparedQueue.remove(), phase, unsatisfiedDependencies, preparableQueue, preparedQueue, viewQueue);
}
phase++;
LinkedList> phasePreparables = preparableQueue;
preparableQueue = new LinkedList<>();
for (PreparableTransformer, ?> preparable : phasePreparables) {
addNode(preparable, phase, unsatisfiedDependencies, preparableQueue, preparedQueue, viewQueue);
}
}
// fill in children info
for (int i = 0; i < _nodes.length; i++) {
ArrayList> children = _childrenMap.get(_nodes[i]);
int[] childrenArray = new int[children.size()];
for (int j = 0; j < children.size(); j++) {
childrenArray[j] = getNodeIndex(children.get(j));
}
Arrays.sort(childrenArray);
_children[i] = childrenArray;
}
// store output indices to avoid having to do a lookup later
_outputIndices = new int[_outputs.size()];
for (int i = 0; i < _outputIndices.length; i++) {
_outputIndices[i] = _nodeIndexMap.getInt(_outputs.get(i));
}
boolean isPrepared = true;
for (int i = _placeholders.size() + _generators.size(); i < _nodes.length; i++) {
if (!(_nodes[i] instanceof PreparedTransformer)) {
isPrepared = false;
break;
}
}
_isPrepared = isPrepared;
// a DAG transformer is constant-result if all its outputs have a constant result
_isAlwaysConstant = _outputs.stream().allMatch(Producer::hasConstantResult);
// a DAG is marked idempotent if all its preprable transformers are idempotent
_hasIdempotentPreparer = isPrepared || Arrays.stream(_nodes)
.allMatch(producer -> !(producer instanceof PreparableTransformer)
|| ((PreparableTransformer, ?>) producer).internalAPI().hasIdempotentPreparer());
_maxMinibatchSize = Arrays.stream(_nodes)
.filter(node -> node instanceof PreparedTransformer)
.map(node -> (PreparedTransformer>) node)
.mapToInt(prepared -> prepared.internalAPI().getPreferredMinibatchSize())
.max()
.orElse(1);
_maxParentCount = Arrays.stream(_parents).mapToInt(arr -> arr.length).max().orElse(0);
_equalityDAG = createEqualityDAG();
}
/**
* @return the number of inputs the DAG accepts
*/
int getInputArity() {
return _placeholders.size();
}
/**
* @return the number of outputs the DAG produces
*/
int getOutputArity() {
return _outputIndices.length;
}
/**
* @param node the node whose index is sought
* @return the 0-based index of the node as assigned by this DAGStructure
*/
int getNodeIndex(Producer> node) {
return _nodeIndexMap.getInt(node);
}
/**
* Checks if a particular node is an output (i.e. its result is an output of the DAG)
* @param nodeIndex the index of the node to check
* @return whether or not the node is a output
*/
boolean isOutput(int nodeIndex) {
for (int outputIndex : _outputIndices) {
if (nodeIndex == outputIndex) {
return true;
}
}
return false;
}
/**
* Checks if a node is a root of the DAG (a generator or placeholder)
* @param nodeIndex the index of the node to check
* @return whether or not the node is a root
*/
boolean isRoot(int nodeIndex) {
return nodeIndex < _placeholders.size() + _generators.size();
}
private void addNode(Producer> node, int phase,
IdentityHashMap, Set>> unsatisfiedDependencies,
LinkedList> preparableQueue, LinkedList> preparedQueue,
LinkedList> viewQueue) {
int index = _nodeIndexMap.size();
_nodes[index] = node;
_nodeIndexMap.put(node, index);
_phases[index] = phase;
if (node instanceof Transformer>) {
Transformer> transformer = (Transformer>) node;
List extends Producer>> parentList = transformer.internalAPI().getInputList();
int[] parents = new int[parentList.size()];
for (int i = 0; i < parents.length; i++) {
parents[i] = getNodeIndex(parentList.get(i));
}
_parents[index] = parents;
} else if (node instanceof TransformerView, ?>) {
TransformerView, ?> transformerView = (TransformerView, ?>) node;
int[] parents = new int[] { getNodeIndex(transformerView.internalAPI().getViewed()) };
_parents[index] = parents;
} else {
assert phase == 0;
_parents[index] = new int[0];
}
for (ChildProducer> child : _childrenMap.get(node)) {
Set> dependencies = unsatisfiedDependencies.get(child);
// Could be empty/added already if this child appears multiple times in children list and we've already seen it:
if (!dependencies.isEmpty()) {
dependencies.remove(node);
if (dependencies.isEmpty()) {
if (child instanceof PreparedTransformer>) {
preparedQueue.add((PreparedTransformer>) child);
} else if (child instanceof PreparableTransformer, ?>) {
preparableQueue.add((PreparableTransformer, ?>) child);
} else if (child instanceof TransformerView, ?>) {
viewQueue.add((TransformerView, ?>) child);
} else {
throw new IllegalArgumentException("Unknown dependency type");
}
}
}
}
}
/**
* Gets the highest phase of any node in the DAG.
*
* @return the highest phase of any node in the DAG
*/
public int getLastPhase() {
return _phases[_phases.length - 1];
}
/**
* Checks if a node is in the last phase
*
* @param nodeIndex the node to check
* @return true iff the node is in the last phase
*/
public boolean isLastPhase(int nodeIndex) {
return _phases[nodeIndex] == getLastPhase();
}
/**
* Returns the index of the first node in a given phase
*
* @param phase the phase to look for
* @return the index of the first node with the specified phase
*/
public int firstNodeInPhase(int phase) {
if (phase == 0) {
return 0;
}
int firstIndex = Arrays.binarySearch(_phases, phase);
while (_phases[firstIndex - 1] == phase) {
firstIndex--;
}
return firstIndex;
}
/**
* Returns the index of the first prepared transformer in a given phase (prepared transformers are always last within
* a phase). If there are no prepared transformers in a phase, the index of the first node in the *next* phase is
* returned, or, if the requested phase is the last phase, the total number of nodes is returned instead.
*
* @param phase the phase in which to search
* @return the index of the first prepared transformer in the phase, or the index of the last node in the phase + 1 if
* there are not prepared transformers in the phase
*/
public int firstPreparedTransformerInPhase(int phase) {
int index = firstNodeInPhase(phase);
while (index < _phases.length && _phases[index] == phase && !(_nodes[index] instanceof PreparedTransformer)) {
index++;
}
return index;
}
@Override
public Set> nodes() {
return _childrenMap.keySet();
}
@Override
public List extends ChildProducer>> children(Producer> vertex) {
return _childrenMap.get(vertex);
}
@Override
public List extends Producer>> parents(Producer> vertex) {
assert _childrenMap.containsKey(vertex);
return vertex instanceof ChildProducer ? ((ChildProducer>) vertex).internalAPI().getInputList()
: Collections.emptyList();
}
/**
* No-op prepared transformer that is used to help determine the equality and hash codes of DAGStructures.
*/
@ValueEquality
private static class EqualityLeaf extends AbstractPreparedTransformerDynamic {
private static final long serialVersionUID = 1L;
public EqualityLeaf(List extends Producer>> inputs) {
super(inputs);
}
@Override
protected Void apply(List values) {
return null;
}
}
private EqualityLeaf createEqualityDAG() {
EqualityLeaf equalityDAG = new EqualityLeaf(_outputs);
IdentityHashMap, Producer>> placeholderMap = new IdentityHashMap<>(_placeholders.size());
for (int i = 0; i < _placeholders.size(); i++) {
placeholderMap.put(_placeholders.get(i), new PositionPlaceholder<>(i));
}
return DAGUtil.replaceInputs(equalityDAG, placeholderMap);
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof DAGStructure)) {
return false;
}
DAGStructure> other = (DAGStructure>) obj;
// A DAGStructure is equal to another if their graphs and numbers of placeholders are equal
return this._placeholders.size() == other._placeholders.size()
&& this._equalityDAG.equals(other._equalityDAG);
}
@Override
public int hashCode() {
return Objects.hash(this._placeholders.size(), _equalityDAG);
}
public Object[] createExecutionStateArray(long count) {
Object[] states = new Object[_nodes.length];
for (int i = _placeholders.size() + _generators.size(); i < states.length; i++) {
states[i] = ((PreparedTransformer>) _nodes[i]).internalAPI().createExecutionCache(count);
}
return states;
}
private static String intSequenceString(int[] vals) {
return Arrays.stream(vals).mapToObj(Integer::toString).collect(Collectors.joining(","));
}
public String toProducerTable() {
final String format = "%-5s%-35s%-25s%-25s\n";
StringBuilder builder = new StringBuilder();
builder.append(String.format(format, "ID", "Name", "Children", "Parents"));
for (int i = 0; i < 85; i++) {
builder.append('-');
}
builder.append('\n');
for (int i = 0; i < _nodes.length; i++) {
builder.append(String.format(format, i, _nodes[i].getName(), intSequenceString(_children[i]),
intSequenceString(_parents[i])));
}
return builder.toString();
}
/**
* Returns a stream of the producers in the DAG as discovered by a breadth-first search starting from the outputs
* (producers with a lower distance to the outputs will be returned first).
*
* The producers are provided as {@link LinkedStack}s, each representing a shortest-path from that producer to one of
* the DAG's outputs (with the top of the stack, accessible via {@link LinkedStack#peek()}, being the producer of
* interest, and the last/bottom element in the stack being an output node). Each producer (and path to that
* producer) will be enumerated only once, even if multiple shortest-paths exist.
*
* {@link Placeholder}s that are disconnected from the outputs will not be included in the returned stream.
*
* @return a stream of {@link LinkedStack}s representing paths to each connected producer in the DAG
*/
public Stream>> producers() {
return Producer.subgraphProducers(_outputs);
}
/**
* Gets the values for all outputs of this DAG that are {@link Constant}s; the values for non-{@link Constant} outputs
* will be {@code null}.
*
* Often a graph can be reduced such that all outputs that can be determined independently of the values provided by
* {@link com.linkedin.dagli.placeholder.Placeholder}s are replaced by their pre-computed values in the form of
* {@link Constant}s, but this is dependent upon the level of reduction applied to the DAG.
*
* Note that, as a {@link Constant} may itself have a null value, it is not possible to determine which outputs are
* {@link Constant} solely from the value returned by this method.
*
* @return the constant output values of this DAG
*/
@SuppressWarnings("unchecked") // R guaranteed to be the right output type or tuple-of-outputs type by DAG semantics
public R getConstantOutput() {
// some compilers will object if we don't cast Producer> to Producer
© 2015 - 2024 Weber Informatics LLC | Privacy Policy