Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.cpu.nativecpu.blas;
import org.nd4j.linalg.api.blas.impl.BaseLapack;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.nd4j.linalg.api.blas.BlasException ;
import static org.bytedeco.openblas.global.openblas.*;
public class CpuLapack extends BaseLapack {
protected static int getColumnOrder(INDArray A) {
return A.ordering() == 'f' ? LAPACK_COL_MAJOR : LAPACK_ROW_MAJOR;
}
protected static int getLda(INDArray A) {
if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) {
throw new ND4JArraySizeException();
}
return A.ordering() == 'f' ? (int) A.rows() : (int) A.columns();
}
//=========================
// L U DECOMP
@Override
public void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
int status = LAPACKE_sgetrf(getColumnOrder(A), M, N,
(FloatPointer)A.data().addressPointer(),
getLda(A), (IntPointer)IPIV.data().addressPointer()
);
if( status < 0 ) {
throw new BlasException( "Failed to execute sgetrf", status ) ;
}
}
@Override
public void dgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
int status = LAPACKE_dgetrf(getColumnOrder(A), M, N, (DoublePointer)A.data().addressPointer(),
getLda(A), (IntPointer)IPIV.data().addressPointer()
);
if( status < 0 ) {
throw new BlasException( "Failed to execute dgetrf", status ) ;
}
}
//=========================
// Q R DECOMP
@Override
public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
INDArray tau = Nd4j.create(DataType.FLOAT, N ) ;
int status = LAPACKE_sgeqrf(getColumnOrder(A), M, N,
(FloatPointer)A.data().addressPointer(), getLda(A),
(FloatPointer)tau.data().addressPointer()
);
if( status != 0 ) {
throw new BlasException( "Failed to execute sgeqrf", status ) ;
}
// Copy R ( upper part of Q ) into result
if( R != null ) {
R.assign( A.get( NDArrayIndex.interval( 0, A.columns() ), NDArrayIndex.all() ) ) ;
INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;
for( int i=1 ; i