org.nd4j.linalg.jcublas.SimpleJCublas Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/
package org.nd4j.linalg.jcublas;
import jcuda.*;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasOperation;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import jcuda.runtime.cudaError;
import jcuda.runtime.cudaMemcpyKind;
import jcuda.utils.KernelLauncher;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.CopyOp;
import org.nd4j.linalg.factory.DataTypeValidation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.kernel.KernelFunctionLoader;
import org.nd4j.linalg.jcublas.util.PointerUtil;
import org.nd4j.linalg.util.LinearUtil;
import javax.naming.Context;
/**
* Simple abstraction for jcublas operations
*
* @author mjk
* @author Adam Gibson
*/
public class SimpleJCublas {
private static boolean init = false;
static {
init();
}
/**
* Assert that the data buffer for each ndarray
* is a cuda buffer
* @param buffer the arrays to tests
*/
public static void assertCudaBuffer(INDArray... buffer) {
for (INDArray b1 : buffer)
if (!(b1.data() instanceof JCudaBuffer))
throw new IllegalArgumentException("Unable to allocate pointer for buffer of type " + buffer.getClass().toString());
}
/**
* Assert that the data buffer for each ndarray
* is a cuda buffer
* @param buffer the arrays to tests
*/
public static void assertCudaBuffer(DataBuffer... buffer) {
for (DataBuffer b1 : buffer)
if (!(b1 instanceof JCudaBuffer))
throw new IllegalArgumentException("Unable to allocate pointer for buffer of type " + buffer.getClass().toString());
}
/**
* Initialize JCublas2. Only called once
*/
public static void init() {
if (init)
return;
JCublas2.setExceptionsEnabled(true);
JCudaDriver.setExceptionsEnabled(true);
JCuda.setExceptionsEnabled(true);
try {
KernelFunctionLoader.getInstance().load();
} catch (Exception e) {
throw new RuntimeException(e);
}
// Check if the device supports mapped host memory
cudaDeviceProp deviceProperties = new cudaDeviceProp();
JCuda.cudaGetDeviceProperties(deviceProperties, 0);
if (deviceProperties.canMapHostMemory == 0) {
System.err.println("This device can not map host memory");
System.err.println(deviceProperties.toFormattedString());
return;
}
init = true;
}
/**
* Sync the device
*/
public static void sync() {
JCuda.cudaDeviceSynchronize();
ContextHolder.syncStream();
}
/**
* General matrix vector multiplication
*
* @param A
* @param B
* @param C
* @param alpha
* @param beta
* @return
*/
public static INDArray gemv(INDArray A, INDArray B, INDArray C, double alpha, double beta) {
DataTypeValidation.assertDouble(A, B, C);
assertCudaBuffer(A.data(), B.data(), C.data());
sync();
CublasPointer cAPointer = new CublasPointer(A.offset() > 0 ? A.ravel() : A);
CublasPointer cBPointer = new CublasPointer(B.offset() > 0 ? B.ravel() : B);
CublasPointer cCPointer = new CublasPointer(C);
JCublas2.cublasDgemv(
ContextHolder.getInstance().getHandle(),
cublasOperation.CUBLAS_OP_N,
A.rows(),
A.columns(),
Pointer.to(new double[]{alpha}),
cAPointer.getDevicePointer(),
A.rows(),
cBPointer.getDevicePointer(),
B.majorStride(),
Pointer.to(new double[]{beta}),
cCPointer.getDevicePointer(),
C.majorStride());
cCPointer.copyToHost();
releaseCublasPointers(cAPointer, cBPointer, cCPointer);
sync();
return C;
}
/**
* General matrix vector multiplication
*
* @param A
* @param B
* @param C
* @param alpha
* @param beta
* @return
*/
public static INDArray gemv(INDArray A, INDArray B, INDArray C, float alpha, float beta) {
DataTypeValidation.assertFloat(A, B, C);
CublasPointer cAPointer = new CublasPointer(A.offset() > 0 ? A.ravel() : A);
CublasPointer cBPointer = new CublasPointer(B.offset() > 0 ? B.ravel() : B);
CublasPointer cCPointer = new CublasPointer(C);
sync();
JCublas2.cublasSgemv(
ContextHolder.getInstance().getHandle(),
cublasOperation.CUBLAS_OP_N,
A.rows(),
A.columns(),
Pointer.to(new float[]{alpha}),
cAPointer.getDevicePointer(),
A.size(0),
cBPointer.getDevicePointer(),
B.majorStride(),
Pointer.to(new float[]{beta}),
cCPointer.getDevicePointer(),
C.majorStride());
sync();
cCPointer.copyToHost();
releaseCublasPointers(cCPointer,cAPointer,cBPointer);
return C;
}
/**
* General matrix vector
*
* @param A
* @param B
* @param a
* @param C
* @param b
* @return
*/
public static IComplexNDArray gemv(IComplexNDArray A, IComplexNDArray B, IComplexDouble a, IComplexNDArray C
, IComplexDouble b) {
DataTypeValidation.assertSameDataType(A, B, C);
sync();
CublasPointer cAPointer = new CublasPointer(A.ravel());
CublasPointer cBPointer = new CublasPointer(B);
CublasPointer cCPointer = new CublasPointer(C);
cuDoubleComplex alpha = cuDoubleComplex.cuCmplx(a.realComponent().doubleValue(), b.imaginaryComponent().doubleValue());
cuDoubleComplex beta = cuDoubleComplex.cuCmplx(b.realComponent().doubleValue(), b.imaginaryComponent().doubleValue());
JCublas2.cublasZgemv(
ContextHolder.getInstance().getHandle(),
cublasOperation.CUBLAS_OP_N, //trans
A.rows(), // m
A.rows(), // n
PointerUtil.getPointer(alpha),
cAPointer.getDevicePointer(), // A
A.size(0), // lda
cBPointer.getDevicePointer(), // x
B.majorStride() / 2, // ldb
PointerUtil.getPointer(beta), // beta
cCPointer.getDevicePointer(), // ydoin
C.majorStride() / 2); // ldc
sync();
cCPointer.copyToHost();
releaseCublasPointers(cAPointer,cBPointer,cCPointer);
return C;
}
/**
* General matrix vector
*
* @param A
* @param B
* @param a
* @param C
* @param b
* @return
*/
public static IComplexNDArray gemv(IComplexNDArray A, IComplexNDArray B, IComplexFloat a, IComplexNDArray C
, IComplexFloat b) {
DataTypeValidation.assertFloat(A, B, C);
assertCudaBuffer(A, B, C);
sync();
try(
CublasPointer cAPointer = new CublasPointer(A.offset() > 0 ? A.ravel() : A);
CublasPointer cBPointer = new CublasPointer(B.offset() > 0 ? B.ravel() : B);
CublasPointer cCPointer = new CublasPointer(C)) {
cuComplex alpha = cuComplex.cuCmplx(a.realComponent().floatValue(), b.imaginaryComponent().floatValue());
cuComplex beta = cuComplex.cuCmplx(b.realComponent().floatValue(), b.imaginaryComponent().floatValue());
JCublas2.cublasCgemv(
ContextHolder.getInstance().getHandle(),
cublasOperation.CUBLAS_OP_N, //trans
A.rows(), // m
A.columns(), // n
PointerUtil.getPointer(alpha),
cAPointer.getDevicePointer(), // A
A.size(0), // lda
cBPointer.getDevicePointer(), // x
B.majorStride() / 2, // ldb
PointerUtil.getPointer(beta), // beta
cCPointer.getDevicePointer(), // y
C.majorStride() / 2); // ldc
sync();
cCPointer.copyToHost();
releaseCublasPointers(cAPointer,cBPointer,cCPointer);
}catch(Exception e) {
throw new RuntimeException(e);
}
return C;
}
/**
* General matrix multiply
*
* @param A
* @param B
* @param a
* @param C
* @param b
* @return
*/
public static IComplexNDArray gemm(IComplexNDArray A, IComplexNDArray B, IComplexDouble a, IComplexNDArray C
, IComplexDouble b) {
DataTypeValidation.assertSameDataType(A, B, C);
sync();
CublasPointer cAPointer = new CublasPointer(A.offset() > 0 ? A.ravel() : A);
CublasPointer cBPointer = new CublasPointer(B.offset() > 0 ? B.ravel() : B);
CublasPointer cCPointer = new CublasPointer(C);
cuDoubleComplex alpha = cuDoubleComplex.cuCmplx(a.realComponent().doubleValue(), b.imaginaryComponent().doubleValue());
cuDoubleComplex beta = cuDoubleComplex.cuCmplx(b.realComponent().doubleValue(), b.imaginaryComponent().doubleValue());
JCublas2.cublasZgemm(
ContextHolder.getInstance().getHandle(),
cublasOperation.CUBLAS_OP_N, //trans
cublasOperation.CUBLAS_OP_N,
C.rows(), // m
C.columns(), // n
A.columns(), //k,
PointerUtil.getPointer(alpha),
cAPointer.getDevicePointer(), // A
A.size(0), // lda
cBPointer.getDevicePointer(), // x
B.size(0), // ldb
PointerUtil.getPointer(beta), // beta
cCPointer.getDevicePointer(), // y
C.size(0)); // ldc
sync();
cCPointer.copyToHost();
releaseCublasPointers(cAPointer,cBPointer,cCPointer);
return C;
}
/**
* General matrix multiply
*
* @param A
* @param B
* @param a
* @param C
* @param b
* @return
*/
public static IComplexNDArray gemm(IComplexNDArray A, IComplexNDArray B, IComplexFloat a, IComplexNDArray C
, IComplexFloat b) {
DataTypeValidation.assertFloat(A, B, C);
sync();
cuComplex alpha = cuComplex.cuCmplx(a.realComponent().floatValue(), b.imaginaryComponent().floatValue());
cuComplex beta = cuComplex.cuCmplx(b.realComponent().floatValue(), b.imaginaryComponent().floatValue());
//custom striding for blas doesn't work
if(A.offset() > 0) {
CublasPointer cAPointer = new CublasPointer(A.ravel());
CublasPointer cBPointer = new CublasPointer(B);
CublasPointer cCPointer = new CublasPointer(C);
JCublas2.cublasCgemm(
ContextHolder.getInstance().getHandle(),
cublasOperation.CUBLAS_OP_N, //trans
cublasOperation.CUBLAS_OP_N,
C.rows(), // m
C.columns(), // n
A.columns(), //k,
PointerUtil.getPointer(alpha),
cAPointer.getDevicePointer(), // A
A.rows(), // lda
cBPointer.getDevicePointer(), // x
B.rows(), // ldb
PointerUtil.getPointer(beta), // beta
cCPointer.getDevicePointer(), // y
C.rows()); // ldc
sync();
cCPointer.copyToHost();
releaseCublasPointers(cAPointer,cBPointer,cCPointer);
}
else {
CublasPointer cAPointer = new CublasPointer(A);
CublasPointer cBPointer = new CublasPointer(B);
CublasPointer cCPointer = new CublasPointer(C);
JCublas2.cublasCgemm(
ContextHolder.getInstance().getHandle(),
cublasOperation.CUBLAS_OP_N, //trans
cublasOperation.CUBLAS_OP_N,
C.rows(), // m
C.columns(), // n
A.columns(), //k,
PointerUtil.getPointer(alpha),
cAPointer.getDevicePointer(), // A
A.rows(), // lda
cBPointer.getDevicePointer(), // x
B.rows(), // ldb
PointerUtil.getPointer(beta), // beta
cCPointer.getDevicePointer(), // y
C.rows()); // ldc
sync();
cCPointer.copyToHost();
releaseCublasPointers(cAPointer,cBPointer,cCPointer);
}
return C;
}
/**
* General matrix multiply
*
* @param A
* @param B
* @param C
* @param alpha
* @param beta
* @return
*/
public static INDArray gemm(INDArray A, INDArray B, INDArray C,
double alpha, double beta) {
DataTypeValidation.assertDouble(A, B, C);
sync();
JCublasNDArray cA = (JCublasNDArray) A;
JCublasNDArray cB = (JCublasNDArray) B;
JCublasNDArray cC = (JCublasNDArray) C;
CublasPointer cAPointer = new CublasPointer(A);
CublasPointer cBPointer = new CublasPointer(B);
CublasPointer cCPointer = new CublasPointer(C);
JCublas2.cublasDgemm(
ContextHolder.getInstance().getHandle(),
cublasOperation.CUBLAS_OP_N, //trans
cublasOperation.CUBLAS_OP_N,
C.rows(), // m
C.columns(), // n
A.columns(), //k,
Pointer.to(new double[]{alpha}),
cAPointer.getDevicePointer(), // A
A.rows(), // lda
cBPointer.getDevicePointer(), // x
B.rows(), // ldb
Pointer.to(new double[]{beta}),
cCPointer.getDevicePointer(), // y
C.rows()); // incy
sync();
cCPointer.copyToHost();
releaseCublasPointers(cAPointer, cBPointer, cCPointer);
return C;
}
/**
* General matrix multiply
*
* @param A
* @param B
* @param C
* @param alpha
* @param beta
* @return
*/
public static INDArray gemm(INDArray A, INDArray B, INDArray C,
float alpha, float beta) {
DataTypeValidation.assertFloat(A, B, C);
sync();
CublasPointer cAPointer = new CublasPointer(A);
CublasPointer cBPointer = new CublasPointer(B);
CublasPointer cCPointer = new CublasPointer(C);
JCublas2.cublasSgemm(
ContextHolder.getInstance().getHandle(),
cublasOperation.CUBLAS_OP_N, //trans
cublasOperation.CUBLAS_OP_N,
C.rows(), // m
C.columns(), // n
A.columns(), //k,
Pointer.to(new float[]{alpha}),
cAPointer.getDevicePointer(), // A
A.rows(), // lda
cBPointer.getDevicePointer(), // x
B.rows(), // ldb
Pointer.to(new float[]{beta}),
cCPointer.getDevicePointer(), // y
C.rows()); // incy
sync();
cCPointer.copyToHost();
releaseCublasPointers(cAPointer,cBPointer,cCPointer);
return C;
}
/**
* Calculate the 2 norm of the ndarray.
* Note that this is a standin for
* no complex ndarray. It will treat this as a normal ndarray
* with a stride of 2.
*
* @param A the ndarray to calculate the norm2 of
* @return the ndarray to calculate the norm2 of
*/
public static double nrm2(IComplexNDArray A) {
sync();
CublasPointer cAPointer = new CublasPointer(A);
if (A.data().dataType() == DataBuffer.Type.FLOAT) {
float[] ret = new float[1];
Pointer result = Pointer.to(ret);
JCublas2.cublasSnrm2(
ContextHolder.getInstance().getHandle()
,A.length()
,cAPointer.getDevicePointer(),
2
, result);
return ret[0];
} else {
double[] ret = new double[1];
Pointer result = Pointer.to(ret);
JCublas2.cublasDnrm2(
ContextHolder.getInstance().getHandle()
,A.length(),
cAPointer.getDevicePointer()
,2
, result);
return ret[0];
}
}
/**
* Copy x to y
*
* @param x the origin
* @param y the destination
*/
public static void copy(IComplexNDArray x, IComplexNDArray y) {
DataTypeValidation.assertSameDataType(x, y);
Nd4j.getExecutioner().exec(new CopyOp(x, y, y, x.length()));
}
/**
* Return the index of the max in the given ndarray
*
* @param x the ndarray to ge the max for
* @return the max index of the given ndarray
*/
public static int iamax(IComplexNDArray x) {
sync();
CublasPointer xCPointer = new CublasPointer(x);
if (x.data().dataType() == DataBuffer.Type.FLOAT) {
cuComplex complex = cuComplex.cuCmplx(0,0);
Pointer resultPointer = PointerUtil.getPointer(complex);
JCublas2.cublasIsamax(ContextHolder.getInstance().getHandle(),x.length(), xCPointer.getDevicePointer(), 1,resultPointer);
return (int) complex.x - 1;
} else {
cuDoubleComplex complex = cuDoubleComplex.cuCmplx(0,0);
Pointer resultPointer = PointerUtil.getPointer(complex);
JCublas2.cublasIzamax(ContextHolder.getInstance().getHandle(),x.length(), xCPointer.getDevicePointer(), 1,resultPointer);
return (int) complex.x;
}
}
/**
* @param x
* @return
*/
public static float asum(IComplexNDArray x) {
CublasPointer xCPointer = new CublasPointer(x);
float[] ret = new float[1];
Pointer result = Pointer.to(ret);
JCublas2.cublasScasum(ContextHolder.getInstance().getHandle(),x.length(), xCPointer.getDevicePointer(), 1, result);
return ret[0];
}
/**
* Swap the elements in each ndarray
*
* @param x
* @param y
*/
public static void swap(INDArray x, INDArray y) {
DataTypeValidation.assertSameDataType(x, y);
CublasPointer xCPointer = new CublasPointer(x);
CublasPointer yCPointer = new CublasPointer(y);
sync();
if (x.data().dataType() == DataBuffer.Type.FLOAT) {
JCublas2.cublasSswap(
ContextHolder.getInstance().getHandle(),
x.length(),
xCPointer.getDevicePointer(),
1,
yCPointer.getDevicePointer(),
1);
} else {
JCublas2.cublasDswap(
ContextHolder.getInstance().getHandle(),
x.length(),
xCPointer.getDevicePointer(),
1,
yCPointer.getDevicePointer(),
1);
}
sync();
}
/**
* @param x
* @return
*/
public static double asum(INDArray x) {
CublasPointer xCPointer = new CublasPointer(x);
Pointer result;
if (x.data().dataType() == DataBuffer.Type.FLOAT) {
float[] ret = new float[1];
result = Pointer.to(ret);
JCublas2.cublasSasum(ContextHolder.getInstance().getHandle(), x.length(), xCPointer.getDevicePointer(), 1, result);
return ret[0];
} else {
double[] ret = new double[1];
result = Pointer.to(ret);
JCublas2.cublasDasum(ContextHolder.getInstance().getHandle(), x.length(), xCPointer.getDevicePointer(), 1, result);
return ret[0];
}
}
/**
* Returns the norm2 of the given ndarray
*
* @param x
* @return
*/
public static double nrm2(INDArray x) {
Pointer result;
if (x.data().dataType() == DataBuffer.Type.FLOAT) {
CublasPointer xCPointer = new CublasPointer(x);
float[] ret = new float[1];
result = Pointer.to(ret);
JCublas2.cublasSnrm2(ContextHolder.getInstance().getHandle(),x.length(), xCPointer.getDevicePointer(), 1,result);
return ret[0];
} else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
CublasPointer xCPointer = new CublasPointer(x);
double[] ret = new double[1];
result = Pointer.to(ret);
double normal2 = JCublas2.cublasDnrm2(ContextHolder.getInstance().getHandle(),x.length(), xCPointer.getDevicePointer(), 1,result);
return normal2;
}
throw new IllegalStateException("Illegal data type on array ");
}
/**
* Returns the index of the max element
* in the given ndarray
*
* @param x
* @return
*/
public static int iamax(INDArray x) {
CublasPointer xCPointer = new CublasPointer(x);
Pointer result;
sync();
if (x.data().dataType() == DataBuffer.Type.FLOAT) {
float[] ret = new float[1];
result = Pointer.to(ret);
JCublas2.cublasIsamax(
ContextHolder.getInstance().getHandle(),
x.length() * x.data().getElementSize(),
xCPointer.getDevicePointer(),
1,result);
ContextHolder.syncStream();
sync();
return (int) (ret[0]- 1);
}
else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
double[] ret = new double[1];
result = Pointer.to(ret);
sync();
JCublas2.cublasIdamax(
ContextHolder.getInstance().getHandle(),
x.length(),
xCPointer.getDevicePointer(),
1, result);
sync();
return (int) (ret[0] - 1);
}
throw new IllegalStateException("Illegal data type on array ");
}
/**
* And and scale by the given scalar da
*
* @param da alpha
* @param A the element to add
* @param B the matrix to add to
*/
public static void axpy(float da, INDArray A, INDArray B) {
DataTypeValidation.assertFloat(A, B);
CublasPointer xAPointer = new CublasPointer(A);
CublasPointer xBPointer = new CublasPointer(B);
sync();
int aStride = LinearUtil.linearStride(A);
int bStride = LinearUtil.linearStride(B);
JCublas2.cublasSaxpy(
ContextHolder.getInstance().getHandle(),
A.length(),
Pointer.to(new float[]{da}),
xAPointer.getDevicePointer(),
aStride,
xBPointer.getDevicePointer(),
bStride);
sync();
xBPointer.copyToHost();
releaseCublasPointers(xAPointer, xBPointer);
}
/**
* @param da
* @param A
* @param B
*/
public static void axpy(IComplexFloat da, IComplexNDArray A, IComplexNDArray B) {
DataTypeValidation.assertFloat(A, B);
CublasPointer aCPointer = new CublasPointer(A);
CublasPointer bCPointer = new CublasPointer(B);
sync();
JCublas2.cublasCaxpy(
ContextHolder.getInstance().getHandle(),
A.length(),
PointerUtil.getPointer(jcuda.cuComplex.cuCmplx(da.realComponent().floatValue(), da.imaginaryComponent().floatValue())),
aCPointer.getDevicePointer(),
A.majorStride() / 2,
bCPointer.getDevicePointer(),
B.majorStride() / 2
);
sync();
}
/**
* @param da
* @param A
* @param B
*/
public static void axpy(IComplexDouble da, IComplexNDArray A, IComplexNDArray B) {
DataTypeValidation.assertDouble(A, B);
CublasPointer aCPointer = new CublasPointer(A);
CublasPointer bCPointer = new CublasPointer(B);
sync();
JCublas2.cublasZaxpy(
ContextHolder.getInstance().getHandle(),
A.length(),
PointerUtil.getPointer(jcuda.cuDoubleComplex.cuCmplx(da.realComponent().floatValue(), da.imaginaryComponent().floatValue())),
aCPointer.getDevicePointer(),
A.majorStride(),
bCPointer.getDevicePointer(),
B.majorStride()
);
sync();
}
/**
* Multiply the given ndarray
* by alpha
*
* @param alpha
* @param x
* @return
*/
public static INDArray scal(double alpha, INDArray x) {
DataTypeValidation.assertDouble(x);
sync();
CublasPointer xCPointer = new CublasPointer(x);
JCublas2.cublasDscal(
ContextHolder.getInstance().getHandle(),
x.length(),
Pointer.to(new double[]{alpha}),
xCPointer.getDevicePointer(),
x.majorStride());
sync();
xCPointer.copyToHost();
releaseCublasPointers(xCPointer);
return x;
}
/**
* Multiply the given ndarray
* by alpha
*
* @param alpha
* @param x
* @return
*/
public static INDArray scal(float alpha, INDArray x) {
DataTypeValidation.assertFloat(x);
sync();
CublasPointer xCPointer = new CublasPointer(x);
JCublas2.cublasSscal(
ContextHolder.getInstance().getHandle(),
x.length(),
Pointer.to(new float[]{alpha}),
xCPointer.getDevicePointer(),
x.majorStride());
sync();
xCPointer.copyToHost();
releaseCublasPointers(xCPointer);
return x;
}
/**
* Copy x to y
*
* @param x the src
* @param y the destination
*/
public static void copy(INDArray x, INDArray y) {
DataTypeValidation.assertSameDataType(x, y);
sync();
CublasPointer xCPointer = new CublasPointer(x);
CublasPointer yCPointer = new CublasPointer(y);
if(x.data().dataType() == DataBuffer.Type.DOUBLE)
JCublas2.cublasDcopy(
ContextHolder.getInstance().getHandle()
,x.length(),xCPointer.getDevicePointer()
,x.majorStride()
,yCPointer.getDevicePointer()
,y.majorStride());
if(x.data().dataType() == DataBuffer.Type.FLOAT)
JCublas2.cublasScopy(ContextHolder.getInstance().getHandle()
,x.length()
,xCPointer.getDevicePointer()
,x.majorStride()
,yCPointer.getDevicePointer()
,y.majorStride());
sync();
yCPointer.copyToHost();
releaseCublasPointers(yCPointer, xCPointer);
}
/**
* Dot product between 2 ndarrays
*
* @param x the first ndarray
* @param y the second ndarray
* @return the dot product between the two ndarrays
*/
public static double dot(INDArray x, INDArray y) {
DataTypeValidation.assertSameDataType(x, y);
sync();
CublasPointer xCPointer = new CublasPointer(x);
CublasPointer yCPointer = new CublasPointer(y);
Pointer result;
if (x.data().dataType() == (DataBuffer.Type.FLOAT)) {
float[] ret = new float[1];
result = Pointer.to(ret);
JCublas2.cublasSdot(
ContextHolder.getInstance().getHandle(),
x.length(),
xCPointer.getDevicePointer(),
1
, yCPointer.getDevicePointer(),
1,result);
sync();
releaseCublasPointers(xCPointer,yCPointer);
return ret[0];
} else {
double[] ret = new double[1];
result = Pointer.to(ret);
JCublas2.cublasDdot(
ContextHolder.getInstance().getHandle(),
x.length(),
xCPointer.getDevicePointer(),
1
, yCPointer.getDevicePointer(),
1,result);
sync();
releaseCublasPointers(xCPointer,yCPointer);
return ret[0];
}
}
private static void releaseCublasPointers(CublasPointer... pointers) {
for(CublasPointer pointer : pointers)
try {
if(pointer != null)
pointer.close();
} catch(Exception e) {
throw new RuntimeException("Could not run cublas command", e);
}
}
/**
* Dot product between to complex ndarrays
* @param x
* @param y
* @return
*/
public static IComplexDouble dot(IComplexNDArray x, IComplexNDArray y) {
DataTypeValidation.assertSameDataType(x, y);
sync();
CublasPointer aCPointer = new CublasPointer(x);
CublasPointer bCPointer = new CublasPointer(y);
jcuda.cuDoubleComplex result = jcuda.cuDoubleComplex.cuCmplx(0,0);
Pointer resultPointer = PointerUtil.getPointer(result);
JCublas2.cublasZdotc(
ContextHolder.getInstance().getHandle(),
x.length(),
aCPointer.getDevicePointer(),
1,
bCPointer.getDevicePointer(),
1,resultPointer);
IComplexDouble ret = Nd4j.createDouble(result.x, result.y);
sync();
releaseCublasPointers(aCPointer, bCPointer);
return ret;
}
public static INDArray ger(INDArray A, INDArray B, INDArray C, double alpha) {
DataTypeValidation.assertDouble(A, B, C);
sync();
// = alpha * A * transpose(B) + C
CublasPointer aCPointer = new CublasPointer(A);
CublasPointer bCPointer = new CublasPointer(B);
CublasPointer cCPointer = new CublasPointer(C);
JCublas2.cublasDger(
ContextHolder.getInstance().getHandle(),
A.rows(), // m
A.columns(),// n
Pointer.to(new double[]{alpha}), // alpha
aCPointer.getDevicePointer(), // d_A or x
A.rows(), // incx
bCPointer.getDevicePointer(), // dB or y
B.rows(), // incy
cCPointer.getDevicePointer(), // dC or A
C.rows() // lda
);
cCPointer.copyToHost();
releaseCublasPointers(aCPointer,bCPointer,cCPointer);
sync();
return C;
}
public static INDArray ger(INDArray A, INDArray B, INDArray C, float alpha) {
DataTypeValidation.assertFloat(A, B, C);
sync();
// = alpha * A * transpose(B) + C
CublasPointer aCPointer = new CublasPointer(A);
CublasPointer bCPointer = new CublasPointer(B);
CublasPointer cCPointer = new CublasPointer(C);
JCublas2.cublasSger(
ContextHolder.getInstance().getHandle(),
A.rows(), // m
A.columns(),// n
Pointer.to(new float[]{alpha}), // alpha
aCPointer.getDevicePointer(), // d_A or x
A.rows(), // incx
bCPointer.getDevicePointer(), // dB or y
B.rows(), // incy
cCPointer.getDevicePointer(), // dC or A
C.rows() // lda
);
sync();
cCPointer.copyToHost();
releaseCublasPointers(aCPointer,bCPointer,cCPointer);
return C;
}
/**
* Complex multiplication of an ndarray
*
* @param alpha
* @param x
* @return
*/
public static IComplexNDArray scal(IComplexFloat alpha, IComplexNDArray x) {
DataTypeValidation.assertFloat(x);
sync();
CublasPointer xCPointer = new CublasPointer(x);
JCublas2.cublasCscal(
ContextHolder.getInstance().getHandle(),
x.length(),
PointerUtil.getPointer(jcuda.cuComplex.cuCmplx(alpha.realComponent(), alpha.imaginaryComponent())),
xCPointer.getDevicePointer(),
1
);
sync();
xCPointer.copyToHost();
releaseCublasPointers(xCPointer);
return x;
}
/**
* Complex multiplication of an ndarray
*
* @param alpha
* @param x
* @return
*/
public static IComplexNDArray scal(IComplexDouble alpha, IComplexNDArray x) {
DataTypeValidation.assertDouble(x);
sync();
CublasPointer xCPointer = new CublasPointer(x);
JCublas2.cublasZscal(
ContextHolder.getInstance().getHandle(),
x.length(),
PointerUtil.getPointer(jcuda.cuDoubleComplex.cuCmplx(alpha.realComponent(), alpha.imaginaryComponent())),
xCPointer.getDevicePointer(),
1
);
sync();
xCPointer.copyToHost();
releaseCublasPointers(xCPointer);
return x;
}
/**
* Complex dot product
*
* @param x
* @param y
* @return
*/
public static IComplexDouble dotu(IComplexNDArray x, IComplexNDArray y) {
DataTypeValidation.assertSameDataType(x, y);
sync();
CublasPointer xCPointer = new CublasPointer(x);
CublasPointer yCPointer = new CublasPointer(y);
IComplexDouble ret = null;
if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
cuDoubleComplex alpha = cuDoubleComplex.cuCmplx(0, 0);
Pointer p = PointerUtil.getPointer(alpha);
JCublas2.cublasZdotu(
ContextHolder.getInstance().getHandle()
,x.length()
, xCPointer.getDevicePointer(),
1
, yCPointer.getDevicePointer()
, 1,p);
ret = Nd4j.createDouble(alpha.x, alpha.y);
} else {
cuComplex complex = cuComplex.cuCmplx(0, 0);
Pointer p = PointerUtil.getPointer(complex);
JCublas2.cublasCdotu(ContextHolder.getInstance().getHandle()
,x.length()
, xCPointer.getDevicePointer()
, 1
, yCPointer.getDevicePointer()
, 1,p);
ret = Nd4j.createDouble(complex.x, complex.y);
}
sync();
releaseCublasPointers(xCPointer, yCPointer);
return ret;
}
/**
* @param A
* @param B
* @param C
* @param Alpha
* @return
*/
public static IComplexNDArray geru(IComplexNDArray A,
IComplexNDArray B,
IComplexNDArray C, IComplexDouble Alpha) {
// = alpha * A * tranpose(B) + C
sync();
DataTypeValidation.assertDouble(A, B, C);
CublasPointer aCPointer = new CublasPointer(A);
CublasPointer bCPointer = new CublasPointer(B);
CublasPointer cCPointer = new CublasPointer(C);
cuDoubleComplex alpha = cuDoubleComplex.cuCmplx(Alpha.realComponent(), Alpha.imaginaryComponent());
JCublas2.cublasZgeru(
ContextHolder.getInstance().getHandle(),
A.rows(), // m
A.columns(),// n
PointerUtil.getPointer(alpha), // alpha
aCPointer.getDevicePointer(), // d_A or x
A.rows(), // incx
bCPointer.getDevicePointer(), // d_B or y
B.rows(), // incy
cCPointer.getDevicePointer(), // d_C or A
C.rows() // lda
);
sync();
cCPointer.copyToHost();
releaseCublasPointers(aCPointer,bCPointer,cCPointer);
return C;
}
/**
* @param A
* @param B
* @param C
* @param Alpha
* @return
*/
public static IComplexNDArray gerc(IComplexNDArray A, IComplexNDArray B, IComplexNDArray C,
IComplexFloat Alpha) {
DataTypeValidation.assertFloat(A, B, C);
// = alpha * A * tranpose(B) + C
sync();
CublasPointer aCPointer = new CublasPointer(A);
CublasPointer bCPointer = new CublasPointer(B);
CublasPointer cCPointer = new CublasPointer(C);
cuComplex alpha = cuComplex.cuCmplx(Alpha.realComponent(), Alpha.imaginaryComponent());
JCublas2.cublasCgerc(
ContextHolder.getInstance().getHandle(),
A.rows(), // m
A.columns(),// n
PointerUtil.getPointer(alpha), // alpha
aCPointer.getDevicePointer(), // dA or x
A.rows(), // incx
bCPointer.getDevicePointer(), // dB or y
B.rows(), // incy
cCPointer.getDevicePointer(), // dC or A
C.rows() // lda
);
sync();
cCPointer.copyToHost();
releaseCublasPointers(aCPointer,bCPointer,cCPointer);
return C;
}
/**
* @param A
* @param B
* @param C
* @param Alpha
* @return
*/
public static IComplexNDArray geru(IComplexNDArray A,
IComplexNDArray B,
IComplexNDArray C, IComplexFloat Alpha) {
DataTypeValidation.assertFloat(A, B, C);
// = alpha * A * tranpose(B) + C
sync();
CublasPointer aCPointer = new CublasPointer(A);
CublasPointer bCPointer = new CublasPointer(B);
CublasPointer cCPointer = new CublasPointer(C);
cuDoubleComplex alpha = cuDoubleComplex.cuCmplx(Alpha.realComponent(), Alpha.imaginaryComponent());
JCublas2.cublasZgeru(
ContextHolder.getInstance().getHandle(),
A.rows(), // m
A.columns(),// n
PointerUtil.getPointer(alpha), // alpha
aCPointer.getDevicePointer(), // d_A or x
A.rows(), // incx
bCPointer.getDevicePointer(), // d_B or y
B.rows(), // incy
cCPointer.getDevicePointer(), // d_C or A
C.rows() // lda
);
sync();
cCPointer.copyToHost();
releaseCublasPointers(aCPointer,bCPointer,cCPointer);
return C;
}
/**
* @param A
* @param B
* @param C
* @param Alpha
* @return
*/
public static IComplexNDArray gerc(IComplexNDArray A, IComplexNDArray B, IComplexNDArray C,
IComplexDouble Alpha) {
DataTypeValidation.assertDouble(A, B, C);
// = alpha * A * tranpose(B) + C
sync();
CublasPointer aCPointer = new CublasPointer(A);
CublasPointer bCPointer = new CublasPointer(B);
CublasPointer cCPointer = new CublasPointer(C);
cuDoubleComplex alpha = cuDoubleComplex.cuCmplx(Alpha.realComponent(), Alpha.imaginaryComponent());
JCublas2.cublasZgerc(
ContextHolder.getInstance().getHandle(),
A.rows(), // m
A.columns(),// n
PointerUtil.getPointer(alpha), // alpha
aCPointer.getDevicePointer(), // dA or x
A.rows(), // incx
bCPointer.getDevicePointer(), // dB or y
B.rows(), // incy
cCPointer.getDevicePointer(), // dC or A
C.rows() // lda
);
sync();
cCPointer.copyToHost();
releaseCublasPointers(aCPointer,bCPointer,cCPointer);
return C;
}
/**
* Simpler version of saxpy
* taking in to account the parameters of the ndarray
*
* @param alpha the alpha to scale by
* @param x the x
* @param y the y
*/
public static void axpy(double alpha, INDArray x, INDArray y) {
DataTypeValidation.assertDouble(x, y);
sync();
CublasPointer xCPointer = new CublasPointer(x);
CublasPointer yCPointer = new CublasPointer(y);
JCublas2.cublasDaxpy(
ContextHolder.getInstance().getHandle(),x.length()
, Pointer.to(new double[]{alpha})
, xCPointer.getDevicePointer()
, 1
, yCPointer.getDevicePointer()
, 1);
sync();
yCPointer.copyToHost();
releaseCublasPointers(xCPointer, yCPointer);
}
/**
* Simpler version of saxpy
* taking in to account the parameters of the ndarray
*
* @param alpha the alpha to scale by
* @param x the x
* @param y the y
*/
public static void saxpy(float alpha, INDArray x, INDArray y) {
DataTypeValidation.assertFloat(x, y);
sync();
CublasPointer xCPointer = new CublasPointer(x);
CublasPointer yCPointer = new CublasPointer(y);
JCublas2.cublasSaxpy(
ContextHolder.getInstance().getHandle()
,x.length()
, Pointer.to(new float[]{alpha})
, xCPointer.getDevicePointer(),
1,
yCPointer.getDevicePointer()
, 1);
sync();
xCPointer.copyToHost();
releaseCublasPointers(xCPointer, yCPointer);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy