Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.linalg.checkutil.CheckUtil Maven / Gradle / Ivy
package org.nd4j.linalg.checkutil;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
/**@author Alex Black
*/
public class CheckUtil {
/**Check first.mmul(second) using Apache commons math mmul. Float/double matrices only.
* Returns true if OK, false otherwise.
* Checks each element according to relative error (|a-b|/(|a|+|b|); however absolute error |a-b| must
* also exceed minAbsDifference for it to be considered a failure. This is necessary to avoid instability
* near 0: i.e., Nd4j mmul might return element of 0.0 (due to underflow on float) while Apache commons math
* mmul might be say 1e-30 or something (using doubles).
* Throws exception if matrices can't be multiplied
* Checks each element of the result. If
* @param first First matrix
* @param second Second matrix
* @param maxRelativeDifference Maximum relative error
* @param minAbsDifference Minimum absolute difference for failure
* @return true if OK, false if result incorrect
*/
public static boolean checkMmul(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference) {
if(first.size(1) != second.size(0))
throw new IllegalArgumentException("first.columns != second.rows");
RealMatrix rmFirst = convertToApacheMatrix(first);
RealMatrix rmSecond = convertToApacheMatrix(second);
INDArray result = first.mmul(second);
RealMatrix rmResult = rmFirst.multiply(rmSecond);
if(!checkShape(rmResult,result)) return false;
boolean ok = checkEntries(rmResult, result, maxRelativeDifference, minAbsDifference);
if(!ok) {
INDArray onCopies = Shape.toOffsetZeroCopy(first).mmul(Shape.toOffsetZeroCopy(second));
printFailureDetails(first, second, rmResult, result, onCopies, "mmul");
}
return ok;
}
public static boolean checkGemm(INDArray a, INDArray b, INDArray c, boolean transposeA, boolean transposeB,
double alpha, double beta,
double maxRelativeDifference, double minAbsDifference) {
int commonDimA = (transposeA ? a.rows() : a.columns());
int commonDimB = (transposeB ? b.columns() : b.rows());
if(commonDimA != commonDimB)
throw new IllegalArgumentException("Common dimensions don't match: a.shape=" +
Arrays.toString(a.shape()) + ", b.shape="+ Arrays.toString(b.shape()) + ", tA=" + transposeA + ", tb=" + transposeB);
int outRows = (transposeA ? a.columns() : a.rows());
int outCols = (transposeB ? b.rows() : b.columns());
if(c.rows() != outRows || c.columns() != outCols)
throw new IllegalArgumentException("C does not match outRows or outCols");
if(c.offset() != 0 || c.ordering() != 'f')
throw new IllegalArgumentException("Invalid c");
INDArray aConvert = transposeA ? a.transpose() : a;
RealMatrix rmA = convertToApacheMatrix(aConvert);
INDArray bConvet = transposeB ? b.transpose() : b;
RealMatrix rmB = convertToApacheMatrix(bConvet);
RealMatrix rmC = convertToApacheMatrix(c);
RealMatrix rmExpected = rmA.scalarMultiply(alpha).multiply(rmB).add(rmC.scalarMultiply(beta));
INDArray cCopy1 = Nd4j.create(c.shape(), 'f');
cCopy1.assign(c);
INDArray cCopy2 = Nd4j.create(c.shape(), 'f');
cCopy2.assign(c);
INDArray out = Nd4j.gemm(a, b, c, transposeA, transposeB, alpha, beta);
if(out != c) {
System.out.println("Returned different array than c");
return false;
}
if(!checkShape(rmExpected,out))
return false;
boolean ok = checkEntries(rmExpected,out,maxRelativeDifference,minAbsDifference);
if(!ok) {
INDArray aCopy = Shape.toOffsetZeroCopy(a);
INDArray bCopy = Shape.toOffsetZeroCopy(b);
INDArray onCopies = Nd4j.gemm(aCopy, bCopy, cCopy1, transposeA, transposeB, alpha, beta);
printGemmFailureDetails(a,b,cCopy2,transposeA,transposeB,alpha,beta,rmExpected,out,onCopies);
}
return ok;
}
/**Same as checkMmul, but for matrix addition */
public static boolean checkAdd(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference) {
RealMatrix rmFirst = convertToApacheMatrix(first);
RealMatrix rmSecond = convertToApacheMatrix(second);
INDArray result = first.add(second);
RealMatrix rmResult = rmFirst.add(rmSecond);
if (!checkShape(rmResult, result)) return false;
boolean ok = checkEntries(rmResult,result,maxRelativeDifference,minAbsDifference);
if(!ok){
INDArray onCopies = Shape.toOffsetZeroCopy(first).add(Shape.toOffsetZeroCopy(second));
printFailureDetails(first, second, rmResult, result, onCopies, "add");
}
return ok;
}
/** Same as checkMmul, but for matrix subtraction */
public static boolean checkSubtract(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference ){
RealMatrix rmFirst = convertToApacheMatrix(first);
RealMatrix rmSecond = convertToApacheMatrix(second);
INDArray result = first.sub(second);
RealMatrix rmResult = rmFirst.subtract(rmSecond);
if(!checkShape(rmResult,result)) return false;
boolean ok = checkEntries(rmResult, result, maxRelativeDifference, minAbsDifference);
if(!ok){
INDArray onCopies = Shape.toOffsetZeroCopy(first).sub(Shape.toOffsetZeroCopy(second));
printFailureDetails(first, second, rmResult, result, onCopies, "sub");
}
return ok;
}
public static boolean checkMulManually(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference ){
//No apache commons element-wise multiply, but can do this manually
INDArray result = first.mul(second);
int[] shape = first.shape();
INDArray expected = Nd4j.zeros(first.shape());
for(int i=0; i