org.nd4j.linalg.cpu.nativecpu.CpuTADManager Maven / Gradle / Ivy
package org.nd4j.linalg.cpu.nativecpu;
import lombok.NonNull;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.IntBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.cache.TadDescriptor;
import org.nd4j.nativeblas.NativeOps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* @author [email protected]
*/
public class CpuTADManager implements TADManager {
private Map> cache = new ConcurrentHashMap<>();
private NativeOps nativeOps;
private ConstantHandler constantHandler;
private static Logger logger = LoggerFactory.getLogger(CpuTADManager.class);
public CpuTADManager() {
//
}
public void init(@NonNull NativeOps nativeOps, @NonNull ConstantHandler constantHandler) {
this.nativeOps = nativeOps;
this.constantHandler = constantHandler;
}
/**
* This method removes all cached shape buffers
*/
@Override
public void purgeBuffers() {
cache = new ConcurrentHashMap<>();
}
@Override
public Pair getTADOnlyShapeInfo(INDArray array, int[] dimension) {
if (dimension == null || dimension[0] == Integer.MAX_VALUE) {
return new Pair(array.shapeInfoDataBuffer(), null);
} else {
TadDescriptor descriptor = new TadDescriptor(array, dimension);
if (!cache.containsKey(descriptor)) {
int dimensionLength = dimension.length;
int targetRank = array.rank(); ///Math.max(array.rank() - dimensionLength, 2);
int offsetLength = 0;
int tadLength = 1;
for (int i = 0; i < dimensionLength; i++) {
tadLength *= array.shape()[dimension[i]];
}
offsetLength = array.length() / tadLength;
DataBuffer outputBuffer = new IntBuffer(targetRank * 2 + 4);
DataBuffer offsetsBuffer = new IntBuffer(offsetLength);
DataBuffer dimensionBuffer = constantHandler.getConstantBuffer(dimension);
Pointer dimensionPointer = dimensionBuffer.addressPointer();
Pointer xShapeInfo = array.shapeInfoDataBuffer().addressPointer();
Pointer targetPointer = outputBuffer.addressPointer();
Pointer offsetsPointer = offsetsBuffer.addressPointer();
nativeOps.tadOnlyShapeInfo((IntPointer)xShapeInfo, (IntPointer)dimensionPointer, dimension.length, (IntPointer)targetPointer, (IntPointer)offsetsPointer);
Pair pair = new Pair(outputBuffer, offsetsBuffer);
cache.put(descriptor, pair);
return pair;
}
return cache.get(descriptor);
}
}
}