Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu.blas;
import org.nd4j.linalg.api.blas.impl.BaseLapack;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.nd4j.linalg.api.blas.BlasException ;
import static org.bytedeco.javacpp.openblas.*;
/**
* CPU lapack implementation
*/
public class CpuLapack extends BaseLapack {
protected static int getColumnOrder(INDArray A) {
return A.ordering() == 'f' ? LAPACK_COL_MAJOR : LAPACK_ROW_MAJOR;
}
protected static int getLda(INDArray A) {
// FIXME: int cast
return A.ordering() == 'f' ? (int) A.rows() : (int) A.columns();
}
//=========================
// L U DECOMP
@Override
public void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
int status = LAPACKE_sgetrf(getColumnOrder(A), M, N,
(FloatPointer)A.data().addressPointer(),
getLda(A), (IntPointer)IPIV.data().addressPointer()
);
if( status < 0 ) {
throw new BlasException( "Failed to execute sgetrf", status ) ;
}
}
@Override
public void dgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
int status = LAPACKE_dgetrf(getColumnOrder(A), M, N, (DoublePointer)A.data().addressPointer(),
getLda(A), (IntPointer)IPIV.data().addressPointer()
);
if( status < 0 ) {
throw new BlasException( "Failed to execute dgetrf", status ) ;
}
}
//=========================
// Q R DECOMP
@Override
public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
INDArray tau = Nd4j.create( N ) ;
int status = LAPACKE_sgeqrf(getColumnOrder(A), M, N,
(FloatPointer)A.data().addressPointer(), getLda(A),
(FloatPointer)tau.data().addressPointer()
);
if( status != 0 ) {
throw new BlasException( "Failed to execute sgeqrf", status ) ;
}
// Copy R ( upper part of Q ) into result
if( R != null ) {
R.assign( A.get( NDArrayIndex.interval( 0, A.columns() ), NDArrayIndex.all() ) ) ;
INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;
for( int i=1 ; i