Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.linkedin.dagli.dag.SimpleDAGExecutor Maven / Gradle / Ivy
Go to download
DAG-oriented machine learning framework for bug-resistant, readable, efficient, maintainable and trivially deployable models in Java and other JVM languages
package com.linkedin.dagli.dag;
import com.linkedin.dagli.generator.Constant;
import com.linkedin.dagli.generator.Generator;
import com.linkedin.dagli.objectio.biglist.BigListWriter;
import com.linkedin.dagli.objectio.ConcatenatedReader;
import com.linkedin.dagli.objectio.ConstantReader;
import com.linkedin.dagli.objectio.ObjectIterator;
import com.linkedin.dagli.objectio.ObjectReader;
import com.linkedin.dagli.objectio.ObjectWriter;
import com.linkedin.dagli.placeholder.Placeholder;
import com.linkedin.dagli.preparer.Preparer;
import com.linkedin.dagli.preparer.PreparerContext;
import com.linkedin.dagli.preparer.PreparerResult;
import com.linkedin.dagli.preparer.PreparerResultMixed;
import com.linkedin.dagli.producer.ChildProducer;
import com.linkedin.dagli.producer.Producer;
import com.linkedin.dagli.transformer.PreparableTransformer;
import com.linkedin.dagli.transformer.PreparedTransformer;
import com.linkedin.dagli.transformer.Transformer;
import com.linkedin.dagli.util.invariant.Arguments;
import com.linkedin.dagli.view.TransformerView;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
/**
* A Simple(r) DAG executor for training and inference on DAGs. The {@link SimpleDAGExecutor} is effectively a
* reference implementation and can be used for debugging and checking correctness relative to other executors.
*
* A consequence of its simplicity is that it's single-threaded and thus relatively slow vs.
* {@link MultithreadedDAGExecutor} when the machine has multiple logical processors. Normally this executor would not
* be used outside of testing.
*/
public final class SimpleDAGExecutor extends AbstractDAGExecutor implements DAGExecutor {
private static final long serialVersionUID = 1L;
@Override
public int hashCode() {
return SimpleDAGExecutor.class.hashCode();
}
@Override
public boolean equals(Object obj) {
return obj instanceof SimpleDAGExecutor;
}
private static ObjectReader generateIterable(long inputSize, Generator> generator) {
if (inputSize == 0) {
return ObjectReader.empty();
}
BigListWriter bbal = new BigListWriter<>(inputSize);
Object[] buffer = new Object[(int) Math.min(4096, inputSize)];
for (long i = 0; i < inputSize; i++) {
int bufferIndex = (int) (i % buffer.length);
buffer[bufferIndex] = generator.generate(i);
if (bufferIndex == buffer.length - 1) {
bbal.write(buffer, 0, buffer.length);
}
}
bbal.write(buffer, 0, (int) (inputSize % buffer.length));
assert bbal.size64() == inputSize;
return bbal.createReader();
}
private PreparedTransformer> transformerWithNewInputs(List extends Producer>> transformerInputs,
PreparedTransformer preparedTransformer, HashMap, Producer>> producerMap) {
//PreparedTransformer preparedTransformer =
// originalPreparedTransformer.internalAPI().withProgenitorHandleUnsafe(originalPreparedTransformer.getHandle());
if (transformerInputs.stream().anyMatch(input -> producerMap.get(input) != input)) {
PreparedTransformer res = preparedTransformer.internalAPI().withInputsUnsafe(
transformerInputs.stream().map(producerMap::get).collect(Collectors.toList()));
LogManager.getLogger()
.trace(() -> "Pre-prepared transformer " + preparedTransformer.toString()
+ " requires new, prepared inputs, became " + res.toString());
return res;
} else {
return preparedTransformer;
}
}
@Override
protected > ObjectReader>[] applyUnsafeImpl(T dag,
ObjectReader[] inputValueLists) {
return prepareAndApply(dag, inputValueLists).getOutputs();
}
@Override
protected , T extends PreparableDAGTransformer> DAGExecutionResult
prepareAndApplyUnsafeImpl(T dag, ObjectReader[] inputValueLists) {
return (DAGExecutionResult) prepareAndApply(dag, inputValueLists);
}
private > DAGExecutionResult
prepareAndApply(T dag, ObjectReader[] inputValueLists) {
HashMap, ObjectReader>> cache = new HashMap<>();
HashMap, Producer>> preparedForNewDataProducerMap = new HashMap<>();
HashMap, Producer>> preparedForPreparationDataProducerMap = new HashMap<>();
long inputSize = inputValueLists[0].size64();
DAGStructure dagStructure = dag.internalAPI().getDAGStructure();
for (int i = 0; i < inputValueLists.length; i++) {
Arguments.check(inputValueLists[i].size64() == inputSize);
cache.put(dagStructure._placeholders.get(i), inputValueLists[i]);
}
// placeholders are intrinsically "prepared"
for (Placeholder> placeholder : dagStructure._placeholders) {
preparedForNewDataProducerMap.put(placeholder, placeholder);
preparedForPreparationDataProducerMap.put(placeholder, placeholder);
}
HashMap, Set>> unsatisfiedDependencies =
new HashMap<>(dagStructure._childrenMap.size());
LinkedList> queue = new LinkedList<>();
for (Producer> producer : dagStructure._childrenMap.keySet()) {
if (producer instanceof Generator>) {
Generator> generator = (Generator>) producer;
// like placeholders, generators are intrinsically "prepared"
preparedForNewDataProducerMap.put(generator, generator);
preparedForPreparationDataProducerMap.put(generator, generator);
// and their values can be generated immediately
cache.put(generator, generateIterable(inputSize, generator));
} else if (producer instanceof ChildProducer>) {
ChildProducer> child = (ChildProducer>) producer;
Set> dependencies = child.internalAPI().getInputList()
.stream()
.filter(p -> p instanceof ChildProducer>)
.map(p -> (ChildProducer>) p)
.collect(Collectors.toSet());
if (dependencies.isEmpty()) {
queue.push(child);
} else {
unsatisfiedDependencies.put(child, dependencies);
}
}
}
while (!queue.isEmpty()) {
ChildProducer> producer = queue.pop();
List extends Producer>> parents = producer.internalAPI().getInputList();
List> args =
parents.stream().map(cache::get).collect(Collectors.toList());
final ObjectReader results;
if (producer instanceof Transformer>) {
final PreparedTransformer> preparedForNewData;
final PreparedTransformer> preparedForPreparationData;
if (producer instanceof PreparedTransformer>) {
PreparedTransformer> preparedTransformer = (PreparedTransformer>) producer;
preparedForNewData = transformerWithNewInputs(parents, preparedTransformer, preparedForNewDataProducerMap);
preparedForPreparationData =
transformerWithNewInputs(parents, preparedTransformer, preparedForPreparationDataProducerMap);
} else if (producer instanceof PreparableTransformer, ?>) {
PreparableTransformer, ?> preparableTransformer = (PreparableTransformer, ?>) producer;
Preparer, ?> transformerPreparer = preparableTransformer.internalAPI()
.getPreparer(PreparerContext.builder(inputSize).setExecutor(this).build());
ObjectIterator>[] iterators = args.stream().map(ObjectReader::iterator).toArray(ObjectIterator[]::new);
for (long i = 0; i < inputSize; i++) {
Object[] objs = new Object[args.size()];
for (int j = 0; j < parents.size(); j++) {
objs[j] = iterators[j].next();
}
transformerPreparer.processUnsafe(objs);
}
PreparerResultMixed extends PreparedTransformer>, ? extends PreparedTransformer>> preparerResult =
transformerPreparer.finishUnsafe(
new ConcatenatedReader<>(Object[]::new, args.toArray(new ObjectReader[0])));
List> preparedInputsForNewData =
parents.stream().map(preparedForNewDataProducerMap::get).collect(Collectors.toList());
List> preparedInputsForPreparationData =
parents.stream().map(preparedForPreparationDataProducerMap::get).collect(Collectors.toList());
preparedForNewData = preparerResult.getPreparedTransformerForNewData()
.internalAPI()
.withInputsUnsafe(preparedInputsForNewData);
preparedForPreparationData = preparerResult.getPreparedTransformerForPreparationData()
.internalAPI()
.withInputsUnsafe(preparedInputsForPreparationData);
if (LogManager.getLogger().getLevel().equals(Level.TRACE)) {
assert (preparedForNewData.internalAPI().getInputList().size() == preparedInputsForNewData.size());
for (int i = 0; i < preparedForNewData.internalAPI().getInputList().size(); i++) {
if (preparedForNewData.internalAPI().getInputList().get(i) != preparedInputsForNewData.get(i)) {
throw new IllegalStateException("Input mismatch while processing transformer " + producer.toString());
}
}
}
} else {
throw new IllegalArgumentException("Unknown transformer type");
}
ObjectWriter resultsAccumulator = new BigListWriter<>(inputSize);
ObjectIterator[] iterators = args.stream().map(ObjectReader::iterator).toArray(ObjectIterator[]::new);
Object executionState = preparedForPreparationData.internalAPI().createExecutionCache(inputSize);
long remaining = inputSize;
while (remaining > 0) {
int batchSize = (int) Math.min(remaining, Integer.MAX_VALUE - 8); // limit to safe-ish max array size
remaining -= batchSize;
// yes, we could reuse arrays across batches, but it is *highly* unlikely this executor will ever be applied
// with more than a single batch worth of examples
Object[][] objs = new Object[iterators.length][batchSize];
for (int j = 0; j < parents.size(); j++) {
iterators[j].next(objs[j], 0, batchSize);
}
Object[] resultArray = new Object[batchSize];
preparedForPreparationData.internalAPI().applyAllUnsafe(executionState, batchSize, objs, resultArray);
resultsAccumulator.writeAll(resultArray);
}
results = resultsAccumulator.createReader();
// check all ancestors for unprepared transformers
if (LogManager.getLogger().getLevel().equals(Level.TRACE)) {
HashSet> seen = new HashSet<>();
LinkedList toCheck = new LinkedList<>();
toCheck.add(preparedForNewData);
seen.add(preparedForNewData);
while (!toCheck.isEmpty()) {
PreparedTransformer next = toCheck.pop();
for (Object parent : next.internalAPI().getInputList()) {
if (parent instanceof PreparableTransformer) {
LogManager.getLogger().error(
"ERROR!: " + preparedForNewData.toString() + " has non-prepared ancestor: " + parent.toString());
} else if (parent instanceof PreparedTransformer) {
if (!seen.contains(parent)) {
toCheck.add((PreparedTransformer) parent);
seen.add((PreparedTransformer) parent);
}
}
}
}
}
preparedForNewDataProducerMap.put(producer, preparedForNewData);
preparedForPreparationDataProducerMap.put(producer, preparedForPreparationData);
} else if (producer instanceof TransformerView, ?>) {
TransformerView view = (TransformerView, ?>) producer;
assert parents.size() == 1;
PreparedTransformer> parentPreparedForNewData =
(PreparedTransformer>) preparedForNewDataProducerMap.get(parents.get(0));
PreparedTransformer> parentPreparedForPreparationData =
(PreparedTransformer>) preparedForPreparationDataProducerMap.get(parents.get(0));
Object valueForNewData = view.internalAPI().prepare(parentPreparedForNewData);
Object valueForPreparationData =
view.internalAPI().prepareForPreparationData(parentPreparedForPreparationData, parentPreparedForNewData);
preparedForNewDataProducerMap.put(view, new Constant<>(valueForNewData));
preparedForPreparationDataProducerMap.put(view, new Constant<>(valueForPreparationData));
results = new ConstantReader(valueForPreparationData, inputSize);
} else {
throw new IllegalArgumentException("Unknown ChildProducer type");
}
cache.put(producer, results);
for (ChildProducer> child : dagStructure._childrenMap.get(producer)) {
Set> dependencies = unsatisfiedDependencies.get(child);
dependencies.remove(producer);
if (dependencies.isEmpty()) {
queue.add(child);
}
}
}
PreparedTransformer preparedForNewDataDAG;
PreparedTransformer preparedForPreparationDataDAG;
if (dag instanceof PreparedDAGTransformer) {
preparedForNewDataDAG = (PreparedDAGTransformer) dag;
preparedForPreparationDataDAG = (PreparedDAGTransformer) dag;
} else {
PreparableDAGTransformer preparableDAG = (PreparableDAGTransformer) dag;
preparedForNewDataDAG = preparableDAG.internalAPI().createPreparedDAG(dagStructure._placeholders,
dagStructure._outputs.stream().map(preparedForNewDataProducerMap::get).collect(Collectors.toList()));
preparedForPreparationDataDAG = preparableDAG.internalAPI().createPreparedDAG(dagStructure._placeholders,
dagStructure._outputs.stream().map(preparedForPreparationDataProducerMap::get).collect(Collectors.toList()));
}
ObjectReader>[] resList = dagStructure._outputs.stream().map(cache::get).toArray(ObjectReader[]::new);
return new DAGExecutionResult(
new PreparerResult.Builder<>().withTransformerForNewData(preparedForNewDataDAG)
.withTransformerForPreparationData(preparedForPreparationDataDAG)
.build(), resList);
}
}