ai.djl.engine.Engine Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.engine;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;
import java.lang.management.MemoryUsage;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* The {@code Engine} interface is the base of the provided implementation for DJL.
*
* Any engine-specific functionality should be provided through this class. In general, it should
* contain methods to detect information about the usable machine hardware and to create a new
* {@link NDManager} and {@link Model}.
*
* @see Engine Guide
* @see EngineProvider
* @see The guide on resource
* and engine caching
*/
public abstract class Engine {
private static final Logger logger = LoggerFactory.getLogger(Engine.class);
private static final Map ALL_ENGINES = new ConcurrentHashMap<>();
private static final String DEFAULT_ENGINE = initEngine();
private Device defaultDevice;
// use object to check if it's set
private Integer seed;
private static synchronized String initEngine() {
ServiceLoader loaders = ServiceLoader.load(EngineProvider.class);
for (EngineProvider provider : loaders) {
logger.debug("Found EngineProvider: {}", provider.getEngineName());
ALL_ENGINES.put(provider.getEngineName(), provider);
}
if (ALL_ENGINES.isEmpty()) {
logger.debug("No engine found from EngineProvider");
return null;
}
String defaultEngine = System.getenv("DJL_DEFAULT_ENGINE");
defaultEngine = System.getProperty("ai.djl.default_engine", defaultEngine);
if (defaultEngine == null || defaultEngine.isEmpty()) {
int rank = Integer.MAX_VALUE;
for (EngineProvider provider : ALL_ENGINES.values()) {
if (provider.getEngineRank() < rank) {
defaultEngine = provider.getEngineName();
rank = provider.getEngineRank();
}
}
} else if (!ALL_ENGINES.containsKey(defaultEngine)) {
throw new EngineException("Unknown default engine: " + defaultEngine);
}
logger.debug("Found default engine: {}", defaultEngine);
return defaultEngine;
}
/**
* Returns the name of the Engine.
*
* @return the name of the engine
*/
public abstract String getEngineName();
/**
* Return the rank of the {@code Engine}.
*
* @return the rank of the engine
*/
public abstract int getRank();
/**
* Returns the default Engine.
*
* @return the instance of {@code Engine}
* @see EngineProvider
*/
public static Engine getInstance() {
if (DEFAULT_ENGINE == null) {
throw new EngineException(
"No deep learning engine found."
+ System.lineSeparator()
+ "Please refer to https://github.com/deepjavalibrary/djl/blob/master/docs/development/troubleshooting.md for more details.");
}
return getEngine(System.getProperty("ai.djl.default_engine", DEFAULT_ENGINE));
}
/**
* Returns if the specified engine is available.
*
* @param engineName the name of Engine to check
* @return {@code true} if the specified engine is available
* @see EngineProvider
*/
public static boolean hasEngine(String engineName) {
return ALL_ENGINES.containsKey(engineName);
}
/**
* Returns a set of engine names that are loaded.
*
* @return a set of engine names that are loaded
*/
public static Set getAllEngines() {
return ALL_ENGINES.keySet();
}
/**
* Returns the {@code Engine} with the given name.
*
* @param engineName the name of Engine to retrieve
* @return the instance of {@code Engine}
* @see EngineProvider
*/
public static Engine getEngine(String engineName) {
EngineProvider provider = ALL_ENGINES.get(engineName);
if (provider == null) {
throw new IllegalArgumentException("Deep learning engine not found: " + engineName);
}
return provider.getEngine();
}
/**
* Returns the version of the deep learning engine.
*
* @return the version number of the deep learning engine
*/
public abstract String getVersion();
/**
* Returns whether the engine has the specified capability.
*
* @param capability the capability to retrieve
* @return {@code true} if the engine has the specified capability
*/
public abstract boolean hasCapability(String capability);
/**
* Returns the engine's default {@link Device}.
*
* @return the engine's default {@link Device}
*/
public Device defaultDevice() {
if (defaultDevice == null) {
if (hasCapability(StandardCapabilities.CUDA) && CudaUtils.getGpuCount() > 0) {
defaultDevice = Device.gpu();
} else {
defaultDevice = Device.cpu();
}
}
return defaultDevice;
}
/**
* Construct an empty SymbolBlock for loading.
*
* @param manager the manager to manage parameters
* @return Empty {@link SymbolBlock} for static graph
*/
public abstract SymbolBlock newSymbolBlock(NDManager manager);
/**
* Constructs a new model.
*
* @param name the model name
* @param device the device that the model will be loaded onto
* @return a new Model instance using the network defined in block
*/
public abstract Model newModel(String name, Device device);
/**
* Creates a new top-level {@link NDManager}.
*
* {@code NDManager} will inherit default {@link Device}.
*
* @return a new top-level {@code NDManager}
*/
public abstract NDManager newBaseManager();
/**
* Creates a new top-level {@link NDManager} with specified {@link Device}.
*
* @param device the default {@link Device}
* @return a new top-level {@code NDManager}
*/
public abstract NDManager newBaseManager(Device device);
/**
* Returns a new instance of {@link GradientCollector}.
*
* @return a new instance of {@link GradientCollector}
*/
public abstract GradientCollector newGradientCollector();
/**
* Returns a new instance of {@link ParameterServer}.
*
* @param optimizer the optimizer to update
* @return a new instance of {@link ParameterServer}
*/
public ParameterServer newParameterServer(Optimizer optimizer) {
return new LocalParameterServer(optimizer);
}
/**
* Seeds the random number generator in DJL Engine.
*
*
This will affect all {@link Device}s and all operators using Engine's random number
* generator.
*
* @param seed the seed to be fixed in Engine
*/
public void setRandomSeed(int seed) {
this.seed = seed;
}
/**
* Returns the random seed in DJL Engine.
*
* @return seed the seed to be fixed in Engine
*/
public Integer getSeed() {
return seed;
}
/** Prints debug information about the environment for debugging environment issues. */
@SuppressWarnings("PMD.SystemPrintln")
public static void debugEnvironment() {
System.out.println("----------- System Properties -----------");
System.getProperties().forEach((k, v) -> System.out.println(k + ": " + v));
System.out.println();
System.out.println("--------- Environment Variables ---------");
System.getenv().forEach((k, v) -> System.out.println(k + ": " + v));
System.out.println();
System.out.println("-------------- Directories --------------");
try {
Path temp = Paths.get(System.getProperty("java.io.tmpdir"));
System.out.println("temp directory: " + temp);
Path tmpFile = Files.createTempFile("test", ".tmp");
Files.delete(tmpFile);
Path cacheDir = Utils.getCacheDir();
System.out.println("DJL cache directory: " + cacheDir.toAbsolutePath());
Path path = Utils.getEngineCacheDir();
System.out.println("Engine cache directory: " + path.toAbsolutePath());
Files.createDirectories(path);
if (!Files.isWritable(path)) {
System.out.println("Engine cache directory is not writable!!!");
}
} catch (Throwable e) {
e.printStackTrace(System.out);
}
System.out.println();
System.out.println("------------------ CUDA -----------------");
int gpuCount = Device.getGpuCount();
System.out.println("GPU Count: " + gpuCount);
System.out.println("Default Device: " + Device.defaultDevice());
if (gpuCount > 0) {
System.out.println("CUDA: " + CudaUtils.getCudaVersionString());
System.out.println("ARCH: " + CudaUtils.getComputeCapability(0));
}
for (int i = 0; i < gpuCount; ++i) {
Device device = Device.gpu(i);
MemoryUsage mem = CudaUtils.getGpuMemory(device);
System.out.println("GPU(" + i + ") memory used: " + mem.getCommitted() + " bytes");
}
System.out.println();
System.out.println("----------------- Engines ---------------");
System.out.println("Default Engine: " + DEFAULT_ENGINE);
for (EngineProvider provider : ALL_ENGINES.values()) {
System.out.println(provider.getEngineName() + ": " + provider.getEngineRank());
try {
provider.getEngine();
} catch (EngineException e) {
e.printStackTrace(System.out);
}
}
}
}