All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.jita.concurrency.CudaAffinityManager Maven / Gradle / Ivy

The newest version!
package org.nd4j.jita.concurrency;

import lombok.NonNull;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.concurrency.BasicAffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * AffinityManager implementation for CUDA
 *
 * @author [email protected]
 */
public class CudaAffinityManager extends BasicAffinityManager {
    private static Logger logger = LoggerFactory.getLogger(CudaAffinityManager.class);

    private Map affinityMap = new ConcurrentHashMap<>();
    private AtomicInteger devPtr = new AtomicInteger(0);
    private ThreadLocal affiliated = new ThreadLocal<>();

    private AtomicInteger numberOfDevices = new AtomicInteger(-1);

    public CudaAffinityManager() {
        super();

    }

    /**
     * This method returns deviceId for current thread.
     *
     * If no device was assigned to this thread before this call, it'll be assinged here.
     *
     * @return
     */
    @Override
    public Integer getDeviceForCurrentThread() {
        return getDeviceForThread(Thread.currentThread().getId());
    }

    /**
     * This method returns deviceId for given thread.
     *
     * If no device was assigned to this thread before this call, it'll be assinged here.
     * @param thread
     * @return
     */
    @Override
    public Integer getDeviceForThread(Thread thread) {
        return getDeviceForThread(thread.getId());
    }

    /**
     * This method returns deviceId for given thread, identified by threadId
     *
     * If no device was assigned to this thread before this call, it'll be assinged here.
     *
     * @param threadId
     * @return
     */
    @Override
    public Integer getDeviceForThread(long threadId) {
        if (getNumberOfDevices() == 1)
            return 0;

        Integer aff = affinityMap.get(threadId);

        if (aff == null) {
            Integer deviceId = getNextDevice(threadId);
            affinityMap.put(threadId, deviceId);
            affiliated.set(new AtomicBoolean(false));

            if (threadId == Thread.currentThread().getId()) {
                NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(deviceId));
                //logger.error("setDevice({}) called for thread {}", deviceId, Thread.currentThread().getName());
                affiliated.get().set(true);
            }

            return deviceId;
        } else {

            if (threadId == Thread.currentThread().getId()) {
                if (affiliated.get() == null)
                    affiliated.set(new AtomicBoolean(false));

                if (!affiliated.get().get()) {
                    NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(aff));
                    //logger.error("SCARY setDevice({}) called for thread {}", aff, threadId);
                    affiliated.get().set(true);
                    return aff;
                }
            }

            return aff;
        }
/*


        return affinityMap.get(threadId);
*/
        //return 0;
    }

    /**
     * This method pairs specified thread & device
     *
     * @param thread
     * @param deviceId
     */
    @Override
    public void attachThreadToDevice(Thread thread, Integer deviceId) {
        attachThreadToDevice(thread.getId(), deviceId);
    }

    /**
     * This method pairs specified thread & device
     *
     * @param threadId
     * @param deviceId
     */
    @Override
    public void attachThreadToDevice(long threadId, Integer deviceId) {
        List devices = new ArrayList<>(CudaEnvironment.getInstance().getConfiguration().getAvailableDevices());
        logger.debug("Manually mapping thread [{}] to device [{}], out of [{}] devices...", threadId, deviceId,
                        devices.size());
        affinityMap.put(threadId, deviceId);
    }

    /**
     * This method returns device id available. Round-robin balancing used here.
     *
     * @param threadId this parameter can be anything, it's used for logging only.
     * @return
     */
    protected Integer getNextDevice(long threadId) {
        Integer device = null;
        if (!CudaEnvironment.getInstance().getConfiguration().isForcedSingleGPU() && getNumberOfDevices() > 0) {
            // simple round-robin here
            synchronized (this) {
                device = CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().get(devPtr.getAndIncrement());

                // We check only for number of entries here, not their actual values
                if (devPtr.get() >= CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().size())
                    devPtr.set(0);

                logger.debug("Mapping thread [{}] to device [{}], out of [{}] devices...", threadId, device,
                        CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().size());
            }
        } else {
            device = CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().get(0);
            logger.debug("Single device is forced, mapping to device [{}]", device);
        }

        return device;
    }

    /**
     * This method returns number of available devices in system.
     *
     * Please note: returned value might be different from actual number of used devices.
     *
     * @return total number of devices
     */
    @Override
    public int getNumberOfDevices() {
        if (numberOfDevices.get() < 0) {
            synchronized (this) {
                if (numberOfDevices.get() < 1) {
                    numberOfDevices.set(NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices());
                }
            }
        }

        return numberOfDevices.get();
    }

    /**
     * Utility method, to associate INDArray with specific device (backend-specific)
     *
     * @param array
     */
    @Override
    public void touch(INDArray array) {
        if (array == null)
            return;

        touch(array.data());
        touch(array.shapeInfoDataBuffer());
    }

    /**
     * Utility method, to associate INDArray with specific device (backend-specific)
     *
     * @param buffer
     */
    @Override
    public void touch(DataBuffer buffer) {
        if (buffer == null)
            return;

        AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(buffer);

        if (point.isConstant()) {
            Nd4j.getConstantHandler().relocateConstantSpace(buffer);
        } else {
            AtomicAllocator.getInstance().getMemoryHandler().relocateObject(buffer);
        }
    }

    /**
     * This method replicates given INDArray, and places it to target device.
     *
     * @param deviceId target deviceId
     * @param array    INDArray to replicate
     * @return
     */
    @Override
    public synchronized INDArray replicateToDevice(Integer deviceId, INDArray array) {
        if (array == null)
            return null;

        if (array.isView())
            throw new UnsupportedOperationException("It's impossible to replicate View");

        int[] shape = array.shape();
        int[] stride = array.stride();
        int elementWiseStride = array.elementWiseStride();
        char ordering = array.ordering();
        int length = array.length();

        // we use this call to get device memory updated
        AtomicAllocator.getInstance().getPointer(array,
                        (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext());

        int currentDeviceId = getDeviceForCurrentThread();

        NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(deviceId));
        attachThreadToDevice(Thread.currentThread().getId(), deviceId);


        DataBuffer newDataBuffer = replicateToDevice(deviceId, array.data());
        DataBuffer newShapeBuffer = Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0,
                        elementWiseStride, ordering).getFirst();
        INDArray result = Nd4j.createArrayFromShapeBuffer(newDataBuffer, newShapeBuffer);

        attachThreadToDevice(Thread.currentThread().getId(), currentDeviceId);
        NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(currentDeviceId));


        return result;
    }

    /**
     * This method replicates given DataBuffer, and places it to target device.
     *
     * @param deviceId target deviceId
     * @param buffer
     * @return
     */
    @Override
    public DataBuffer replicateToDevice(Integer deviceId, DataBuffer buffer) {
        if (buffer == null)
            return null;

        int currentDeviceId = AtomicAllocator.getInstance().getDeviceId();
        if (currentDeviceId != deviceId) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(deviceId));
            Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread().getId(), deviceId);
        }

        DataBuffer dstBuffer = Nd4j.createBuffer(buffer.length(), false);
        AtomicAllocator.getInstance().memcpy(dstBuffer, buffer);

        if (currentDeviceId != deviceId) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(currentDeviceId));
            Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread().getId(), currentDeviceId);
        }

        return dstBuffer;
    }

    /**
     * This method marks given INDArray as actual in specific location (either host, device, or both)
     *
     * @param array
     * @param location
     */
    @Override
    public void tagLocation(INDArray array, Location location) {
        if (location == Location.HOST)
            AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
        else if (location == Location.DEVICE)
            AtomicAllocator.getInstance().getAllocationPoint(array).tickDeviceWrite();
        else if (location == Location.EVERYWHERE) {
            AtomicAllocator.getInstance().getAllocationPoint(array).tickDeviceWrite();
            AtomicAllocator.getInstance().getAllocationPoint(array).tickHostRead();
        }
    }

    /**
     * This method marks given DataBuffer as actual in specific location (either host, device, or both)
     *
     * @param buffer
     * @param location
     */
    @Override
    public void tagLocation(DataBuffer buffer, Location location) {
        if (location == Location.HOST)
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickHostWrite();
        else if (location == Location.DEVICE)
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
        else if (location == Location.EVERYWHERE) {
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickHostRead();
        }
    }

    @Override
    public Integer getDeviceForArray(@NonNull INDArray array) {
        return AtomicAllocator.getInstance().getDeviceId(array);
    }

    @Override
    public void unsafeSetDevice(Integer deviceId) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(deviceId));
    }

    @Override
    public void ensureLocation(INDArray array, Location location) {
        AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(array);
        switch (location) {
            case HOST: {
                AtomicAllocator.getInstance().synchronizeHostData(array);
                }
                break;
            case DEVICE:{
                AtomicAllocator.getInstance().getFlowController().synchronizeToDevice(point);
                }
                break;
            case EVERYWHERE:
            default: {
                throw new UnsupportedOperationException("Unknown location specified: " + location);
            }
        }
    }

    @Override
    public Location getActiveLocation(INDArray array) {
        AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(array);

        if (point.isActualOnDeviceSide() && point.isActualOnHostSide()) {
            return Location.EVERYWHERE;
        } else if (point.isActualOnDeviceSide()) {
            return Location.DEVICE;
        } else {
            return Location.HOST;
        }
    }

    @Override
    public boolean isCrossDeviceAccessSupported() {
        return NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable() && CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed();
    }

    @Override
    public void allowCrossDeviceAccess(boolean reallyAllow) {
        CudaEnvironment.getInstance().getConfiguration().allowCrossDeviceAccess(reallyAllow);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy