All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.BaseDL4JTest Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */
package org.deeplearning4j;

import lombok.SneakyThrows;
import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.*;

import org.nd4j.common.base.Preconditions;
import org.nd4j.common.config.ND4JSystemProperties;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.ILoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import static org.junit.jupiter.api.Assumptions.assumeTrue;


@DisplayName("Base DL 4 J Test")
public abstract class BaseDL4JTest {

   private static Logger log = LoggerFactory.getLogger(BaseDL4JTest.class.getName());

    protected long startTime;

    protected int threadCountBefore;

    private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();

    /**
     * Override this to specify the number of threads for C++ execution, via
     * {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)}
     * @return Number of threads to use for C++ op execution
     */
    public int numThreads() {
        return DEFAULT_THREADS;
    }

    /**
     * Override this method to set the default timeout for methods in the test class
     */
    public long getTimeoutMilliseconds() {
        return 90_000;
    }

    /**
     * Override this to set the profiling mode for the tests defined in the child class
     */
    public OpExecutioner.ProfilingMode getProfilingMode() {
        return OpExecutioner.ProfilingMode.SCOPE_PANIC;
    }

    /**
     * Override this to set the datatype of the tests defined in the child class
     */
    public DataType getDataType() {
        return DataType.DOUBLE;
    }

    public DataType getDefaultFPDataType() {
        return getDataType();
    }

    protected static Boolean integrationTest;

    /**
     * @return True if integration tests maven profile is enabled, false otherwise.
     */
    public static boolean isIntegrationTests() {
        if (integrationTest == null) {
            String prop = System.getenv("DL4J_INTEGRATION_TESTS");
            integrationTest = Boolean.parseBoolean(prop);
        }
        return integrationTest;
    }

    /**
     * Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled.
     * This can be used to dynamically skip integration tests when the integration test profile is not enabled.
     * Note that the integration test profile is not enabled by default - "integration-tests" profile
     */
    public static void skipUnlessIntegrationTests() {
        assumeTrue(isIntegrationTests(), "Skipping integration test - integration profile is not enabled");
    }

    @BeforeEach
    @Timeout(90000L)
    void beforeTest(TestInfo testInfo) {
        log.info("{}.{}", getClass().getSimpleName(), testInfo.getTestMethod().get().getName());
        // Suppress ND4J initialization - don't need this logged for every test...
        System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
        System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
        Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
        Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
        Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
        Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
        Nd4j.getExecutioner().enableDebugMode(false);
        Nd4j.getExecutioner().enableVerboseMode(false);
        int numThreads = numThreads();
        Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
        if (numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
            Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
        }
        startTime = System.currentTimeMillis();
        threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
    }

    @SneakyThrows
    @AfterEach
    void afterTest(TestInfo testInfo) {
        // Attempt to keep workspaces isolated between tests
        Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
        MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace(null);
        if (currWS != null) {
            // Not really safe to continue testing under this situation... other tests will likely fail with obscure
            // errors that are hard to track back to this
            log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
            System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
            System.out.flush();
            // Try to flush logs also:
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
            }
            ILoggerFactory lf = LoggerFactory.getILoggerFactory();
            //work around to remove explicit dependency on logback
            if( lf.getClass().getName().equals("ch.qos.logback.classic.LoggerContext")) {
                Method method = lf.getClass().getMethod("stop");
                method.setAccessible(true);
                method.invoke(lf);
            }
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
            }
            System.exit(1);
        }
        StringBuilder sb = new StringBuilder();
        long maxPhys = Pointer.maxPhysicalBytes();
        long maxBytes = Pointer.maxBytes();
        long currPhys = Pointer.physicalBytes();
        long currBytes = Pointer.totalBytes();
        long jvmTotal = Runtime.getRuntime().totalMemory();
        long jvmMax = Runtime.getRuntime().maxMemory();
        int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
        long duration = System.currentTimeMillis() - startTime;
        sb.append(getClass().getSimpleName()).append(".").append(testInfo.getTestMethod().get().getName()).append(": ").append(duration).append(" ms").append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")").append(", jvmTotal=").append(jvmTotal).append(", jvmMax=").append(jvmMax).append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes).append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
        List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
        if (ws != null && ws.size() > 0) {
            long currSize = 0;
            for (MemoryWorkspace w : ws) {
                currSize += w.getCurrentSize();
            }
            if (currSize > 0) {
                sb.append(", threadWSSize=").append(currSize).append(" (").append(ws.size()).append(" WSs)");
            }
        }
        Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
        Object o = p.get("cuda.devicesInformation");
        if (o instanceof List) {
            List> l = (List>) o;
            if (l.size() > 0) {
                sb.append(" [").append(l.size()).append(" GPUs: ");
                for (int i = 0; i < l.size(); i++) {
                    Map m = l.get(i);
                    if (i > 0)
                        sb.append(",");
                    sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ").append(m.get("cuda.totalMemory")).append(" total)");
                }
                sb.append("]");
            }
        }
        log.info(sb.toString());
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy