Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
package org.nd4j.linalg.api.blas.impl;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.blas.Lapack;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/**
* Base lapack define float and double versions.
*
* @author Adam Gibson
* @author rcorbish
*/
@Slf4j
public abstract class BaseLapack implements Lapack {
@Override
public INDArray getrf(INDArray A) {
int m = A.rows();
int n = A.columns();
INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());
int mn = Math.min(m, n);
INDArray IPIV = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(mn),
Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, mn}).getFirst());
if (A.data().dataType() == DataBuffer.Type.DOUBLE)
dgetrf(m, n, A, IPIV, INFO);
else if (A.data().dataType() == DataBuffer.Type.FLOAT)
sgetrf(m, n, A, IPIV, INFO);
else
throw new UnsupportedOperationException();
if (INFO.getInt(0) < 0) {
throw new Error("Parameter #" + INFO.getInt(0) + " to getrf() was not valid");
} else if (INFO.getInt(0) > 0) {
log.warn("The matrix is singular - cannot be used for inverse op. Check L matrix at row "
+ INFO.getInt(0));
}
return IPIV;
}
/**
* Float/Double versions of LU decomp.
* This is the official LAPACK interface (in case you want to call this directly)
* See getrf for full details on LU Decomp
*
* @param M the number of rows in the matrix A
* @param N the number of cols in the matrix A
* @param A the matrix to factorize - data must be in column order ( create with 'f' ordering )
* @param IPIV an output array for the permutations ( must be int based storage )
* @param INFO error details 1 int array, a positive number (i) implies row i cannot be factored, a negative value implies paramtere i is invalid
*/
public abstract void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO);
public abstract void dgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO);
@Override
public void potrf(INDArray A, boolean lower ) {
byte uplo = (byte)(lower?'L':'U') ; // upper or lower part of the factor desired ?
int n = A.columns();
INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());
if (A.data().dataType() == DataBuffer.Type.DOUBLE)
dpotrf( uplo, n, A, INFO);
else if (A.data().dataType() == DataBuffer.Type.FLOAT)
spotrf( uplo, n, A, INFO);
else
throw new UnsupportedOperationException();
if (INFO.getInt(0) < 0) {
throw new Error("Parameter #" + INFO.getInt(0) + " to potrf() was not valid");
} else if (INFO.getInt(0) > 0) {
throw new Error("The matrix is not positive definite! (potrf fails @ order " + INFO.getInt(0) + ")" ) ;
}
return ;
}
/**
* Float/Double versions of cholesky decomp for positive definite matrices
*
* A = LL*
*
* @param uplo which factor to return L or U
* @param M the number of rows & cols in the matrix A
* @param A the matrix to factorize - data must be in column order ( create with 'f' ordering )
* @param INFO error details 1 int array, a positive number (i) implies row i cannot be factored, a negative value implies paramtere i is invalid
*/
public abstract void spotrf( byte uplo, int N, INDArray A, INDArray INFO) ;
public abstract void dpotrf( byte uplo, int N, INDArray A, INDArray INFO ) ;
@Override
public void geqrf(INDArray A, INDArray R ) {
int m = A.rows();
int n = A.columns();
INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());
if( R.rows() != A.columns() || R.columns() != A.columns() ) {
throw new Error( "geqrf: R must be N x N (n = columns in A)") ;
}
if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
dgeqrf(m, n, A, R, INFO);
} else if (A.data().dataType() == DataBuffer.Type.FLOAT) {
sgeqrf(m, n, A, R, INFO);
} else {
throw new UnsupportedOperationException();
}
if (INFO.getInt(0) < 0) {
throw new Error("Parameter #" + INFO.getInt(0) + " to getrf() was not valid");
}
}
/**
* Float/Double versions of QR decomp.
* This is the official LAPACK interface (in case you want to call this directly)
* See geqrf for full details on LU Decomp
*
* @param M the number of rows in the matrix A
* @param N the number of cols in the matrix A
* @param A the matrix to factorize - data must be in column order ( create with 'f' ordering )
* @param R an output array for other part of factorization
* @param INFO error details 1 int array, a positive number (i) implies row i cannot be factored, a negative value implies paramtere i is invalid
*/
public abstract void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO);
public abstract void dgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO);
@Override
public int syev( char jobz, char uplo, INDArray A, INDArray V ) {
if( A.rows() != A.columns() ) {
throw new Error( "syev: A must be square.") ;
}
if( A.rows() != V.length() ) {
throw new Error( "syev: V must be the length of the matrix dimension.") ;
}
int status = -1 ;
if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
status = dsyev( jobz, uplo, A.rows(), A, V ) ;
} else if (A.data().dataType() == DataBuffer.Type.FLOAT) {
status = ssyev( jobz, uplo, A.rows(), A, V ) ;
} else {
throw new UnsupportedOperationException();
}
return status ;
}
/**
* Float/Double versions of eigen value/vector calc.
*
* @param jobz 'N' - no eigen vectors, 'V' - return eigenvectors
* @param uplo upper or lower part of symmetric matrix to use
* @param N the number of rows & cols in the matrix A
* @param A the matrix to calculate eigenvectors
* @param R an output array for eigenvalues ( may be null )
* @param INFO error details 1 int array, a positive number (i) implies row i cannot be factored, a negative value implies paramtere i is invalid
*/
public abstract int ssyev( char jobz, char uplo, int N, INDArray A, INDArray R ) ;
public abstract int dsyev( char jobz, char uplo, int N, INDArray A, INDArray R ) ;
@Override
public void gesvd(INDArray A, INDArray S, INDArray U, INDArray VT) {
int m = A.rows();
int n = A.columns();
byte jobu = (byte) (U == null ? 'N' : 'A');
byte jobvt = (byte) (VT == null ? 'N' : 'A');
INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());
if (A.data().dataType() == DataBuffer.Type.DOUBLE)
dgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO);
else if (A.data().dataType() == DataBuffer.Type.FLOAT)
sgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO);
else
throw new UnsupportedOperationException();
if (INFO.getInt(0) < 0) {
throw new Error("Parameter #" + INFO.getInt(0) + " to gesvd() was not valid");
} else if (INFO.getInt(0) > 0) {
log.warn("The matrix contains singular elements. Check S matrix at row " + INFO.getInt(0));
}
}
public abstract void sgesvd(byte jobu, byte jobvt, int M, int N, INDArray A, INDArray S, INDArray U, INDArray VT,
INDArray INFO);
public abstract void dgesvd(byte jobu, byte jobvt, int M, int N, INDArray A, INDArray S, INDArray U, INDArray VT,
INDArray INFO);
@Override
public INDArray getPFactor(int M, INDArray ipiv) {
// The simplest permutation is the identity matrix
INDArray P = Nd4j.eye(M); // result is a square matrix with given size
for (int i = 0; i < ipiv.length(); i++) {
int pivot = ipiv.getInt(i) - 1; // Did we swap row #i with anything?
if (pivot > i) { // don't reswap when we get lower down in the vector
INDArray v1 = P.getColumn(i).dup(); // because of row vs col major order we'll ...
INDArray v2 = P.getColumn(pivot); // ... make a transposed matrix immediately
P.putColumn(i, v2);
P.putColumn(pivot, v1); // note dup() above is required - getColumn() is a 'view'
}
}
return P; // the permutation matrix - contains a single 1 in any row and column
}
/* TODO: consider doing this in place to save memory. This implies U is taken out first
L is the same shape as the input matrix. Just the lower triangular with a diagonal of 1s
*/
@Override
public INDArray getLFactor(INDArray A) {
int m = A.rows();
int n = A.columns();
INDArray L = Nd4j.create(m, n);
for (int r = 0; r < m; r++) {
for (int c = 0; c < n; c++) {
if (r > c && r < m && c < n) {
L.putScalar(r, c, A.getFloat(r, c));
} else if (r < c) {
L.putScalar(r, c, 0.f);
} else {
L.putScalar(r, c, 1.f);
}
}
}
return L;
}
@Override
public INDArray getUFactor(INDArray A) {
int m = A.rows();
int n = A.columns();
INDArray U = Nd4j.create(n, n);
for (int r = 0; r < n; r++) {
for (int c = 0; c < n; c++) {
if (r <= c && r < m && c < n) {
U.putScalar(r, c, A.getFloat(r, c));
} else {
U.putScalar(r, c, 0.f);
}
}
}
return U;
}
}