ai.djl.engine.Engine Maven / Gradle / Ivy
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
* http://aws.amazon.com/apache2.0/
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
package ai.djl.engine;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.Ec2Utils;
import ai.djl.util.RandomUtils;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.InputStream;
import java.lang.management.MemoryUsage;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.Properties;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
* The {@code Engine} interface is the base of the provided implementation for DJL.
* Any engine-specific functionality should be provided through this class. In general, it should
* contain methods to detect information about the usable machine hardware and to create a new
* {@link NDManager} and {@link Model}.
* @see Engine Guide
* @see EngineProvider
* @see The guide on resource
* and engine caching
public abstract class Engine {
private static final Logger logger = LoggerFactory.getLogger(Engine.class);
private static final Map ALL_ENGINES = new ConcurrentHashMap<>();
private static String defaultEngine = initEngine();
private static final Pattern PATTERN =
private Device defaultDevice;
// use object to check if it's set
private Integer seed;
private static synchronized String initEngine() {
if (Boolean.parseBoolean(Utils.getenv("DJL_ENGINE_MANUAL_INIT"))) {
return null;
ServiceLoader loaders = ServiceLoader.load(EngineProvider.class);
for (EngineProvider provider : loaders) {
if (ALL_ENGINES.isEmpty()) {
logger.debug("No engine found from EngineProvider");
return null;
String def = System.getProperty("ai.djl.default_engine");
String newDefaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def);
if (newDefaultEngine == null || newDefaultEngine.isEmpty()) {
int rank = Integer.MAX_VALUE;
for (EngineProvider provider : ALL_ENGINES.values()) {
if (provider.getEngineRank() < rank) {
newDefaultEngine = provider.getEngineName();
rank = provider.getEngineRank();
} else if (!ALL_ENGINES.containsKey(newDefaultEngine)) {
throw new EngineException("Unknown default engine: " + newDefaultEngine);
logger.debug("Found default engine: {}", newDefaultEngine);
return newDefaultEngine;
* Returns the alternative {@code engine} if available.
* @return the alternative {@code engine}
public abstract Engine getAlternativeEngine();
* Returns the name of the Engine.
* @return the name of the engine
public abstract String getEngineName();
* Return the rank of the {@code Engine}.
* @return the rank of the engine
public abstract int getRank();
* Returns the default Engine name.
* @return the default Engine name
public static String getDefaultEngineName() {
return System.getProperty("ai.djl.default_engine", defaultEngine);
* Returns the default Engine.
* @return the instance of {@code Engine}
* @see EngineProvider
public static Engine getInstance() {
if (defaultEngine == null) {
throw new EngineException(
"No deep learning engine found."
+ System.lineSeparator()
+ "Please refer to"
+ " https://github.com/deepjavalibrary/djl/blob/master/docs/development/troubleshooting.md"
+ " for more details.");
return getEngine(getDefaultEngineName());
* Returns if the specified engine is available.
* @param engineName the name of Engine to check
* @return {@code true} if the specified engine is available
* @see EngineProvider
public static boolean hasEngine(String engineName) {
return ALL_ENGINES.containsKey(engineName);
* Registers a {@link EngineProvider} if not registered.
* @param provider the {@code EngineProvider} to be registered
public static void registerEngine(EngineProvider provider) {
logger.debug("Registering EngineProvider: {}", provider.getEngineName());
ALL_ENGINES.put(provider.getEngineName(), provider);
* Returns the default engine.
* @return the default engine
public static String getDefaultEngine() {
return defaultEngine;
* Sets the default engine returned by {@link #getInstance()}.
* @param engineName the new default engine's name
public static void setDefaultEngine(String engineName) {
// Requires an engine to be loaded (without exception) before being the default
logger.debug("Setting new default engine: {}", engineName);
defaultEngine = engineName;
* Returns a set of engine names that are loaded.
* @return a set of engine names that are loaded
public static Set getAllEngines() {
return ALL_ENGINES.keySet();
* Returns the {@code Engine} with the given name.
* @param engineName the name of Engine to retrieve
* @return the instance of {@code Engine}
* @see EngineProvider
public static Engine getEngine(String engineName) {
EngineProvider provider = ALL_ENGINES.get(engineName);
if (provider == null) {
throw new IllegalArgumentException("Deep learning engine not found: " + engineName);
Engine engine = provider.getEngine();
if (engine == null) {
throw new IllegalStateException(
"The engine " + engineName + " was not able to initialize");
return engine;
* Returns the version of the deep learning engine.
* @return the version number of the deep learning engine
public abstract String getVersion();
* Returns whether the engine has the specified capability.
* @param capability the capability to retrieve
* @return {@code true} if the engine has the specified capability
public abstract boolean hasCapability(String capability);
* Returns the engine's default {@link Device}.
* @return the engine's default {@link Device}
public Device defaultDevice() {
if (defaultDevice == null) {
if (hasCapability(StandardCapabilities.CUDA) && CudaUtils.getGpuCount() > 0) {
defaultDevice = Device.gpu();
} else {
defaultDevice = Device.cpu();
return defaultDevice;
* Returns an array of devices.
* If GPUs are available, it will return an array of {@code Device} of size
* \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device.
* @return an array of devices
public Device[] getDevices() {
return getDevices(Integer.MAX_VALUE);
* Returns an array of devices given the maximum number of GPUs to use.
If GPUs are available, it will return an array of {@code Device} of size
* \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device.
* @param maxGpus the max number of GPUs to use. Use 0 for no GPUs.
* @return an array of devices
public Device[] getDevices(int maxGpus) {
int count = getGpuCount();
if (maxGpus <= 0 || count <= 0) {
return new Device[] {Device.cpu()};
count = Math.min(maxGpus, count);
Device[] devices = new Device[count];
for (int i = 0; i < count; ++i) {
devices[i] = Device.gpu(i);
return devices;
* Returns the number of GPUs available in the system.
* @return the number of GPUs available in the system
public int getGpuCount() {
if (hasCapability(StandardCapabilities.CUDA)) {
return CudaUtils.getGpuCount();
return 0;
* Construct an empty SymbolBlock for loading.
* @param manager the manager to manage parameters
* @return Empty {@link SymbolBlock} for static graph
public SymbolBlock newSymbolBlock(NDManager manager) {
throw new UnsupportedOperationException("Not supported.");
* Constructs a new model.
* @param name the model name
* @param device the device that the model will be loaded onto
* @return a new Model instance using the network defined in block
public abstract Model newModel(String name, Device device);
* Creates a new top-level {@link NDManager}.
{@code NDManager} will inherit default {@link Device}.
* @return a new top-level {@code NDManager}
public abstract NDManager newBaseManager();
* Creates a new top-level {@link NDManager} with specified {@link Device}.
* @param device the default {@link Device}
* @return a new top-level {@code NDManager}
public abstract NDManager newBaseManager(Device device);
* Returns a new instance of {@link GradientCollector}.
* @return a new instance of {@link GradientCollector}
public GradientCollector newGradientCollector() {
throw new UnsupportedOperationException("Not supported.");
* Returns a new instance of {@link ParameterServer}.
* @param optimizer the optimizer to update
* @return a new instance of {@link ParameterServer}
public ParameterServer newParameterServer(Optimizer optimizer) {
return new LocalParameterServer(optimizer);
* Seeds the random number generator in DJL Engine.
This will affect all {@link Device}s and all operators using Engine's random number
* generator.
* @param seed the seed to be fixed in Engine
public void setRandomSeed(int seed) {
this.seed = seed;
* Returns the random seed in DJL Engine.
* @return seed the seed to be fixed in Engine
public Integer getSeed() {
return seed;
* Returns the DJL API version.
* @return seed the seed to be fixed in Engine
public static String getDjlVersion() {
String version = Engine.class.getPackage().getSpecificationVersion();
if (version != null) {
return version;
try (InputStream is = Engine.class.getResourceAsStream("api.properties")) {
Properties prop = new Properties();
return prop.getProperty("djl_version");
} catch (IOException e) {
throw new AssertionError("Failed to open api.properties", e);
/** {@inheritDoc} */
public String toString() {
return getEngineName() + ':' + getVersion();
/** Prints debug information about the environment for debugging environment issues. */
public static void debugEnvironment() {
System.out.println("----------- System Properties -----------");
System.getProperties().forEach((k, v) -> print((String) k, v));
System.out.println("--------- Environment Variables ---------");
System.out.println("-------------- Directories --------------");
try {
Path temp = Paths.get(System.getProperty("java.io.tmpdir"));
System.out.println("temp directory: " + temp);
Path tmpFile = Files.createTempFile("test", ".tmp");
Path cacheDir = Utils.getCacheDir();
System.out.println("DJL cache directory: " + cacheDir.toAbsolutePath());
Path path = Utils.getEngineCacheDir();
System.out.println("Engine cache directory: " + path.toAbsolutePath());
if (!Files.isWritable(path)) {
System.out.println("Engine cache directory is not writable!!!");
} catch (Throwable e) {
System.out.println("------------------ CUDA -----------------");
int gpuCount = CudaUtils.getGpuCount();
System.out.println("GPU Count: " + gpuCount);
if (gpuCount > 0) {
System.out.println("CUDA: " + CudaUtils.getCudaVersionString());
System.out.println("ARCH: " + CudaUtils.getComputeCapability(0));
for (int i = 0; i < gpuCount; ++i) {
Device device = Device.gpu(i);
MemoryUsage mem = CudaUtils.getGpuMemory(device);
System.out.println("GPU(" + i + ") memory used: " + mem.getCommitted() + " bytes");
System.out.println("----------------- Engines ---------------");
System.out.println("DJL version: " + getDjlVersion());
System.out.println("Default Engine: " + Engine.getInstance());
System.out.println("Default Device: " + Engine.getInstance().defaultDevice());
for (EngineProvider provider : ALL_ENGINES.values()) {
System.out.println(provider.getEngineName() + ": " + provider.getEngineRank());
private static void print(String key, Object value) {
if (PATTERN.matcher(key).find()) {
value = "*********";
System.out.println(key + ": " + value); // NOPMD