org.nd4j.linalg.jcublas.fft.JcudaFft Maven / Gradle / Ivy
package org.nd4j.linalg.jcublas.fft;
import jcuda.jcufft.JCufft;
import jcuda.jcufft.cufftHandle;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.fft.DefaultFFTInstance;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.fft.ops.JCudaVectorFFT;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
*
* Uses JCufft rather than
* the built in FFTs
*
* @author Adam Gibson
*/
public class JcudaFft extends DefaultFFTInstance {
//map of thread names to handles (one fft handle per thread)
private final Map handles;
public JcudaFft() {
this.handles = new ConcurrentHashMap<>();
Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
@Override
public void run() {
for (cufftHandle handle : handles.values()) {
JCufft.cufftDestroy(handle);
}
}
}));
JCufft.setExceptionsEnabled(true);
}
/**
* Get the handle for the current thread
* @return the handle for the current thread
*/
public cufftHandle getHandle() {
cufftHandle handle = handles.get(Thread.currentThread().getName());
if(handle == null) {
handle = new cufftHandle();
JCufft.cufftCreate(handle);
handles.put(Thread.currentThread().getName(), handle);
}
return handle;
}
@Override
protected Op getFftOp(INDArray arr, int n) {
return new JCudaVectorFFT(arr,n);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy