Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.linalg.checkutil.CheckUtil Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.checkutil;
import lombok.val;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
/**@author Alex Black
*/
public class CheckUtil {
/**Check first.mmul(second) using Apache commons math mmul. Float/double matrices only.
* Returns true if OK, false otherwise.
* Checks each element according to relative error (|a-b|/(|a|+|b|); however absolute error |a-b| must
* also exceed minAbsDifference for it to be considered a failure. This is necessary to avoid instability
* near 0: i.e., Nd4j mmul might return element of 0.0 (due to underflow on float) while Apache commons math
* mmul might be say 1e-30 or something (using doubles).
* Throws exception if matrices can't be multiplied
* Checks each element of the result. If
* @param first First matrix
* @param second Second matrix
* @param maxRelativeDifference Maximum relative error
* @param minAbsDifference Minimum absolute difference for failure
* @return true if OK, false if result incorrect
*/
public static boolean checkMmul(INDArray first, INDArray second, double maxRelativeDifference,
double minAbsDifference) {
if (first.size(1) != second.size(0))
throw new IllegalArgumentException("first.columns != second.rows");
RealMatrix rmFirst = convertToApacheMatrix(first);
RealMatrix rmSecond = convertToApacheMatrix(second);
INDArray result = first.mmul(second);
RealMatrix rmResult = rmFirst.multiply(rmSecond);
if (!checkShape(rmResult, result))
return false;
boolean ok = checkEntries(rmResult, result, maxRelativeDifference, minAbsDifference);
if (!ok) {
INDArray onCopies = Shape.toOffsetZeroCopy(first).mmul(Shape.toOffsetZeroCopy(second));
printFailureDetails(first, second, rmResult, result, onCopies, "mmul");
}
return ok;
}
public static boolean checkGemm(INDArray a, INDArray b, INDArray c, boolean transposeA, boolean transposeB,
double alpha, double beta, double maxRelativeDifference, double minAbsDifference) {
long commonDimA = (transposeA ? a.rows() : a.columns());
long commonDimB = (transposeB ? b.columns() : b.rows());
if (commonDimA != commonDimB)
throw new IllegalArgumentException("Common dimensions don't match: a.shape=" + Arrays.toString(a.shape())
+ ", b.shape=" + Arrays.toString(b.shape()) + ", tA=" + transposeA + ", tb=" + transposeB);
long outRows = (transposeA ? a.columns() : a.rows());
long outCols = (transposeB ? b.rows() : b.columns());
if (c.rows() != outRows || c.columns() != outCols)
throw new IllegalArgumentException("C does not match outRows or outCols");
if (c.offset() != 0 || c.ordering() != 'f')
throw new IllegalArgumentException("Invalid c");
INDArray aConvert = transposeA ? a.transpose() : a;
RealMatrix rmA = convertToApacheMatrix(aConvert);
INDArray bConvet = transposeB ? b.transpose() : b;
RealMatrix rmB = convertToApacheMatrix(bConvet);
RealMatrix rmC = convertToApacheMatrix(c);
RealMatrix rmExpected = rmA.scalarMultiply(alpha).multiply(rmB).add(rmC.scalarMultiply(beta));
INDArray cCopy1 = Nd4j.create(c.shape(), 'f');
cCopy1.assign(c);
INDArray cCopy2 = Nd4j.create(c.shape(), 'f');
cCopy2.assign(c);
INDArray out = Nd4j.gemm(a, b, c, transposeA, transposeB, alpha, beta);
if (out != c) {
System.out.println("Returned different array than c");
return false;
}
if (!checkShape(rmExpected, out))
return false;
boolean ok = checkEntries(rmExpected, out, maxRelativeDifference, minAbsDifference);
if (!ok) {
INDArray aCopy = Shape.toOffsetZeroCopy(a);
INDArray bCopy = Shape.toOffsetZeroCopy(b);
INDArray onCopies = Nd4j.gemm(aCopy, bCopy, cCopy1, transposeA, transposeB, alpha, beta);
printGemmFailureDetails(a, b, cCopy2, transposeA, transposeB, alpha, beta, rmExpected, out, onCopies);
}
return ok;
}
/**Same as checkMmul, but for matrix addition */
public static boolean checkAdd(INDArray first, INDArray second, double maxRelativeDifference,
double minAbsDifference) {
RealMatrix rmFirst = convertToApacheMatrix(first);
RealMatrix rmSecond = convertToApacheMatrix(second);
INDArray result = first.add(second);
RealMatrix rmResult = rmFirst.add(rmSecond);
if (!checkShape(rmResult, result))
return false;
boolean ok = checkEntries(rmResult, result, maxRelativeDifference, minAbsDifference);
if (!ok) {
INDArray onCopies = Shape.toOffsetZeroCopy(first).add(Shape.toOffsetZeroCopy(second));
printFailureDetails(first, second, rmResult, result, onCopies, "add");
}
return ok;
}
/** Same as checkMmul, but for matrix subtraction */
public static boolean checkSubtract(INDArray first, INDArray second, double maxRelativeDifference,
double minAbsDifference) {
RealMatrix rmFirst = convertToApacheMatrix(first);
RealMatrix rmSecond = convertToApacheMatrix(second);
INDArray result = first.sub(second);
RealMatrix rmResult = rmFirst.subtract(rmSecond);
if (!checkShape(rmResult, result))
return false;
boolean ok = checkEntries(rmResult, result, maxRelativeDifference, minAbsDifference);
if (!ok) {
INDArray onCopies = Shape.toOffsetZeroCopy(first).sub(Shape.toOffsetZeroCopy(second));
printFailureDetails(first, second, rmResult, result, onCopies, "sub");
}
return ok;
}
public static boolean checkMulManually(INDArray first, INDArray second, double maxRelativeDifference,
double minAbsDifference) {
//No apache commons element-wise multiply, but can do this manually
INDArray result = first.mul(second);
long[] shape = first.shape();
INDArray expected = Nd4j.zeros(first.shape());
for (int i = 0; i < shape[0]; i++) {
for (int j = 0; j < shape[1]; j++) {
double v = first.getDouble(i, j) * second.getDouble(i, j);
expected.putScalar(new int[] {i, j}, v);
}
}
if (!checkShape(expected, result))
return false;
boolean ok = checkEntries(expected, result, maxRelativeDifference, minAbsDifference);
if (!ok) {
INDArray onCopies = Shape.toOffsetZeroCopy(first).mul(Shape.toOffsetZeroCopy(second));
printFailureDetails(first, second, expected, result, onCopies, "mul");
}
return ok;
}
public static boolean checkDivManually(INDArray first, INDArray second, double maxRelativeDifference,
double minAbsDifference) {
//No apache commons element-wise division, but can do this manually
INDArray result = first.div(second);
long[] shape = first.shape();
INDArray expected = Nd4j.zeros(first.shape());
for (int i = 0; i < shape[0]; i++) {
for (int j = 0; j < shape[1]; j++) {
double v = first.getDouble(i, j) / second.getDouble(i, j);
expected.putScalar(new int[] {i, j}, v);
}
}
if (!checkShape(expected, result))
return false;
boolean ok = checkEntries(expected, result, maxRelativeDifference, minAbsDifference);
if (!ok) {
INDArray onCopies = Shape.toOffsetZeroCopy(first).mul(Shape.toOffsetZeroCopy(second));
printFailureDetails(first, second, expected, result, onCopies, "div");
}
return ok;
}
private static boolean checkShape(RealMatrix rmResult, INDArray result) {
long[] outShape = {rmResult.getRowDimension(), rmResult.getColumnDimension()};
if (!Arrays.equals(outShape, result.shape())) {
System.out.println("Failure on shape: " + Arrays.toString(result.shape()) + ", expected "
+ Arrays.toString(outShape));
return false;
}
return true;
}
private static boolean checkShape(INDArray expected, INDArray actual) {
if (!Arrays.equals(expected.shape(), actual.shape())) {
System.out.println("Failure on shape: " + Arrays.toString(actual.shape()) + ", expected "
+ Arrays.toString(expected.shape()));
return false;
}
return true;
}
public static boolean checkEntries(RealMatrix rmResult, INDArray result, double maxRelativeDifference,
double minAbsDifference) {
int[] outShape = {rmResult.getRowDimension(), rmResult.getColumnDimension()};
for (int i = 0; i < outShape[0]; i++) {
for (int j = 0; j < outShape[1]; j++) {
double expOut = rmResult.getEntry(i, j);
double actOut = result.getDouble(i, j);
if (Double.isNaN(actOut)) {
System.out.println("NaN failure on value: (" + i + "," + j + " exp=" + expOut + ", act=" + actOut);
return false;
}
if (expOut == 0.0 && actOut == 0.0)
continue;
double absError = Math.abs(expOut - actOut);
double relError = absError / (Math.abs(expOut) + Math.abs(actOut));
if (relError > maxRelativeDifference && absError > minAbsDifference) {
System.out.println("Failure on value: (" + i + "," + j + " exp=" + expOut + ", act=" + actOut
+ ", absError=" + absError + ", relError=" + relError);
return false;
}
}
}
return true;
}
public static boolean checkEntries(INDArray expected, INDArray actual, double maxRelativeDifference,
double minAbsDifference) {
long[] outShape = expected.shape();
for (int i = 0; i < outShape[0]; i++) {
for (int j = 0; j < outShape[1]; j++) {
double expOut = expected.getDouble(i, j);
double actOut = actual.getDouble(i, j);
if (expOut == 0.0 && actOut == 0.0)
continue;
double absError = Math.abs(expOut - actOut);
double relError = absError / (Math.abs(expOut) + Math.abs(actOut));
if (relError > maxRelativeDifference && absError > minAbsDifference) {
System.out.println("Failure on value: (" + i + "," + j + " exp=" + expOut + ", act=" + actOut
+ ", absError=" + absError + ", relError=" + relError);
return false;
}
}
}
return true;
}
public static RealMatrix convertToApacheMatrix(INDArray matrix) {
if (matrix.rank() != 2)
throw new IllegalArgumentException("Input rank is not 2 (not matrix)");
long[] shape = matrix.shape();
if (matrix.columns() > Integer.MAX_VALUE || matrix.rows() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
BlockRealMatrix out = new BlockRealMatrix((int) shape[0], (int) shape[1]);
for (int i = 0; i < shape[0]; i++) {
for (int j = 0; j < shape[1]; j++) {
double value = matrix.getDouble(i, j);
out.setEntry(i, j, value);
}
}
return out;
}
public static INDArray convertFromApacheMatrix(RealMatrix matrix, DataType dataType) {
val shape = new long[] {matrix.getRowDimension(), matrix.getColumnDimension()};
INDArray out = Nd4j.create(dataType, shape);
for (int i = 0; i < shape[0]; i++) {
for (int j = 0; j < shape[1]; j++) {
double value = matrix.getEntry(i, j);
out.putScalar(new int[] {i, j}, value);
}
}
return out;
}
public static void printFailureDetails(INDArray first, INDArray second, RealMatrix expected, INDArray actual,
INDArray onCopies, String op) {
System.out.println("\nFactory: " + Nd4j.factory().getClass() + "\n");
System.out.println("First:");
printMatrixFullPrecision(first);
System.out.println("\nSecond:");
printMatrixFullPrecision(second);
System.out.println("\nExpected (Apache Commons)");
printApacheMatrix(expected);
System.out.println("\nSame Nd4j op on copies: (Shape.toOffsetZeroCopy(first)." + op
+ "(Shape.toOffsetZeroCopy(second)))");
printMatrixFullPrecision(onCopies);
System.out.println("\nActual:");
printMatrixFullPrecision(actual);
}
public static void printGemmFailureDetails(INDArray a, INDArray b, INDArray c, boolean transposeA,
boolean transposeB, double alpha, double beta, RealMatrix expected, INDArray actual,
INDArray onCopies) {
System.out.println("\nFactory: " + Nd4j.factory().getClass() + "\n");
System.out.println("Op: gemm(a,b,c,transposeA=" + transposeA + ",transposeB=" + transposeB + ",alpha=" + alpha
+ ",beta=" + beta + ")");
System.out.println("a:");
printMatrixFullPrecision(a);
System.out.println("\nb:");
printMatrixFullPrecision(b);
System.out.println("\nc:");
printMatrixFullPrecision(c);
System.out.println("\nExpected (Apache Commons)");
printApacheMatrix(expected);
System.out.println("\nSame Nd4j op on zero offset copies: gemm(aCopy,bCopy,cCopy," + transposeA + ","
+ transposeB + "," + alpha + "," + beta + ")");
printMatrixFullPrecision(onCopies);
System.out.println("\nActual:");
printMatrixFullPrecision(actual);
}
public static void printMatrixFullPrecision(INDArray matrix) {
boolean floatType = (matrix.data().dataType() == DataType.FLOAT);
printNDArrayHeader(matrix);
long[] shape = matrix.shape();
for (int i = 0; i < shape[0]; i++) {
for (int j = 0; j < shape[1]; j++) {
if (floatType)
System.out.print(matrix.getFloat(i, j));
else
System.out.print(matrix.getDouble(i, j));
if (j != shape[1] - 1)
System.out.print(", ");
else
System.out.println();
}
}
}
public static void printNDArrayHeader(INDArray array) {
System.out.println(array.data().dataType() + " - order=" + array.ordering() + ", offset=" + array.offset()
+ ", shape=" + Arrays.toString(array.shape()) + ", stride=" + Arrays.toString(array.stride())
+ ", length=" + array.length() + ", data().length()=" + array.data().length());
}
public static void printFailureDetails(INDArray first, INDArray second, INDArray expected, INDArray actual,
INDArray onCopies, String op) {
System.out.println("\nFactory: " + Nd4j.factory().getClass() + "\n");
System.out.println("First:");
printMatrixFullPrecision(first);
System.out.println("\nSecond:");
printMatrixFullPrecision(second);
System.out.println("\nExpected");
printMatrixFullPrecision(expected);
System.out.println("\nSame Nd4j op on copies: (Shape.toOffsetZeroCopy(first)." + op
+ "(Shape.toOffsetZeroCopy(second)))");
printMatrixFullPrecision(onCopies);
System.out.println("\nActual:");
printMatrixFullPrecision(actual);
}
public static void printApacheMatrix(RealMatrix matrix) {
int nRows = matrix.getRowDimension();
int nCols = matrix.getColumnDimension();
System.out.println("Apache Commons RealMatrix: Shape: [" + nRows + "," + nCols + "]");
for (int i = 0; i < nRows; i++) {
for (int j = 0; j < nCols; j++) {
System.out.print(matrix.getEntry(i, j));
if (j != nCols - 1)
System.out.print(", ");
else
System.out.println();
}
}
}
}