org.nd4j.linalg.cpu.nativecpu.blas.CpuLapack Maven / Gradle / Ivy
package org.nd4j.linalg.cpu.nativecpu.blas;
import org.nd4j.linalg.api.blas.impl.BaseLapack;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import static org.bytedeco.javacpp.openblas.*;
import org.nd4j.linalg.api.blas.BlasException ;
/**
* CPU lapack implementation
*/
public class CpuLapack extends BaseLapack {
protected static int getColumnOrder(INDArray A) {
return A.ordering() == 'f' ? LAPACK_COL_MAJOR : LAPACK_ROW_MAJOR;
}
protected static int getLda(INDArray A) {
return A.ordering() == 'f' ? A.rows() : A.columns();
}
//=========================
// L U DECOMP
@Override
public void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
int status = LAPACKE_sgetrf(getColumnOrder(A), M, N,
(FloatPointer)A.data().addressPointer(),
getLda(A), (IntPointer)IPIV.data().addressPointer()
);
if( status < 0 ) {
throw new BlasException( "Failed to execute sgetrf", status ) ;
}
}
@Override
public void dgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
int status = LAPACKE_dgetrf(getColumnOrder(A), M, N, (DoublePointer)A.data().addressPointer(),
getLda(A), (IntPointer)IPIV.data().addressPointer()
);
if( status < 0 ) {
throw new BlasException( "Failed to execute dgetrf", status ) ;
}
}
//=========================
// Q R DECOMP
@Override
public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
INDArray tau = Nd4j.create( N ) ;
int status = LAPACKE_sgeqrf(getColumnOrder(A), M, N,
(FloatPointer)A.data().addressPointer(), getLda(A),
(FloatPointer)tau.data().addressPointer()
);
if( status != 0 ) {
throw new BlasException( "Failed to execute sgeqrf", status ) ;
}
// Copy R ( upper part of Q ) into result
if( R != null ) {
R.assign( A.get( NDArrayIndex.interval( 0, A.columns() ), NDArrayIndex.all() ) ) ;
INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;
for( int i=1 ; i