All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.parallelism.ParallelInference Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.parallelism;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObserver;
import org.deeplearning4j.parallelism.inference.observers.BatchedInferenceObservable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;

import java.util.ArrayList;
import java.util.List;
import java.util.Observer;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;

@Slf4j
public class ParallelInference {
    protected Model model;
    protected long nanos;
    protected int workers;
    protected int batchLimit;
    protected InferenceMode inferenceMode;
    protected int queueLimit;
    protected LoadBalanceMode loadBalanceMode = LoadBalanceMode.FIFO;

    // this queue holds data for inference
    private BlockingQueue observables;

    private final Object locker = new Object();

    private InferenceWorker[] zoo;
    private ObservablesProvider provider;



    public final static int DEFAULT_NUM_WORKERS = Nd4j.getAffinityManager().getNumberOfDevices();
    public final static int DEFAULT_BATCH_LIMIT = 32;
    public final static InferenceMode DEFAULT_INFERENCE_MODE = InferenceMode.BATCHED;
    public final static int DEFAULT_QUEUE_LIMIT = 64;



    protected ParallelInference() {
        //
    }

    /**
     * This method allows to update Model used for inference in runtime, without queue reset
     *
     * @param model
     */
    public void updateModel(@NonNull Model model) {
        if (zoo != null) {
            for (val w: zoo)
                w.updateModel(model);
        } else {
            // if zoo wasn't initalized yet - just replace model
            this.model = model;
        }
    }

    /**
     * This method returns Models used in workers at this moment
     * PLEASE NOTE: This method is NOT thread safe, and should NOT be used anywhere but tests
     *
     * @return
     */
    protected Model[] getCurrentModelsFromWorkers() {
        if (zoo == null)
            return new Model[0];

        val models = new Model[zoo.length];
        int cnt = 0;
        for (val w:zoo) {
            models[cnt++] = w.replicatedModel;
        }

        return models;
    }

    protected void init() {
        observables = new LinkedBlockingQueue<>(queueLimit);

        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int currentDevice = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        AtomicBoolean assignedRoot = new AtomicBoolean(false);

        zoo = new InferenceWorker[workers];
        for (int i = 0; i < workers; i++) {
            int cDevice = i % numDevices;
            boolean cRoot = !assignedRoot.get() && cDevice == currentDevice;
            assignedRoot.compareAndSet(false, cRoot);

            zoo[i] = new InferenceWorker(i, model, observables, cRoot, cDevice);

            zoo[i].setDaemon(true);
            zoo[i].start();
        }


        if (inferenceMode == InferenceMode.BATCHED) {
            log.info("Initializing ObservablesProvider...");
            provider = new ObservablesProvider(nanos, batchLimit, observables);
        }
    }

    protected long getWorkerCounter(int workerIdx) {
        return zoo[workerIdx].getCounterValue();
    }

    /**
     * This method gracefully shuts down ParallelInference instance
     */
    public synchronized void shutdown() {
        if (zoo == null)
            return;

        for (int e = 0; e < zoo.length; e++) {
            if (zoo[e] == null)
                continue;

            zoo[e].interrupt();
            zoo[e].shutdown();
            zoo[e] = null;
        }
        zoo = null;

        System.gc();
    }

    /**
     *
     * @param input
     * @return
     */
    public INDArray output(double[] input) {
        return output(Nd4j.create(input));
    }

    /**
     *
     * @param input
     * @return
     */
    public INDArray output(float[] input) {
        return output(Nd4j.create(input));
    }

    public INDArray output(INDArray input) {
        return output(input, null);
    }

    public INDArray output(INDArray input, INDArray inputMask){
        INDArray[] out = output(new INDArray[]{input}, (inputMask == null ? null : new INDArray[]{inputMask}));
        // basically, depending on model type we either
        // throw stuff to specific model, or wait for batch
        if(out.length != 1){
            throw new IllegalArgumentException("Network has multiple (" + out.length + ") output arrays, but only a" +
                    " single output can be returned using this method. Use for output(INDArray[] input, INDArray[] " +
                    "inputMasks) for multi-output nets");
        }
        return out[0];
    }

    /**
     *
     * @param dataSet
     * @return
     */
    public INDArray output(DataSet dataSet) {
        return output(dataSet.getFeatures(), dataSet.getFeaturesMaskArray());
    }

    /**
     * Generate predictions/output from the netwonk
     *
     * @param input Input to the network
     * @return Output from the network
     */
    public INDArray[] output(INDArray... input) {
        return output(input, null);
    }

    /**
     * Generate predictions/outputs from the network, optionally using input masks for predictions
     *
     * @param input      Input to the network
     * @param inputMasks Input masks for the network. May be null.
     * @return Output from the network
     */
    public INDArray[] output(INDArray[] input, INDArray[] inputMasks){
        Nd4j.getExecutioner().commit(); //Commit before passing input to other thread

        // basically, depending on model type we either throw stuff to specific model, or wait for batch
        BasicInferenceObserver observer = new BasicInferenceObserver();
        InferenceObservable observable;

        if (inferenceMode == InferenceMode.SEQUENTIAL) {
            observable = new BasicInferenceObservable(input, inputMasks);
            observable.addObserver(observer);
            try {
                observables.put(observable);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(e);
            }
        } else {
            observable = provider.setInput(observer, input, inputMasks);
        }

        try {
            // submit query to processing
            // and block until Observable returns
            //observer.wait();

            observer.waitTillDone();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }

        return observable.getOutput();
    }

    /**
     * This method does forward pass and returns output provided by OutputAdapter
     *
     * @param adapter
     * @param inputs
     * @return
     */
    public  T output(@NonNull ModelAdapter adapter, INDArray... inputs) {
        return output(adapter, inputs, null);
    }

    /**
     * This method does forward pass and returns output provided by OutputAdapter
     *
     * @param adapter
     * @param input
     * @param inputMasks
     * @param 
     * @return
     */
    public  T output(@NonNull ModelAdapter adapter,INDArray[] input, INDArray[] inputMasks) {
        throw new ND4JIllegalStateException("Adapted mode requires Inplace inference mode");
    }


    public static class Builder {
        private Model model;
        private int workers = DEFAULT_NUM_WORKERS;
        private int batchLimit = DEFAULT_BATCH_LIMIT;
        private InferenceMode inferenceMode = DEFAULT_INFERENCE_MODE;
        private int queueLimit = DEFAULT_QUEUE_LIMIT;
        protected LoadBalanceMode loadBalanceMode = LoadBalanceMode.FIFO;

        public Builder(@NonNull Model model) {
            this.model = model;
        }


        /**
         * This method allows you to define mode that'll be used during inference. Options are:
         *
         * SEQUENTIAL: Input will be sent to last-used worker unmodified.
         * BATCHED: Multiple inputs will be packed into single batch, and
         * sent to last-used device.
         *
         * @param inferenceMode
         * @return
         */
        public Builder inferenceMode(@NonNull InferenceMode inferenceMode) {
            this.inferenceMode = inferenceMode;
            return this;
        }


        /**
         * This method allows you to specify load balance mode
         *
         * @param loadBalanceMode
         * @return
         */
        public Builder loadBalanceMode(@NonNull LoadBalanceMode loadBalanceMode) {
            this.loadBalanceMode = loadBalanceMode;
            return this;
        }


        /**
         * This method defines, how many model copies will be used for inference.
         *
         * PLEASE NOTE: This method primarily suited for multi-GPU systems
         * PLEASE NOTE: For INPLACE inference mode this value will mean number of models per DEVICE
         *
         * @param workers
         * @return
         */
        public Builder workers(int workers) {
            if (workers < 1)
                throw new IllegalStateException("Workers should be positive value");

            this.workers = workers;
            return this;
        }

        /**
         * This method defines, how many input samples can
         * be batched within given time frame.
         *
         * PLEASE NOTE: This value has no effect in
         * SEQUENTIAL inference mode
         *
         * @param limit
         * @return
         */
        public Builder batchLimit(int limit) {
            if (limit < 1)
                throw new IllegalStateException("Batch limit should be positive value");

            this.batchLimit = limit;
            return this;
        }

        /**
         * This method defines buffer queue size.
         *
         * Default value: 64
         *
         * @param limit
         * @return
         */
        public Builder queueLimit(int limit) {
            if (limit < 1)
                throw new IllegalStateException("Queue limit should be positive value");

            this.queueLimit = limit;
            return this;
        }

        /**
         * This method builds new ParallelInference instance
         *
         * @return
         */
        public ParallelInference build() {
            if (this.inferenceMode == InferenceMode.INPLACE) {
                val inf = new InplaceParallelInference();
                inf.inferenceMode = this.inferenceMode;
                inf.model = this.model;
                inf.workers = this.workers;
                inf.loadBalanceMode = this.loadBalanceMode;

                inf.init();

                return inf;
            } else {
                ParallelInference inference = new ParallelInference();
                inference.batchLimit = this.batchLimit;
                inference.queueLimit = this.queueLimit;
                inference.inferenceMode = this.inferenceMode;
                inference.model = this.model;
                inference.workers = this.workers;
                inference.loadBalanceMode = this.loadBalanceMode;

                inference.init();

                return inference;
            }
        }
    }


    /**
     * This class actually does inference with respect to device affinity
     *
     */
    private class InferenceWorker extends Thread implements Runnable {
        private BlockingQueue inputQueue;
        private AtomicBoolean shouldWork = new AtomicBoolean(true);
        private AtomicBoolean isStopped = new AtomicBoolean(false);
        private Model protoModel;
        private Model replicatedModel;
        private AtomicLong counter = new AtomicLong(0);
        private boolean rootDevice;
        private int deviceId;

        private ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock();

        private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice, int deviceId) {
            this.inputQueue = inputQueue;
            this.protoModel = model;
            this.rootDevice = rootDevice;
            this.deviceId = deviceId;

            this.setDaemon(true);
            this.setName("InferenceThread-" + id);

        }

        protected long getCounterValue() {
            return counter.get();
        }

        protected void updateModel(@NonNull Model model) {
            try {
                modelLock.writeLock().lock();
                this.protoModel = model;

                // now re-init model
                initializeReplicaModel();
            } finally {
                modelLock.writeLock().unlock();
            }
        }

        /**
         * This method duplicates model for future use during inference
         */
        protected void initializeReplicaModel() {
            if (protoModel instanceof ComputationGraph) {
                if (!rootDevice) {
                    this.replicatedModel = new ComputationGraph(ComputationGraphConfiguration
                            .fromJson(((ComputationGraph) protoModel).getConfiguration().toJson()));
                    this.replicatedModel.init();

                    synchronized (locker) {
                        this.replicatedModel.setParams(protoModel.params().unsafeDuplication(true));

                        Nd4j.getExecutioner().commit();
                    }
                } else {
                    this.replicatedModel = protoModel;
                }
            } else if (protoModel instanceof MultiLayerNetwork) {
                if (!rootDevice) {
                    this.replicatedModel = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(
                            ((MultiLayerNetwork) protoModel).getLayerWiseConfigurations().toJson()));
                    this.replicatedModel.init();

                    synchronized (locker) {
                        this.replicatedModel.setParams(protoModel.params().unsafeDuplication(true));

                        Nd4j.getExecutioner().commit();
                    }
                } else {
                    this.replicatedModel = protoModel;
                }
            }
        }

        @Override
        public void run() {
            Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
            try {
                // model should be replicated & initialized here
                initializeReplicaModel();

                boolean isCG = replicatedModel instanceof  ComputationGraph;
                boolean isMLN = replicatedModel instanceof  MultiLayerNetwork;

                while (shouldWork.get()) {
                    InferenceObservable request = inputQueue.take();

                    if (request != null) {
                        counter.incrementAndGet();

                        // FIXME: get rid of instanceof here, model won't change during runtime anyway
                        if (isCG) {
                            List> batches = request.getInputBatches();
                            List out = new ArrayList<>(batches.size());
                            try {
                                for (Pair inBatch : batches) {
                                    try {
                                        modelLock.readLock().lock();

                                        INDArray[] output = ((ComputationGraph) replicatedModel).output(false, inBatch.getFirst(), inBatch.getSecond());
                                        out.add(output);
                                    } finally {
                                        Nd4j.getExecutioner().commit();
                                        modelLock.readLock().unlock();
                                    }

                                }
                                request.setOutputBatches(out);
                            } catch (Exception e){
                                request.setOutputException(e);
                            }
                        } else if (isMLN) {
                            List> batches = request.getInputBatches();
                            List out = new ArrayList<>(batches.size());
                            try {
                                for (Pair inBatch : batches) {
                                    INDArray f = inBatch.getFirst()[0];
                                    INDArray fm = (inBatch.getSecond() == null ? null : inBatch.getSecond()[0]);
                                    try {
                                        modelLock.readLock().lock();

                                        INDArray output = ((MultiLayerNetwork) replicatedModel).output(f, false, fm, null);
                                        out.add(new INDArray[]{output});
                                    } finally {
                                        Nd4j.getExecutioner().commit();
                                        modelLock.readLock().unlock();
                                    }
                                }
                                request.setOutputBatches(out);
                            } catch (Exception e){
                                request.setOutputException(e);
                            }
                        }


                    } else {
                        // just do nothing, i guess and hope for next round?
                    }
                }
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                // do nothing
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                isStopped.set(true);
            }
        }

        protected void shutdown() {
            shouldWork.set(false);
            while (!isStopped.get()) {
                // block until main loop is finished
            }
        }
    }


    protected static class ObservablesProvider {
        private BlockingQueue targetQueue;
        private long nanos;
        private int batchLimit;

        private volatile BatchedInferenceObservable currentObservable;
        private final Object locker = new Object();

        protected ObservablesProvider(long nanos, int batchLimit, @NonNull BlockingQueue queue) {
            this.targetQueue = queue;
            this.nanos = nanos;
            this.batchLimit = batchLimit;
        }

        protected InferenceObservable setInput(@NonNull Observer observer, INDArray input){
            return setInput(observer, new INDArray[]{input}, null);
        }

        protected InferenceObservable setInput(@NonNull Observer observer, INDArray... input){
            return setInput(observer, input, null);
        }

        protected InferenceObservable setInput(@NonNull Observer observer, INDArray[] input, INDArray[] inputMask) {
            synchronized (locker) {
                boolean isNew = false;
                if (currentObservable == null || currentObservable.getCounter() >= batchLimit
                                || currentObservable.isLocked()) {
                    isNew = true;
                    currentObservable = new BatchedInferenceObservable();
                }

                currentObservable.addInput(input, inputMask);
                currentObservable.addObserver(observer);

                try {
                    if (isNew)
                        targetQueue.put(currentObservable);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new RuntimeException(e);
                }

                return currentObservable;
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy