org.nd4j.linalg.cache.TADManager Maven / Gradle / Ivy
package org.nd4j.linalg.cache;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* This interface describes TAD caching.
*
* While working with tensors, all operations are happening on some demensions, and since training tasks are repetitive, we can pre-calculate TAD shapes and offsets once, and reuse them later during whole training process.
*
*
*
* @author [email protected]
*/
public interface TADManager {
/**
* This method returns TAD shapeInfo and all offets for specified tensor and dimensions.
*
* @param array Tensor for TAD precalculation
* @param dimension
* @return
*/
Pair getTADOnlyShapeInfo(INDArray array, int[] dimension);
/**
* This method removes all cached shape buffers
*/
void purgeBuffers();
}