org.deeplearning4j.parallelism.inference.observers.BatchedInferenceObservable Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.parallelism.inference.observers;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* This class holds reference input, and implements second use case: BATCHED inference
*
* @author [email protected]
*/
@Slf4j
public class BatchedInferenceObservable extends BasicInferenceObservable implements InferenceObservable {
private List inputs = new ArrayList<>();
private List inputMasks = new ArrayList<>();
private List outputs = new ArrayList<>();
private AtomicInteger counter = new AtomicInteger(0);
private ThreadLocal position = new ThreadLocal<>();
private List outputBatchInputArrays = new ArrayList<>();
private final Object locker = new Object();
private ReentrantReadWriteLock realLocker = new ReentrantReadWriteLock();
private AtomicBoolean isLocked = new AtomicBoolean(false);
private AtomicBoolean isReadLocked = new AtomicBoolean(false);
public BatchedInferenceObservable() {
}
@Override
public void addInput(INDArray[] input, INDArray[] inputMasks) {
synchronized (locker) {
inputs.add(input);
this.inputMasks.add(inputMasks);
position.set(counter.getAndIncrement());
if (isReadLocked.get())
realLocker.readLock().unlock();
}
}
@Override
public List> getInputBatches() {
realLocker.writeLock().lock();
isLocked.set(true);
outputBatchInputArrays.clear();
// this method should pile individual examples into single batch
if (counter.get() > 1) {
int pos = 0;
List> out = new ArrayList<>();
int numArrays = inputs.get(0).length;
while(pos < inputs.size()) {
//First: determine which we can actually batch...
int lastPossible = pos;
for (int i = pos+1; i < inputs.size(); i++) {
if (canBatch(inputs.get(pos), inputs.get(i))) {
lastPossible = i;
} else {
break;
}
}
int countToMerge = lastPossible-pos+1;
INDArray[][] featuresToMerge = new INDArray[countToMerge][0];
INDArray[][] fMasksToMerge = null;
int fPos = 0;
for( int i=pos; i<=lastPossible; i++ ){
featuresToMerge[fPos] = inputs.get(i);
if(inputMasks.get(i) != null) {
if(fMasksToMerge == null){
fMasksToMerge = new INDArray[countToMerge][0];
for( int j=0; j merged = DataSetUtil.mergeFeatures(featuresToMerge, fMasksToMerge);
out.add(merged);
outputBatchInputArrays.add(new int[]{pos, lastPossible});
pos = lastPossible+1;
}
realLocker.writeLock().unlock();
return out;
} else {
outputBatchInputArrays.add(new int[]{0,0});
realLocker.writeLock().unlock();
return Collections.singletonList(new Pair<>(inputs.get(0), inputMasks.get(0)));
}
}
private static boolean canBatch(INDArray[] first, INDArray[] candidate){
//Check if we can batch these inputs into the one array. This isn't always possible - for example, some fully
// convolutional nets can support different input image sizes
//For now: let's simply require that the inputs have the same shape
//In the future: we'll intelligently handle the RNN variable length case
//Note also we can ignore input masks here - they should have shared dimensions with the input, thus if the
// inputs can be batched, so can the masks
for(int i=0; i output) {
//this method should split batched output INDArray[] into multiple separate INDArrays
int countNumInputBatches = 0; //Counter for total number of input batches processed
for( int outBatchNum=0; outBatchNum getOutputs() {
return outputs;
}
protected void setCounter(int value) {
counter.set(value);
}
public void setPosition(int pos) {
position.set(pos);
}
public int getCounter() {
return counter.get();
}
public boolean isLocked() {
boolean lck = !realLocker.readLock().tryLock();
boolean result = lck || isLocked.get();
if (!result)
isReadLocked.set(true);
return result;
}
@Override
public INDArray[] getOutput() {
// basically we should take care of splits here: each client should get its own part of output, wrt order number
checkOutputException();
return outputs.get(position.get());
}
}