All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.cpu.nativecpu.CpuTADManager Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.cpu.nativecpu;

import lombok.NonNull;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.IntBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.cache.TadDescriptor;
import org.nd4j.nativeblas.NativeOps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author [email protected]
 */
public class CpuTADManager implements TADManager {
    private Map> cache = new ConcurrentHashMap<>();
    private NativeOps nativeOps;
    private ConstantHandler constantHandler;
    private static Logger logger = LoggerFactory.getLogger(CpuTADManager.class);

    public CpuTADManager() {
        //
    }

    public void init(@NonNull NativeOps nativeOps, @NonNull ConstantHandler constantHandler) {
        this.nativeOps = nativeOps;
        this.constantHandler = constantHandler;
    }

    /**
     * This method removes all cached shape buffers
     */
    @Override
    public void purgeBuffers() {
        cache = new ConcurrentHashMap<>();
    }

    @Override
    public Pair getTADOnlyShapeInfo(INDArray array, int[] dimension) {
        if (dimension == null || dimension[0] == Integer.MAX_VALUE) {
            return new Pair(array.shapeInfoDataBuffer(), null);
        } else {
            TadDescriptor descriptor = new TadDescriptor(array, dimension);

            if (!cache.containsKey(descriptor)) {
                int dimensionLength = dimension.length;

                int targetRank = array.rank(); ///Math.max(array.rank() - dimensionLength, 2);
                int offsetLength = 0;
                int tadLength = 1;
                for (int i = 0; i < dimensionLength; i++) {
                    tadLength *= array.shape()[dimension[i]];
                }

                offsetLength = array.length() / tadLength;

                DataBuffer outputBuffer = new IntBuffer(targetRank * 2 + 4);
                DataBuffer offsetsBuffer = new IntBuffer(offsetLength);

                DataBuffer dimensionBuffer = constantHandler.getConstantBuffer(dimension);
                Pointer dimensionPointer = dimensionBuffer.addressPointer();

                Pointer xShapeInfo = array.shapeInfoDataBuffer().addressPointer();
                Pointer targetPointer = outputBuffer.addressPointer();
                Pointer offsetsPointer = offsetsBuffer.addressPointer();

                nativeOps.tadOnlyShapeInfo((IntPointer)xShapeInfo, (IntPointer)dimensionPointer, dimension.length, (IntPointer)targetPointer, (IntPointer)offsetsPointer);

                Pair pair = new Pair(outputBuffer, offsetsBuffer);
                cache.put(descriptor, pair);
                return pair;
            }

            return cache.get(descriptor);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy