All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.engine.Engine Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.Ec2Utils;
import ai.djl.util.RandomUtils;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.lang.management.MemoryUsage;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.Properties;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;

/**
 * The {@code Engine} interface is the base of the provided implementation for DJL.
 *
 * 

Any engine-specific functionality should be provided through this class. In general, it should * contain methods to detect information about the usable machine hardware and to create a new * {@link NDManager} and {@link Model}. * * @see Engine Guide * @see EngineProvider * @see The guide on * resource and engine caching */ public abstract class Engine { private static final Logger logger = LoggerFactory.getLogger(Engine.class); private static final Map ALL_ENGINES = new ConcurrentHashMap<>(); private static final String DEFAULT_ENGINE = initEngine(); private static final Pattern PATTERN = Pattern.compile("KEY|TOKEN|PASSWORD", Pattern.CASE_INSENSITIVE); private Device defaultDevice; // use object to check if it's set private Integer seed; private static synchronized String initEngine() { ServiceLoader loaders = ServiceLoader.load(EngineProvider.class); for (EngineProvider provider : loaders) { registerEngine(provider); } if (ALL_ENGINES.isEmpty()) { logger.debug("No engine found from EngineProvider"); return null; } String def = System.getProperty("ai.djl.default_engine"); String defaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def); if (defaultEngine == null || defaultEngine.isEmpty()) { int rank = Integer.MAX_VALUE; for (EngineProvider provider : ALL_ENGINES.values()) { if (provider.getEngineRank() < rank) { defaultEngine = provider.getEngineName(); rank = provider.getEngineRank(); } } } else if (!ALL_ENGINES.containsKey(defaultEngine)) { throw new EngineException("Unknown default engine: " + defaultEngine); } logger.debug("Found default engine: {}", defaultEngine); Ec2Utils.callHome(defaultEngine); return defaultEngine; } /** * Returns the alternative {@code engine} if available. * * @return the alternative {@code engine} */ public abstract Engine getAlternativeEngine(); /** * Returns the name of the Engine. * * @return the name of the engine */ public abstract String getEngineName(); /** * Return the rank of the {@code Engine}. * * @return the rank of the engine */ public abstract int getRank(); /** * Returns the default Engine name. * * @return the default Engine name */ public static String getDefaultEngineName() { return System.getProperty("ai.djl.default_engine", DEFAULT_ENGINE); } /** * Returns the default Engine. * * @return the instance of {@code Engine} * @see EngineProvider */ public static Engine getInstance() { if (DEFAULT_ENGINE == null) { throw new EngineException( "No deep learning engine found." + System.lineSeparator() + "Please refer to" + " https://github.com/deepjavalibrary/djl/blob/master/docs/development/troubleshooting.md" + " for more details."); } return getEngine(getDefaultEngineName()); } /** * Returns if the specified engine is available. * * @param engineName the name of Engine to check * @return {@code true} if the specified engine is available * @see EngineProvider */ public static boolean hasEngine(String engineName) { return ALL_ENGINES.containsKey(engineName); } /** * Registers a {@link EngineProvider} if not registered. * * @param provider the {@code EngineProvider} to be registered */ public static void registerEngine(EngineProvider provider) { logger.debug("Registering EngineProvider: {}", provider.getEngineName()); ALL_ENGINES.putIfAbsent(provider.getEngineName(), provider); } /** * Returns a set of engine names that are loaded. * * @return a set of engine names that are loaded */ public static Set getAllEngines() { return ALL_ENGINES.keySet(); } /** * Returns the {@code Engine} with the given name. * * @param engineName the name of Engine to retrieve * @return the instance of {@code Engine} * @see EngineProvider */ public static Engine getEngine(String engineName) { EngineProvider provider = ALL_ENGINES.get(engineName); if (provider == null) { throw new IllegalArgumentException("Deep learning engine not found: " + engineName); } return provider.getEngine(); } /** * Returns the version of the deep learning engine. * * @return the version number of the deep learning engine */ public abstract String getVersion(); /** * Returns whether the engine has the specified capability. * * @param capability the capability to retrieve * @return {@code true} if the engine has the specified capability */ public abstract boolean hasCapability(String capability); /** * Returns the engine's default {@link Device}. * * @return the engine's default {@link Device} */ public Device defaultDevice() { if (defaultDevice == null) { if (hasCapability(StandardCapabilities.CUDA) && CudaUtils.getGpuCount() > 0) { defaultDevice = Device.gpu(); } else { defaultDevice = Device.cpu(); } } return defaultDevice; } /** * Returns an array of devices. * *

If GPUs are available, it will return an array of {@code Device} of size * \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device. * * @return an array of devices */ public Device[] getDevices() { return getDevices(Integer.MAX_VALUE); } /** * Returns an array of devices given the maximum number of GPUs to use. * *

If GPUs are available, it will return an array of {@code Device} of size * \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device. * * @param maxGpus the max number of GPUs to use. Use 0 for no GPUs. * @return an array of devices */ public Device[] getDevices(int maxGpus) { int count = getGpuCount(); if (maxGpus <= 0 || count <= 0) { return new Device[] {Device.cpu()}; } count = Math.min(maxGpus, count); Device[] devices = new Device[count]; for (int i = 0; i < count; ++i) { devices[i] = Device.gpu(i); } return devices; } /** * Returns the number of GPUs available in the system. * * @return the number of GPUs available in the system */ public int getGpuCount() { if (hasCapability(StandardCapabilities.CUDA)) { return CudaUtils.getGpuCount(); } return 0; } /** * Construct an empty SymbolBlock for loading. * * @param manager the manager to manage parameters * @return Empty {@link SymbolBlock} for static graph */ public SymbolBlock newSymbolBlock(NDManager manager) { throw new UnsupportedOperationException("Not supported."); } /** * Constructs a new model. * * @param name the model name * @param device the device that the model will be loaded onto * @return a new Model instance using the network defined in block */ public abstract Model newModel(String name, Device device); /** * Creates a new top-level {@link NDManager}. * *

{@code NDManager} will inherit default {@link Device}. * * @return a new top-level {@code NDManager} */ public abstract NDManager newBaseManager(); /** * Creates a new top-level {@link NDManager} with specified {@link Device}. * * @param device the default {@link Device} * @return a new top-level {@code NDManager} */ public abstract NDManager newBaseManager(Device device); /** * Returns a new instance of {@link GradientCollector}. * * @return a new instance of {@link GradientCollector} */ public GradientCollector newGradientCollector() { throw new UnsupportedOperationException("Not supported."); } /** * Returns a new instance of {@link ParameterServer}. * * @param optimizer the optimizer to update * @return a new instance of {@link ParameterServer} */ public ParameterServer newParameterServer(Optimizer optimizer) { return new LocalParameterServer(optimizer); } /** * Seeds the random number generator in DJL Engine. * *

This will affect all {@link Device}s and all operators using Engine's random number * generator. * * @param seed the seed to be fixed in Engine */ public void setRandomSeed(int seed) { this.seed = seed; RandomUtils.RANDOM.setSeed(seed); } /** * Returns the random seed in DJL Engine. * * @return seed the seed to be fixed in Engine */ public Integer getSeed() { return seed; } /** * Returns the DJL API version. * * @return seed the seed to be fixed in Engine */ public static String getDjlVersion() { String version = Engine.class.getPackage().getSpecificationVersion(); if (version != null) { return version; } try (InputStream is = Engine.class.getResourceAsStream("api.properties")) { Properties prop = new Properties(); prop.load(is); return prop.getProperty("djl_version"); } catch (IOException e) { throw new AssertionError("Failed to open api.properties", e); } } /** {@inheritDoc} */ @Override public String toString() { return getEngineName() + ':' + getVersion(); } /** Prints debug information about the environment for debugging environment issues. */ @SuppressWarnings("PMD.SystemPrintln") public static void debugEnvironment() { System.out.println("----------- System Properties -----------"); System.getProperties().forEach((k, v) -> print((String) k, v)); System.out.println(); System.out.println("--------- Environment Variables ---------"); Utils.getenv().forEach(Engine::print); System.out.println(); System.out.println("-------------- Directories --------------"); try { Path temp = Paths.get(System.getProperty("java.io.tmpdir")); System.out.println("temp directory: " + temp); Path tmpFile = Files.createTempFile("test", ".tmp"); Files.delete(tmpFile); Path cacheDir = Utils.getCacheDir(); System.out.println("DJL cache directory: " + cacheDir.toAbsolutePath()); Path path = Utils.getEngineCacheDir(); System.out.println("Engine cache directory: " + path.toAbsolutePath()); Files.createDirectories(path); if (!Files.isWritable(path)) { System.out.println("Engine cache directory is not writable!!!"); } } catch (Throwable e) { e.printStackTrace(System.out); } System.out.println(); System.out.println("------------------ CUDA -----------------"); int gpuCount = CudaUtils.getGpuCount(); System.out.println("GPU Count: " + gpuCount); if (gpuCount > 0) { System.out.println("CUDA: " + CudaUtils.getCudaVersionString()); System.out.println("ARCH: " + CudaUtils.getComputeCapability(0)); } for (int i = 0; i < gpuCount; ++i) { Device device = Device.gpu(i); MemoryUsage mem = CudaUtils.getGpuMemory(device); System.out.println("GPU(" + i + ") memory used: " + mem.getCommitted() + " bytes"); } System.out.println(); System.out.println("----------------- Engines ---------------"); System.out.println("DJL version: " + getDjlVersion()); System.out.println("Default Engine: " + Engine.getInstance()); System.out.println("Default Device: " + Engine.getInstance().defaultDevice()); for (EngineProvider provider : ALL_ENGINES.values()) { System.out.println(provider.getEngineName() + ": " + provider.getEngineRank()); } } private static void print(String key, Object value) { if (PATTERN.matcher(key).find()) { value = "*********"; } System.out.println(key + ": " + value); // NOPMD } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy