All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.util.ND4JTestUtils Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.linalg.util;

import lombok.AllArgsConstructor;
import lombok.Data;
import org.apache.commons.io.FileUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.function.BiFunction;
import org.nd4j.common.primitives.Triple;

import java.io.File;
import java.net.URI;
import java.util.*;

public class ND4JTestUtils {

    private ND4JTestUtils(){ }


    @AllArgsConstructor
    @Data
    public static class ComparisonResult {
        List> allResults;
        List> passed;
        List> failed;
        List skippedDir1;
        List skippedDir2;
    }

    /**
     * A function for use with {@link #validateSerializedArrays(File, File, boolean, BiFunction)} using {@link INDArray#equals(Object)}
     */
    public static class EqualsFn implements BiFunction {
        @Override
        public Boolean apply(INDArray i1, INDArray i2) {
            return i1.equals(i2);
        }
    }

    /**
     * A function for use with {@link #validateSerializedArrays(File, File, boolean, BiFunction)} using {@link INDArray#equalsWithEps(Object, double)}
     */
    @AllArgsConstructor
    public static class EqualsWithEpsFn implements BiFunction {
        private final double eps;

        @Override
        public Boolean apply(INDArray i1, INDArray i2) {
            return i1.equalsWithEps(i2, eps);
        }
    }

    /**
     * Scan the specified directories for matching files (i.e., same path relative to their respective root directories)
     * and compare the contents using INDArray.equals (via {@link EqualsFn}
     * Assumes the saved files represent INDArrays saved with {@link Nd4j#saveBinary(INDArray, File)}
     * @param dir1      First directory
     * @param dir2      Second directory
     * @param recursive Whether to search recursively (i.e., include files in subdirectories
     * @return Comparison results
     */
    public static ComparisonResult validateSerializedArrays(File dir1, File dir2, boolean recursive) throws Exception {
        return validateSerializedArrays(dir1, dir2, recursive, new EqualsFn());
    }

    /**
     * Scan the specified directories for matching files (i.e., same path relative to their respective root directories)
     * and compare the contents using a provided function.
* Assumes the saved files represent INDArrays saved with {@link Nd4j#saveBinary(INDArray, File)} * @param dir1 First directory * @param dir2 Second directory * @param recursive Whether to search recursively (i.e., include files in subdirectories * @return Comparison results */ public static ComparisonResult validateSerializedArrays(File dir1, File dir2, boolean recursive, BiFunction evalFn) throws Exception { File[] f1 = FileUtils.listFiles(dir1, null, recursive).toArray(new File[0]); File[] f2 = FileUtils.listFiles(dir2, null, recursive).toArray(new File[0]); Preconditions.checkState(f1.length > 0, "No files found for directory 1: %s", dir1.getAbsolutePath() ); Preconditions.checkState(f2.length > 0, "No files found for directory 2: %s", dir2.getAbsolutePath() ); Map relativized1 = new HashMap<>(); Map relativized2 = new HashMap<>(); URI u = dir1.toURI(); for(File f : f1){ if(!f.isFile()) continue; String relative = u.relativize(f.toURI()).getPath(); relativized1.put(relative, f); } u = dir2.toURI(); for(File f : f2){ if(!f.isFile()) continue; String relative = u.relativize(f.toURI()).getPath(); relativized2.put(relative, f); } List skipped1 = new ArrayList<>(); for(String s : relativized1.keySet()){ if(!relativized2.containsKey(s)){ skipped1.add(relativized1.get(s)); } } List skipped2 = new ArrayList<>(); for(String s : relativized2.keySet()){ if(!relativized1.containsKey(s)){ skipped2.add(relativized1.get(s)); } } List> allResults = new ArrayList<>(); List> passed = new ArrayList<>(); List> failed = new ArrayList<>(); for(Map.Entry e : relativized1.entrySet()){ File file1 = e.getValue(); File file2 = relativized2.get(e.getKey()); if(file2 == null) continue; INDArray i1 = Nd4j.readBinary(file1); INDArray i2 = Nd4j.readBinary(file2); boolean b = evalFn.apply(i1, i2); Triple t = new Triple<>(file1, file2, b); allResults.add(t); if(b){ passed.add(t); } else { failed.add(t); } } Comparator> c = new Comparator>() { @Override public int compare(Triple o1, Triple o2) { return o1.getFirst().compareTo(o2.getFirst()); } }; Collections.sort(allResults, c); Collections.sort(passed, c); Collections.sort(failed, c); Collections.sort(skipped1); Collections.sort(skipped2); return new ComparisonResult(allResults, passed, failed, skipped1, skipped2); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy