All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.jita.allocator.tad.BasicTADManager Maven / Gradle / Ivy

The newest version!
package org.nd4j.jita.allocator.tad;

import org.nd4j.linalg.primitives.Pair;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicLong;

/**
 * @author [email protected]
 */
public class BasicTADManager implements TADManager {
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private static Logger logger = LoggerFactory.getLogger(BasicTADManager.class);
    protected AtomicLong bytes = new AtomicLong(0);

    @Override
    public Pair getTADOnlyShapeInfo(INDArray array, int[] dimension) {
        if (dimension != null && dimension.length > 1)
            Arrays.sort(dimension);

        if (dimension == null)
            dimension = new int[] {Integer.MAX_VALUE};

        boolean isScalar = dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE);

        // FIXME: this is fast triage, remove it later
        int targetRank = isScalar ? 2 : array.rank(); //dimensionLength <= 1 ? 2 : dimensionLength;
        long offsetLength = 0;
        long tadLength = 1;

        if(!isScalar)
            for (int i = 0; i < dimension.length; i++) {
                tadLength *= array.shape()[dimension[i]];
            }

        if(!isScalar)
            offsetLength = array.lengthLong() / tadLength;
        else
            offsetLength = 1;
        //     logger.info("Original shape info before TAD: {}", array.shapeInfoDataBuffer());
        //    logger.info("dimension: {}, tadLength: {}, offsetLength for TAD: {}", Arrays.toString(dimension),tadLength, offsetLength);

        DataBuffer outputBuffer = new CudaIntDataBuffer(targetRank * 2 + 4);
        DataBuffer offsetsBuffer = new CudaLongDataBuffer(offsetLength);

        AtomicAllocator.getInstance().getAllocationPoint(outputBuffer).tickHostWrite();
        AtomicAllocator.getInstance().getAllocationPoint(offsetsBuffer).tickHostWrite();

        DataBuffer dimensionBuffer = AtomicAllocator.getInstance().getConstantBuffer(dimension);
        Pointer dimensionPointer = AtomicAllocator.getInstance().getHostPointer(dimensionBuffer);

        Pointer xShapeInfo = AddressRetriever.retrieveHostPointer(array.shapeInfoDataBuffer());
        Pointer targetPointer = AddressRetriever.retrieveHostPointer(outputBuffer);
        Pointer offsetsPointer = AddressRetriever.retrieveHostPointer(offsetsBuffer);
        if(!isScalar)
            nativeOps.tadOnlyShapeInfo((IntPointer) xShapeInfo, (IntPointer) dimensionPointer, dimension.length,
                    (IntPointer) targetPointer, new LongPointerWrapper(offsetsPointer));

        else  {
            outputBuffer.put(0,2);
            outputBuffer.put(1,1);
            outputBuffer.put(2,1);
            outputBuffer.put(3,1);
            outputBuffer.put(4,1);
            outputBuffer.put(5,0);
            outputBuffer.put(6,0);
            outputBuffer.put(7,99);

        }

        AtomicAllocator.getInstance().getAllocationPoint(outputBuffer).tickHostWrite();
        AtomicAllocator.getInstance().getAllocationPoint(offsetsBuffer).tickHostWrite();

        //   logger.info("TAD shapeInfo after construction: {}", Arrays.toString(TadDescriptor.dataBufferToArray(outputBuffer)));
        // now we need to copy this buffer to either device global memory or device cache

        return new Pair<>(outputBuffer, offsetsBuffer);

    }

    /**
     * This method removes all cached shape buffers
     */
    @Override
    public void purgeBuffers() {
        // no-op
    }

    @Override
    public long getCachedBytes() {
        return bytes.get();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy