All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.impl.evaluation.EvaluationRunner Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.spark.impl.evaluation;

import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.DeviceLocalNDArray;

import java.io.ByteArrayInputStream;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
public class EvaluationRunner {

    private static final EvaluationRunner INSTANCE = new EvaluationRunner();

    public static EvaluationRunner getInstance(){
        return INSTANCE;
    }

    private final AtomicInteger workerCount = new AtomicInteger(0);
    private Queue queue = new ConcurrentLinkedQueue<>();
    //parameters map for device local parameters for a given broadcast
    //Note: byte[] doesn't override Object.equals hence this is effectively an *identity* weak hash map, which is what we want here
    //i.e., DeviceLocal can be GC'd once the Broadcast is no longer referenced anywhere
    //This approach relies on the fact that a single Broadcast object's *content* will be shared by all of Spark's threads,
    // even though the Broadcast object itself mayb not be
    //Also by storing params as a byte[] (i.e., in serialized form), we sidestep a lot of the thread locality issues
    private Map paramsMap = new WeakHashMap<>();


    private EvaluationRunner(){ }

    /**
     * Evaluate the data using the specified evaluations
     * @param evals         Evaluations to perform
     * @param evalWorkers   Number of concurrent workers
     * @param evalBatchSize Evaluation batch size to use
     * @param ds            DataSet iterator
     * @param mds           MultiDataSet iterator
     * @param isCG          True if ComputationGraph, false otherwise
     * @param json          JSON for the network
     * @param params        Parameters for the network
     * @return Future for the results
     */
    public Future execute(IEvaluation[] evals, int evalWorkers, int evalBatchSize, Iterator ds, Iterator mds,
                                         boolean isCG, Broadcast json, Broadcast params){
        Preconditions.checkArgument(evalWorkers > 0, "Invalid number of evaluation workers: must be > 0. Got: %s", evalWorkers);
        Preconditions.checkState(ds != null || mds != null, "No data provided - both DataSet and MultiDataSet iterators were null");

        //For multi-GPU we'll use a round robbin approach for worker thread/GPU affinity
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        if(numDevices <= 0)
            numDevices = 1;

        //Create the device local params if required
        DeviceLocalNDArray deviceLocalParams;
        synchronized (this){
            if(!paramsMap.containsKey(params.getValue())){
                //Due to singleton pattern, this block should execute only once (first thread)
                //Initially put on device 0. For CPU, this means we only have a single copy of the params INDArray shared by
                // all threads, which is both safe and uses the least amount of memory
                //For CUDA, we can't share threads otherwise arrays will be continually relocated, causing a crash
                //Nd4j.getMemoryManager().releaseCurrentContext();
                //NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(0);
                //Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), 0);
                byte[] pBytes = params.getValue();
                INDArray p;
                try{
                    p = Nd4j.read(new ByteArrayInputStream(pBytes));
                } catch (RuntimeException e){
                    throw new RuntimeException(e);  //Should never happen
                }
                DeviceLocalNDArray dlp = new DeviceLocalNDArray(p);
                paramsMap.put(params.getValue(), dlp);
                //log.info("paramsMap: size {}", paramsMap.size());
            }
            deviceLocalParams = paramsMap.get(params.getValue());
        }

        int currentWorkerCount;
        while((currentWorkerCount = workerCount.get()) < evalWorkers){
            //For load balancing: we're relying on the fact that threads are mapped to devices in a round-robbin approach
            // the first time they touch an INDArray. If we assume this method is called by new threads,
            // then the first N workers will be distributed evenly across available devices.

                if (workerCount.compareAndSet(currentWorkerCount, currentWorkerCount + 1)) {
                    log.debug("Starting evaluation in thread {}", Thread.currentThread().getId());
                    //This thread is now a worker
                    EvaluationFuture f = new EvaluationFuture();
                    f.setResult(evals);
                    try {
                        Model m;
                        if (isCG) {
                            ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(json.getValue());
                            ComputationGraph cg = new ComputationGraph(conf);
                            cg.init(deviceLocalParams.get(), false);
                            m = cg;
                        } else {
                            MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(json.getValue());
                            MultiLayerNetwork net = new MultiLayerNetwork(conf);
                            net.init(deviceLocalParams.get(), false);
                            m = net;
                        }

                        //Perform eval on this thread's data
                        try {
                            doEval(m, evals, ds, mds, evalBatchSize);
                        } catch (Throwable t) {
                            f.setException(t);
                        } finally {
                            f.getSemaphore().release(1);
                        }

                        //Perform eval on other thread's data
                        while (!queue.isEmpty()) {
                            Eval e = queue.poll();  //Use poll not remove to avoid race condition on last element
                            if (e == null)
                                continue;
                            try {
                                doEval(m, evals, e.getDs(), e.getMds(), evalBatchSize);
                            } catch (Throwable t) {
                                e.getFuture().setException(t);
                            } finally {
                                e.getFuture().getSemaphore().release(1);
                            }
                        }
                    } finally {
                        workerCount.decrementAndGet();
                        log.debug("Finished evaluation in thread {}", Thread.currentThread().getId());
                    }

                    Nd4j.getExecutioner().commit();
                    return f;
                }
        }

        //At this point: not a worker thread (otherwise, would have returned already)
        log.debug("Submitting evaluation from thread {} for processing in evaluation thread", Thread.currentThread().getId());
        EvaluationFuture f = new EvaluationFuture();
        queue.add(new Eval(ds, mds, evals, f));
        return f;
    }

    private static void doEval(Model m, IEvaluation[] e, Iterator ds, Iterator mds, int evalBatchSize){
        if(m instanceof MultiLayerNetwork){
            MultiLayerNetwork mln = (MultiLayerNetwork)m;
            if(ds != null){
                mln.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e);
            } else {
                mln.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
            }
        } else {
            ComputationGraph cg = (ComputationGraph)m;
            if(ds != null){
                cg.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e);
            } else {
                cg.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
            }
        }
    }



    @AllArgsConstructor
    @Data
    private static class Eval {
        private Iterator ds;
        private Iterator mds;
        private IEvaluation[] evaluations;
        private EvaluationFuture future;
    }

    @Setter
    @Getter
    private static class EvaluationFuture implements Future {

        private Semaphore semaphore = new Semaphore(0);
        private IEvaluation[] result;
        private Throwable exception;

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            throw new UnsupportedOperationException("Not supported");
        }

        @Override
        public boolean isCancelled() {
            return false;
        }

        @Override
        public boolean isDone() {
            return semaphore.availablePermits() > 0;
        }

        @Override
        public IEvaluation[] get() throws InterruptedException, ExecutionException {
            if(result == null && exception == null)
                semaphore.acquire();    //Block until completion (or failure) is reported
            if(exception != null){
                throw new ExecutionException(exception);
            }
            return result;
        }

        @Override
        public IEvaluation[] get(long timeout, @NonNull TimeUnit unit) {
            throw new UnsupportedOperationException();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy