All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.checkutil.CheckUtil Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.checkutil;

import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;

import java.util.Arrays;

/**@author Alex Black
 */
public class CheckUtil {

	/**Check first.mmul(second) using Apache commons math mmul. Float/double matrices only.
* Returns true if OK, false otherwise.
* Checks each element according to relative error (|a-b|/(|a|+|b|); however absolute error |a-b| must * also exceed minAbsDifference for it to be considered a failure. This is necessary to avoid instability * near 0: i.e., Nd4j mmul might return element of 0.0 (due to underflow on float) while Apache commons math * mmul might be say 1e-30 or something (using doubles). * Throws exception if matrices can't be multiplied * Checks each element of the result. If * @param first First matrix * @param second Second matrix * @param maxRelativeDifference Maximum relative error * @param minAbsDifference Minimum absolute difference for failure * @return true if OK, false if result incorrect */ public static boolean checkMmul(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference) { if(first.size(1) != second.size(0)) throw new IllegalArgumentException("first.columns != second.rows"); RealMatrix rmFirst = convertToApacheMatrix(first); RealMatrix rmSecond = convertToApacheMatrix(second); INDArray result = first.mmul(second); RealMatrix rmResult = rmFirst.multiply(rmSecond); if(!checkShape(rmResult,result)) return false; boolean ok = checkEntries(rmResult, result, maxRelativeDifference, minAbsDifference); if(!ok) { INDArray onCopies = Shape.toOffsetZeroCopy(first).mmul(Shape.toOffsetZeroCopy(second)); printFailureDetails(first, second, rmResult, result, onCopies, "mmul"); } return ok; } public static boolean checkGemm(INDArray a, INDArray b, INDArray c, boolean transposeA, boolean transposeB, double alpha, double beta, double maxRelativeDifference, double minAbsDifference) { int commonDimA = (transposeA ? a.rows() : a.columns()); int commonDimB = (transposeB ? b.columns() : b.rows()); if(commonDimA != commonDimB) throw new IllegalArgumentException("Common dimensions don't match: a.shape=" + Arrays.toString(a.shape()) + ", b.shape="+ Arrays.toString(b.shape()) + ", tA=" + transposeA + ", tb=" + transposeB); int outRows = (transposeA ? a.columns() : a.rows()); int outCols = (transposeB ? b.rows() : b.columns()); if(c.rows() != outRows || c.columns() != outCols) throw new IllegalArgumentException("C does not match outRows or outCols"); if(c.offset() != 0 || c.ordering() != 'f') throw new IllegalArgumentException("Invalid c"); INDArray aConvert = transposeA ? a.transpose() : a; RealMatrix rmA = convertToApacheMatrix(aConvert); INDArray bConvet = transposeB ? b.transpose() : b; RealMatrix rmB = convertToApacheMatrix(bConvet); RealMatrix rmC = convertToApacheMatrix(c); RealMatrix rmExpected = rmA.scalarMultiply(alpha).multiply(rmB).add(rmC.scalarMultiply(beta)); INDArray cCopy1 = Nd4j.create(c.shape(), 'f'); cCopy1.assign(c); INDArray cCopy2 = Nd4j.create(c.shape(), 'f'); cCopy2.assign(c); INDArray out = Nd4j.gemm(a, b, c, transposeA, transposeB, alpha, beta); if(out != c) { System.out.println("Returned different array than c"); return false; } if(!checkShape(rmExpected,out)) return false; boolean ok = checkEntries(rmExpected,out,maxRelativeDifference,minAbsDifference); if(!ok) { INDArray aCopy = Shape.toOffsetZeroCopy(a); INDArray bCopy = Shape.toOffsetZeroCopy(b); INDArray onCopies = Nd4j.gemm(aCopy, bCopy, cCopy1, transposeA, transposeB, alpha, beta); printGemmFailureDetails(a,b,cCopy2,transposeA,transposeB,alpha,beta,rmExpected,out,onCopies); } return ok; } /**Same as checkMmul, but for matrix addition */ public static boolean checkAdd(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference) { RealMatrix rmFirst = convertToApacheMatrix(first); RealMatrix rmSecond = convertToApacheMatrix(second); INDArray result = first.add(second); RealMatrix rmResult = rmFirst.add(rmSecond); if (!checkShape(rmResult, result)) return false; boolean ok = checkEntries(rmResult,result,maxRelativeDifference,minAbsDifference); if(!ok){ INDArray onCopies = Shape.toOffsetZeroCopy(first).add(Shape.toOffsetZeroCopy(second)); printFailureDetails(first, second, rmResult, result, onCopies, "add"); } return ok; } /** Same as checkMmul, but for matrix subtraction */ public static boolean checkSubtract(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference ){ RealMatrix rmFirst = convertToApacheMatrix(first); RealMatrix rmSecond = convertToApacheMatrix(second); INDArray result = first.sub(second); RealMatrix rmResult = rmFirst.subtract(rmSecond); if(!checkShape(rmResult,result)) return false; boolean ok = checkEntries(rmResult, result, maxRelativeDifference, minAbsDifference); if(!ok){ INDArray onCopies = Shape.toOffsetZeroCopy(first).sub(Shape.toOffsetZeroCopy(second)); printFailureDetails(first, second, rmResult, result, onCopies, "sub"); } return ok; } public static boolean checkMulManually(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference ){ //No apache commons element-wise multiply, but can do this manually INDArray result = first.mul(second); int[] shape = first.shape(); INDArray expected = Nd4j.zeros(first.shape()); for(int i=0; i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy