All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.validation.TestCase Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.validation;

import lombok.Data;
import lombok.Getter;
import lombok.NonNull;
import lombok.experimental.Accessors;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.functions.EqualityFn;
import org.nd4j.autodiff.validation.functions.RelErrorFn;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.function.Function;

import java.util.*;

@Data
@Accessors(fluent = true)
@Getter
public class TestCase {
    public enum TestSerialization {BEFORE_EXEC, AFTER_EXEC, BOTH, NONE};

    public static final boolean GC_DEFAULT_PRINT = false;
    public static final boolean GC_DEFAULT_EXIT_FIRST_FAILURE = false;
    public static final boolean GC_DEFAULT_DEBUG_MODE = false;
    public static final double GC_DEFAULT_EPS = 1e-5;
    public static final double GC_DEFAULT_MAX_REL_ERROR = 1e-5;
    public static final double GC_DEFAULT_MIN_ABS_ERROR = 1e-6;

    //To test
    private SameDiff sameDiff;
    private String testName;

    //Forward pass test configuration
    /*
     * Note: These forward pass functions are used to validate the output of forward pass for inputs already set
     * on the SameDiff instance.
     * Key:     The name of the variable to check the forward pass output for
     * Value:   A function to check the correctness of the output
     * NOTE: The Function should return null on correct results, and an error message otherwise
     */
    private Map> fwdTestFns;
    private Map placeholderValues;

    //Gradient check configuration
    private boolean gradientCheck = true;
    private boolean gradCheckPrint = GC_DEFAULT_PRINT;
    private boolean gradCheckDefaultExitFirstFailure = GC_DEFAULT_EXIT_FIRST_FAILURE;
    private boolean gradCheckDebugMode = GC_DEFAULT_DEBUG_MODE;
    private double gradCheckEpsilon = GC_DEFAULT_EPS;
    private double gradCheckMaxRelativeError = GC_DEFAULT_MAX_REL_ERROR;
    private double gradCheckMinAbsError = GC_DEFAULT_MIN_ABS_ERROR;
    private Set gradCheckSkipVariables;
    private Map gradCheckMask;

    //FlatBuffers serialization configuration
    private TestSerialization testFlatBufferSerialization = TestSerialization.BOTH;


    /**
     * @param sameDiff SameDiff instance to test. Note: All of the required inputs should already be set
     */
    public TestCase(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
    }

    /**
     * Validate the output (forward pass) for a single variable using INDArray.equals(INDArray)
     *
     * @param name     Name of the variable to check
     * @param expected Expected INDArray
     * @param eps  the expected epsilon, defaults to 1e-3
     */
    public TestCase expectedOutput(@NonNull String name, @NonNull INDArray expected,double eps) {
        return expected(name, new EqualityFn(expected,eps));
    }

    /**
     * Validate the output (forward pass) for a single variable using INDArray.equals(INDArray)
     *
     * @param name     Name of the variable to check
     * @param expected Expected INDArray
     */
    public TestCase expectedOutput(@NonNull String name, @NonNull INDArray expected) {
      return expectedOutput(name,expected,1e-3);
    }

    /**
     * Validate the output (forward pass) for a single variable using element-wise relative error:
     * relError = abs(x-y)/(abs(x)+abs(y)), with x=y=0 case defined to be 0.0.
     * Also has a minimum absolute error condition, which must be satisfied for the relative error failure to be considered
     * legitimate
     *
     * @param name        Name of the variable to check
     * @param expected    Expected INDArray
     * @param maxRelError Maximum allowable relative error
     * @param minAbsError Minimum absolute error for a failure to be considered legitimate
     */
    public TestCase expectedOutputRelError(@NonNull String name, @NonNull INDArray expected, double maxRelError, double minAbsError) {
        return expected(name, new RelErrorFn(expected, maxRelError, minAbsError));
    }

    /**
     * Validate the output (forward pass) for a single variable using INDArray.equals(INDArray)
     *
     * @param var    Variable to check
     * @param output Expected INDArray
     */
    public TestCase expected(@NonNull SDVariable var, @NonNull INDArray output) {
        return expected(var.name(), output);
    }

    /**
     * Validate the output (forward pass) for a single variable using INDArray.equals(INDArray)
     *
     * @param name   Name of the variable to check
     * @param output Expected INDArray
     */
    public TestCase expected(@NonNull String name, @NonNull INDArray output) {
        return expectedOutput(name, output);
    }

    public TestCase expected(SDVariable var, Function validationFn) {
        return expected(var.name(), validationFn);
    }

    /**
     * @param name         The name of the variable to check
     * @param validationFn Function to use to validate the correctness of the specific Op. Should return null
     *                     if validation passes, or an error message if the op validation fails
     */
    public TestCase expected(String name, Function validationFn) {
        if (fwdTestFns == null)
            fwdTestFns = new LinkedHashMap<>();
        fwdTestFns.put(name, validationFn);
        return this;
    }

    public Set gradCheckSkipVariables() {
        return gradCheckSkipVariables;
    }

    public Map gradCheckMask() {
        return gradCheckMask;
    }

    /**
     * Specify the input variables that should NOT be gradient checked.
     * For example, if an input is an integer index (not real valued) it should be skipped as such an input cannot
     * be gradient checked
     *
     * @param toSkip Name of the input variables to skip gradient check for
     */
    public TestCase gradCheckSkipVariables(String... toSkip) {
        if (gradCheckSkipVariables == null)
            gradCheckSkipVariables = new LinkedHashSet<>();
        Collections.addAll(gradCheckSkipVariables, toSkip);
        return this;
    }

    public TestCase placeholderValues(Map placeholderValues){
        this.placeholderValues = placeholderValues;
        return this;
    }

    public TestCase placeholderValue(String variable, INDArray value){
        if(this.placeholderValues == null)
            this.placeholderValues = new HashMap<>();
        this.placeholderValues.put(variable, value);
        return this;
    }


    public void assertConfigValid() {
        Preconditions.checkNotNull(sameDiff, "SameDiff instance cannot be null%s", testNameErrMsg());
        Preconditions.checkState(gradientCheck || (fwdTestFns != null && fwdTestFns.size() > 0), "Test case is empty: nothing to test" +
                " (gradientCheck == false and no expected results available)%s", testNameErrMsg());
    }

    public String testNameErrMsg() {
        if (testName == null)
            return "";
        return " - Test name: \"" + testName + "\"";
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy