All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.checkutil.CheckUtil Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.checkutil;

import lombok.val;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;

import java.util.Arrays;

/**@author Alex Black
 */
public class CheckUtil {

    /**Check first.mmul(second) using Apache commons math mmul. Float/double matrices only.
* Returns true if OK, false otherwise.
* Checks each element according to relative error (|a-b|/(|a|+|b|); however absolute error |a-b| must * also exceed minAbsDifference for it to be considered a failure. This is necessary to avoid instability * near 0: i.e., Nd4j mmul might return element of 0.0 (due to underflow on float) while Apache commons math * mmul might be say 1e-30 or something (using doubles). * Throws exception if matrices can't be multiplied * Checks each element of the result. If * @param first First matrix * @param second Second matrix * @param maxRelativeDifference Maximum relative error * @param minAbsDifference Minimum absolute difference for failure * @return true if OK, false if result incorrect */ public static boolean checkMmul(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference) { if (first.size(1) != second.size(0)) throw new IllegalArgumentException("first.columns != second.rows"); RealMatrix rmFirst = convertToApacheMatrix(first); RealMatrix rmSecond = convertToApacheMatrix(second); INDArray result = first.mmul(second); RealMatrix rmResult = rmFirst.multiply(rmSecond); if (!checkShape(rmResult, result)) return false; boolean ok = checkEntries(rmResult, result, maxRelativeDifference, minAbsDifference); if (!ok) { INDArray onCopies = Shape.toOffsetZeroCopy(first).mmul(Shape.toOffsetZeroCopy(second)); printFailureDetails(first, second, rmResult, result, onCopies, "mmul"); } return ok; } public static boolean checkGemm(INDArray a, INDArray b, INDArray c, boolean transposeA, boolean transposeB, double alpha, double beta, double maxRelativeDifference, double minAbsDifference) { long commonDimA = (transposeA ? a.rows() : a.columns()); long commonDimB = (transposeB ? b.columns() : b.rows()); if (commonDimA != commonDimB) throw new IllegalArgumentException("Common dimensions don't match: a.shape=" + Arrays.toString(a.shape()) + ", b.shape=" + Arrays.toString(b.shape()) + ", tA=" + transposeA + ", tb=" + transposeB); long outRows = (transposeA ? a.columns() : a.rows()); long outCols = (transposeB ? b.rows() : b.columns()); if (c.rows() != outRows || c.columns() != outCols) throw new IllegalArgumentException("C does not match outRows or outCols"); if (c.offset() != 0 || c.ordering() != 'f') throw new IllegalArgumentException("Invalid c"); INDArray aConvert = transposeA ? a.transpose() : a; RealMatrix rmA = convertToApacheMatrix(aConvert); INDArray bConvet = transposeB ? b.transpose() : b; RealMatrix rmB = convertToApacheMatrix(bConvet); RealMatrix rmC = convertToApacheMatrix(c); RealMatrix rmExpected = rmA.scalarMultiply(alpha).multiply(rmB).add(rmC.scalarMultiply(beta)); INDArray cCopy1 = Nd4j.create(c.shape(), 'f'); cCopy1.assign(c); INDArray cCopy2 = Nd4j.create(c.shape(), 'f'); cCopy2.assign(c); INDArray out = Nd4j.gemm(a, b, c, transposeA, transposeB, alpha, beta); if (out != c) { System.out.println("Returned different array than c"); return false; } if (!checkShape(rmExpected, out)) return false; boolean ok = checkEntries(rmExpected, out, maxRelativeDifference, minAbsDifference); if (!ok) { INDArray aCopy = Shape.toOffsetZeroCopy(a); INDArray bCopy = Shape.toOffsetZeroCopy(b); INDArray onCopies = Nd4j.gemm(aCopy, bCopy, cCopy1, transposeA, transposeB, alpha, beta); printGemmFailureDetails(a, b, cCopy2, transposeA, transposeB, alpha, beta, rmExpected, out, onCopies); } return ok; } /**Same as checkMmul, but for matrix addition */ public static boolean checkAdd(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference) { RealMatrix rmFirst = convertToApacheMatrix(first); RealMatrix rmSecond = convertToApacheMatrix(second); INDArray result = first.add(second); RealMatrix rmResult = rmFirst.add(rmSecond); if (!checkShape(rmResult, result)) return false; boolean ok = checkEntries(rmResult, result, maxRelativeDifference, minAbsDifference); if (!ok) { INDArray onCopies = Shape.toOffsetZeroCopy(first).add(Shape.toOffsetZeroCopy(second)); printFailureDetails(first, second, rmResult, result, onCopies, "add"); } return ok; } /** Same as checkMmul, but for matrix subtraction */ public static boolean checkSubtract(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference) { RealMatrix rmFirst = convertToApacheMatrix(first); RealMatrix rmSecond = convertToApacheMatrix(second); INDArray result = first.sub(second); RealMatrix rmResult = rmFirst.subtract(rmSecond); if (!checkShape(rmResult, result)) return false; boolean ok = checkEntries(rmResult, result, maxRelativeDifference, minAbsDifference); if (!ok) { INDArray onCopies = Shape.toOffsetZeroCopy(first).sub(Shape.toOffsetZeroCopy(second)); printFailureDetails(first, second, rmResult, result, onCopies, "sub"); } return ok; } public static boolean checkMulManually(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference) { //No apache commons element-wise multiply, but can do this manually INDArray result = first.mul(second); long[] shape = first.shape(); INDArray expected = Nd4j.zeros(first.shape()); for (int i = 0; i < shape[0]; i++) { for (int j = 0; j < shape[1]; j++) { double v = first.getDouble(i, j) * second.getDouble(i, j); expected.putScalar(new int[] {i, j}, v); } } if (!checkShape(expected, result)) return false; boolean ok = checkEntries(expected, result, maxRelativeDifference, minAbsDifference); if (!ok) { INDArray onCopies = Shape.toOffsetZeroCopy(first).mul(Shape.toOffsetZeroCopy(second)); printFailureDetails(first, second, expected, result, onCopies, "mul"); } return ok; } public static boolean checkDivManually(INDArray first, INDArray second, double maxRelativeDifference, double minAbsDifference) { //No apache commons element-wise division, but can do this manually INDArray result = first.div(second); long[] shape = first.shape(); INDArray expected = Nd4j.zeros(first.shape()); for (int i = 0; i < shape[0]; i++) { for (int j = 0; j < shape[1]; j++) { double v = first.getDouble(i, j) / second.getDouble(i, j); expected.putScalar(new int[] {i, j}, v); } } if (!checkShape(expected, result)) return false; boolean ok = checkEntries(expected, result, maxRelativeDifference, minAbsDifference); if (!ok) { INDArray onCopies = Shape.toOffsetZeroCopy(first).mul(Shape.toOffsetZeroCopy(second)); printFailureDetails(first, second, expected, result, onCopies, "div"); } return ok; } private static boolean checkShape(RealMatrix rmResult, INDArray result) { long[] outShape = {rmResult.getRowDimension(), rmResult.getColumnDimension()}; if (!Arrays.equals(outShape, result.shape())) { System.out.println("Failure on shape: " + Arrays.toString(result.shape()) + ", expected " + Arrays.toString(outShape)); return false; } return true; } private static boolean checkShape(INDArray expected, INDArray actual) { if (!Arrays.equals(expected.shape(), actual.shape())) { System.out.println("Failure on shape: " + Arrays.toString(actual.shape()) + ", expected " + Arrays.toString(expected.shape())); return false; } return true; } public static boolean checkEntries(RealMatrix rmResult, INDArray result, double maxRelativeDifference, double minAbsDifference) { int[] outShape = {rmResult.getRowDimension(), rmResult.getColumnDimension()}; for (int i = 0; i < outShape[0]; i++) { for (int j = 0; j < outShape[1]; j++) { double expOut = rmResult.getEntry(i, j); double actOut = result.getDouble(i, j); if (Double.isNaN(actOut)) { System.out.println("NaN failure on value: (" + i + "," + j + " exp=" + expOut + ", act=" + actOut); return false; } if (expOut == 0.0 && actOut == 0.0) continue; double absError = Math.abs(expOut - actOut); double relError = absError / (Math.abs(expOut) + Math.abs(actOut)); if (relError > maxRelativeDifference && absError > minAbsDifference) { System.out.println("Failure on value: (" + i + "," + j + " exp=" + expOut + ", act=" + actOut + ", absError=" + absError + ", relError=" + relError); return false; } } } return true; } public static boolean checkEntries(INDArray expected, INDArray actual, double maxRelativeDifference, double minAbsDifference) { long[] outShape = expected.shape(); for (int i = 0; i < outShape[0]; i++) { for (int j = 0; j < outShape[1]; j++) { double expOut = expected.getDouble(i, j); double actOut = actual.getDouble(i, j); if (expOut == 0.0 && actOut == 0.0) continue; double absError = Math.abs(expOut - actOut); double relError = absError / (Math.abs(expOut) + Math.abs(actOut)); if (relError > maxRelativeDifference && absError > minAbsDifference) { System.out.println("Failure on value: (" + i + "," + j + " exp=" + expOut + ", act=" + actOut + ", absError=" + absError + ", relError=" + relError); return false; } } } return true; } public static RealMatrix convertToApacheMatrix(INDArray matrix) { if (matrix.rank() != 2) throw new IllegalArgumentException("Input rank is not 2 (not matrix)"); long[] shape = matrix.shape(); if (matrix.columns() > Integer.MAX_VALUE || matrix.rows() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); BlockRealMatrix out = new BlockRealMatrix((int) shape[0], (int) shape[1]); for (int i = 0; i < shape[0]; i++) { for (int j = 0; j < shape[1]; j++) { double value = matrix.getDouble(i, j); out.setEntry(i, j, value); } } return out; } public static INDArray convertFromApacheMatrix(RealMatrix matrix, DataType dataType) { val shape = new long[] {matrix.getRowDimension(), matrix.getColumnDimension()}; INDArray out = Nd4j.create(dataType, shape); for (int i = 0; i < shape[0]; i++) { for (int j = 0; j < shape[1]; j++) { double value = matrix.getEntry(i, j); out.putScalar(new int[] {i, j}, value); } } return out; } public static void printFailureDetails(INDArray first, INDArray second, RealMatrix expected, INDArray actual, INDArray onCopies, String op) { System.out.println("\nFactory: " + Nd4j.factory().getClass() + "\n"); System.out.println("First:"); printMatrixFullPrecision(first); System.out.println("\nSecond:"); printMatrixFullPrecision(second); System.out.println("\nExpected (Apache Commons)"); printApacheMatrix(expected); System.out.println("\nSame Nd4j op on copies: (Shape.toOffsetZeroCopy(first)." + op + "(Shape.toOffsetZeroCopy(second)))"); printMatrixFullPrecision(onCopies); System.out.println("\nActual:"); printMatrixFullPrecision(actual); } public static void printGemmFailureDetails(INDArray a, INDArray b, INDArray c, boolean transposeA, boolean transposeB, double alpha, double beta, RealMatrix expected, INDArray actual, INDArray onCopies) { System.out.println("\nFactory: " + Nd4j.factory().getClass() + "\n"); System.out.println("Op: gemm(a,b,c,transposeA=" + transposeA + ",transposeB=" + transposeB + ",alpha=" + alpha + ",beta=" + beta + ")"); System.out.println("a:"); printMatrixFullPrecision(a); System.out.println("\nb:"); printMatrixFullPrecision(b); System.out.println("\nc:"); printMatrixFullPrecision(c); System.out.println("\nExpected (Apache Commons)"); printApacheMatrix(expected); System.out.println("\nSame Nd4j op on zero offset copies: gemm(aCopy,bCopy,cCopy," + transposeA + "," + transposeB + "," + alpha + "," + beta + ")"); printMatrixFullPrecision(onCopies); System.out.println("\nActual:"); printMatrixFullPrecision(actual); } public static void printMatrixFullPrecision(INDArray matrix) { boolean floatType = (matrix.data().dataType() == DataType.FLOAT); printNDArrayHeader(matrix); long[] shape = matrix.shape(); for (int i = 0; i < shape[0]; i++) { for (int j = 0; j < shape[1]; j++) { if (floatType) System.out.print(matrix.getFloat(i, j)); else System.out.print(matrix.getDouble(i, j)); if (j != shape[1] - 1) System.out.print(", "); else System.out.println(); } } } public static void printNDArrayHeader(INDArray array) { System.out.println(array.data().dataType() + " - order=" + array.ordering() + ", offset=" + array.offset() + ", shape=" + Arrays.toString(array.shape()) + ", stride=" + Arrays.toString(array.stride()) + ", length=" + array.length() + ", data().length()=" + array.data().length()); } public static void printFailureDetails(INDArray first, INDArray second, INDArray expected, INDArray actual, INDArray onCopies, String op) { System.out.println("\nFactory: " + Nd4j.factory().getClass() + "\n"); System.out.println("First:"); printMatrixFullPrecision(first); System.out.println("\nSecond:"); printMatrixFullPrecision(second); System.out.println("\nExpected"); printMatrixFullPrecision(expected); System.out.println("\nSame Nd4j op on copies: (Shape.toOffsetZeroCopy(first)." + op + "(Shape.toOffsetZeroCopy(second)))"); printMatrixFullPrecision(onCopies); System.out.println("\nActual:"); printMatrixFullPrecision(actual); } public static void printApacheMatrix(RealMatrix matrix) { int nRows = matrix.getRowDimension(); int nCols = matrix.getColumnDimension(); System.out.println("Apache Commons RealMatrix: Shape: [" + nRows + "," + nCols + "]"); for (int i = 0; i < nRows; i++) { for (int j = 0; j < nCols; j++) { System.out.print(matrix.getEntry(i, j)); if (j != nCols - 1) System.out.print(", "); else System.out.println(); } } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy