All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.jita.constant.ProtectedCudaConstantHandler Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.jita.constant;

import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.AllocationsTracker;
import org.nd4j.linalg.api.memory.enums.AllocationKind;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.cache.ArrayDescriptor;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.*;
import org.nd4j.linalg.api.memory.MemcpyDirection;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import lombok.extern.slf4j.Slf4j;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicLong;

/**
 * Created by raver on 08.06.2016.
 */
@Slf4j
public class ProtectedCudaConstantHandler implements ConstantHandler {
    private static ProtectedCudaConstantHandler ourInstance = new ProtectedCudaConstantHandler();

    protected Map constantOffsets = new HashMap<>();
    protected Map deviceLocks = new ConcurrentHashMap<>();

    protected Map> buffersCache = new HashMap<>();
    protected Map deviceAddresses = new HashMap<>();
    protected AtomicLong bytes = new AtomicLong(0);
    protected FlowController flowController;

    protected static final ConstantProtector protector = ConstantProtector.getInstance();

    private static Logger logger = LoggerFactory.getLogger(ProtectedCudaConstantHandler.class);

    private static final int MAX_CONSTANT_LENGTH = 49152;
    private static final int MAX_BUFFER_LENGTH = 272;

    protected Semaphore lock = new Semaphore(1);
    private boolean resetHappened = false;


    public static ProtectedCudaConstantHandler getInstance() {
        return ourInstance;
    }

    private ProtectedCudaConstantHandler() {}

    /**
     * This method removes all cached constants
     */
    @Override
    public void purgeConstants() {
        buffersCache = new HashMap<>();

        protector.purgeProtector();

        resetHappened = true;
        logger.info("Resetting Constants...");

        for (Integer device : constantOffsets.keySet()) {
            constantOffsets.get(device).set(0);
            buffersCache.put(device, new ConcurrentHashMap());
        }
    }

    /**
     * Method suited for debug purposes only
     *
     * @return
     */
    protected int amountOfEntries(int deviceId) {
        ensureMaps(deviceId);
        return buffersCache.get(0).size();
    }

    /**
     * This method moves specified dataBuffer to CUDA constant memory space.
     *
     * PLEASE NOTE: CUDA constant memory is limited to 48KB per device.
     *
     * @param dataBuffer
     * @return
     */
    @Override
    public synchronized long moveToConstantSpace(DataBuffer dataBuffer) {
        if (1 > 0)
            throw new RuntimeException("This code shouldn't be called, ever");

        // now, we move things to constant memory
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        ensureMaps(deviceId);

        AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);

        long requiredMemoryBytes = point.getNumberOfBytes();
        val originalBytes = requiredMemoryBytes;
        requiredMemoryBytes += 8 - (requiredMemoryBytes % 8);

        val div = requiredMemoryBytes / 4;
        if (div % 2 != 0)
            requiredMemoryBytes += 4;

        //logger.info("shape: " + point.getShape());
        // and release device memory :)

        AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, deviceId, requiredMemoryBytes);

        long currentOffset = constantOffsets.get(deviceId).get();
        val context = AtomicAllocator.getInstance().getDeviceContext();
        if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) {
            if (point.getAllocationStatus() == AllocationStatus.HOST
                            && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
                //AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
                throw new UnsupportedOperationException("Pew-pew");
            }

            val profD = PerformanceTracker.getInstance().helperStartTransaction();

            if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) {
                throw new ND4JIllegalStateException("memcpyAsync failed");
            }
            flowController.commitTransfer(context.getSpecialStream());

            PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);

            point.setConstant(true);
            point.tickDeviceWrite();
            point.tickHostRead();
            point.setDeviceId(deviceId);

            protector.persistDataBuffer(dataBuffer);

            return 0;
        }

        long bytes = requiredMemoryBytes;
        currentOffset = constantOffsets.get(deviceId).getAndAdd(bytes);

        if (currentOffset >= MAX_CONSTANT_LENGTH) {
            if (point.getAllocationStatus() == AllocationStatus.HOST
                            && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
                //AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
                throw new UnsupportedOperationException("Pew-pew");
            }

            val profD = PerformanceTracker.getInstance().helperStartTransaction();

            if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) {
                throw new ND4JIllegalStateException("memcpyAsync failed");
            }
            flowController.commitTransfer(context.getSpecialStream());

            PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);

            point.setConstant(true);
            point.tickDeviceWrite();
            point.tickHostRead();
            point.setDeviceId(deviceId);

            protector.persistDataBuffer(dataBuffer);

            return 0;
        }



        NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getHostPointer(), originalBytes, 1, context.getSpecialStream());
        flowController.commitTransfer(context.getSpecialStream());

        long cAddr = deviceAddresses.get(deviceId).address() + currentOffset;

        //if (resetHappened)
        //    logger.info("copying to constant: {}, bufferLength: {}, bufferDtype: {}, currentOffset: {}, currentAddres: {}", requiredMemoryBytes, dataBuffer.length(), dataBuffer.dataType(), currentOffset, cAddr);

        point.setAllocationStatus(AllocationStatus.CONSTANT);
        //point.setDevicePointer(new CudaPointer(cAddr));
        if (1 > 0)
            throw new UnsupportedOperationException("Pew-pew");

        point.setConstant(true);
        point.tickDeviceWrite();
        point.setDeviceId(deviceId);
        point.tickHostRead();


        protector.persistDataBuffer(dataBuffer);

        return cAddr;
    }

    /**
     * PLEASE NOTE: This method implementation is hardware-dependant.
     * PLEASE NOTE: This method does NOT allow concurrent use of any array
     *
     * @param dataBuffer
     * @return
     */
    @Override
    public DataBuffer relocateConstantSpace(DataBuffer dataBuffer) {
        // we always assume that data is sync, and valid on host side
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        ensureMaps(deviceId);

        if (dataBuffer instanceof CudaIntDataBuffer) {
            int[] data = dataBuffer.asInt();
            return getConstantBuffer(data, DataType.INT);
        } else if (dataBuffer instanceof CudaFloatDataBuffer) {
            float[] data = dataBuffer.asFloat();
            return getConstantBuffer(data, DataType.FLOAT);
        } else if (dataBuffer instanceof CudaDoubleDataBuffer) {
            double[] data = dataBuffer.asDouble();
            return getConstantBuffer(data, DataType.DOUBLE);
        } else if (dataBuffer instanceof CudaHalfDataBuffer) {
            float[] data = dataBuffer.asFloat();
            return getConstantBuffer(data, DataType.HALF);
        } else if (dataBuffer instanceof CudaLongDataBuffer) {
            long[] data = dataBuffer.asLong();
            return getConstantBuffer(data, DataType.LONG);
        }

        throw new IllegalStateException("Unknown CudaDataBuffer opType");
    }

    private void ensureMaps(Integer deviceId) {
        if (!buffersCache.containsKey(deviceId)) {
            if (flowController == null)
                flowController = AtomicAllocator.getInstance().getFlowController();

            try {
                synchronized (this) {
                    if (!buffersCache.containsKey(deviceId)) {

                        // TODO: this op call should be checked
                        //nativeOps.setDevice(new CudaPointer(deviceId));

                        buffersCache.put(deviceId, new ConcurrentHashMap());
                        constantOffsets.put(deviceId, new AtomicLong(0));
                        deviceLocks.put(deviceId, new Semaphore(1));

                        Pointer cAddr = NativeOpsHolder.getInstance().getDeviceNativeOps().getConstantSpace();
                        //                    logger.info("constant pointer: {}", cAddr.address() );

                        deviceAddresses.put(deviceId, cAddr);
                    }
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    /**
     * This method returns DataBuffer with contant equal to input array.
     *
     * PLEASE NOTE: This method assumes that you'll never ever change values within result DataBuffer
     *
     * @param array
     * @return
     */
    @Override
    public DataBuffer getConstantBuffer(int[] array, DataType type) {
        return Nd4j.getExecutioner().createConstantBuffer(array, type);
    }

    /**
     * This method returns DataBuffer with contant equal to input array.
     *
     * PLEASE NOTE: This method assumes that you'll never ever change values within result DataBuffer
     *
     * @param array
     * @return
     */
    @Override
    public DataBuffer getConstantBuffer(float[] array, DataType type) {
        return Nd4j.getExecutioner().createConstantBuffer(array, type);
    }

    /**
     * This method returns DataBuffer with contant equal to input array.
     *
     * PLEASE NOTE: This method assumes that you'll never ever change values within result DataBuffer
     *
     * @param array
     * @return
     */
    @Override
    public DataBuffer getConstantBuffer(double[] array, DataType type) {
        return Nd4j.getExecutioner().createConstantBuffer(array, type);
        /*
        ArrayDescriptor descriptor = new ArrayDescriptor(array, type);

        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();

        ensureMaps(deviceId);

        if (!buffersCache.get(deviceId).containsKey(descriptor)) {
            // we create new databuffer
            //logger.info("Creating new constant buffer...");
            DataBuffer buffer = Nd4j.createTypedBufferDetached(array, type);

            if (constantOffsets.get(deviceId).get() + (array.length * Nd4j.sizeOfDataType()) < MAX_CONSTANT_LENGTH) {
                buffer.setConstant(true);
                // now we move data to constant memory, and keep happy
                moveToConstantSpace(buffer);

                buffersCache.get(deviceId).put(descriptor, buffer);

                bytes.addAndGet(array.length * Nd4j.sizeOfDataType());
            }
            return buffer;
        } //else logger.info("Reusing constant buffer...");

        return buffersCache.get(deviceId).get(descriptor);
         */
    }

    @Override
    public DataBuffer getConstantBuffer(long[] array, DataType type) {
        return Nd4j.getExecutioner().createConstantBuffer(array, type);
        /*
        //  logger.info("getConstantBuffer(int[]) called");
        ArrayDescriptor descriptor = new ArrayDescriptor(array, type);

        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();

        ensureMaps(deviceId);

        if (!buffersCache.get(deviceId).containsKey(descriptor)) {
            // we create new databuffer
            //logger.info("Creating new constant buffer...");
            DataBuffer buffer = Nd4j.createTypedBufferDetached(array, type);

            if (constantOffsets.get(deviceId).get() + (array.length * 8) < MAX_CONSTANT_LENGTH) {
                buffer.setConstant(true);
                // now we move data to constant memory, and keep happy
                moveToConstantSpace(buffer);

                buffersCache.get(deviceId).put(descriptor, buffer);

                bytes.addAndGet(array.length * 8);
            }
            return buffer;
        } //else logger.info("Reusing constant buffer...");

        return buffersCache.get(deviceId).get(descriptor);
         */
    }

    @Override
    public DataBuffer getConstantBuffer(boolean[] array, DataType dataType) {
        return getConstantBuffer(ArrayUtil.toLongs(array), dataType);
    }

    @Override
    public long getCachedBytes() {
        return bytes.get();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy