ai.djl.mxnet.engine.MxModel Maven / Gradle / Ivy
Show all versions of mxnet-engine Show documentation
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.mxnet.engine;
import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.nn.Parameter;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* {@code MxModel} is the MXNet implementation of {@link Model}.
*
* MxModel contains all the methods in Model to load and process a model. In addition, it
* provides MXNet Specific functionality, such as getSymbol to obtain the Symbolic graph and
* getParameters to obtain the parameter NDArrays
*/
public class MxModel extends BaseModel {
private static final Logger logger = LoggerFactory.getLogger(MxModel.class);
/**
* Constructs a new Model on a given device.
*
* @param name the model name
* @param device the device the model should be located on
*/
MxModel(String name, Device device) {
super(name);
dataType = DataType.FLOAT32;
properties = new ConcurrentHashMap<>();
manager = MxNDManager.getSystemManager().newSubManager(device);
manager.setName("mxModel");
}
/**
* Loads the MXNet model from a specified location.
*
*
MXNet engine looks for {MODEL_NAME}-symbol.json and {MODEL_NAME}-{EPOCH}.params files in
* the specified directory. By default, MXNet engine will pick up the latest epoch of the
* parameter file. However, users can explicitly specify an epoch to be loaded:
*
*
* Map<String, String> options = new HashMap<>()
* options.put("epoch", "3");
* model.load(modelPath, "squeezenet", options);
*
*
* @param modelPath the directory of the model
* @param prefix the model file name or path prefix
* @param options load model options, see documentation for the specific engine
* @throws IOException Exception for file loading
*/
@Override
@SuppressWarnings("PMD.EmptyControlStatement")
public void load(Path modelPath, String prefix, Map options)
throws IOException, MalformedModelException {
setModelDir(modelPath);
wasLoaded = true;
if (prefix == null) {
prefix = modelName;
}
boolean hasParameter = true;
String optimization = null;
if (options != null) {
String paramOption = (String) options.get("hasParameter");
if (paramOption != null) {
hasParameter = Boolean.parseBoolean(paramOption);
}
optimization = (String) options.get("MxOptimizeFor");
}
Path paramFile = paramPathResolver(prefix, options);
if (hasParameter && paramFile == null) {
prefix = modelDir.toFile().getName();
paramFile = paramPathResolver(prefix, options);
if (paramFile == null && block == null) {
throw new FileNotFoundException(
"Parameter file with prefix: "
+ prefix
+ " not found in: "
+ modelDir
+ " or not readable by the engine.");
}
}
if (block == null) {
// load MxSymbolBlock
Path symbolFile = modelDir.resolve(prefix + "-symbol.json");
if (Files.notExists(symbolFile)) {
throw new FileNotFoundException(
"Symbol file not found: "
+ symbolFile
+ ", please set block manually for imperative model.");
}
Symbol symbol =
Symbol.load((MxNDManager) manager, symbolFile.toAbsolutePath().toString());
// TODO: change default name "data" to model-specific one
block = new MxSymbolBlock(manager, symbol);
}
if (hasParameter) {
loadParameters(paramFile, options);
}
// TODO: Check if Symbol has all names that params file have
if (optimization != null) {
((MxSymbolBlock) block).optimizeFor(optimization);
}
// Freeze parameters to match Block spec for preTrained data
boolean trainParam =
options != null && Boolean.parseBoolean((String) options.get("trainParam"));
if (!trainParam) {
// TODO: See https://github.com/deepjavalibrary/djl/pull/2360
// NOPMD
// block.freezeParameters(true);
}
}
/** {@inheritDoc} */
@Override
@SuppressWarnings("PMD.EmptyControlStatement")
public Trainer newTrainer(TrainingConfig trainingConfig) {
PairList> initializer = trainingConfig.getInitializers();
if (block == null) {
throw new IllegalStateException(
"You must set a block for the model before creating a new trainer");
}
if (wasLoaded) {
// Unfreeze parameters if training directly
// TODO: See https://github.com/deepjavalibrary/djl/pull/2360
// block.freezeParameters(false);
}
for (Pair> pair : initializer) {
if (pair.getKey() != null && pair.getValue() != null) {
block.setInitializer(pair.getKey(), pair.getValue());
}
}
return new Trainer(this, trainingConfig);
}
/** {@inheritDoc} */
@Override
public String[] getArtifactNames() {
try (Stream stream = Files.walk(modelDir)) {
List files = stream.filter(Files::isRegularFile).collect(Collectors.toList());
List ret = new ArrayList<>(files.size());
for (Path path : files) {
String fileName = path.toFile().getName();
if (fileName.endsWith(".params") || fileName.endsWith("-symbol.json")) {
// ignore symbol and param files.
continue;
}
Path relative = modelDir.relativize(path);
ret.add(relative.toString());
}
return ret.toArray(Utils.EMPTY_ARRAY);
} catch (IOException e) {
throw new AssertionError("Failed list files", e);
}
}
/** {@inheritDoc} */
@Override
public void close() {
// TODO workaround for MXNet Engine crash issue
JnaUtils.waitAll();
super.close();
}
@SuppressWarnings("PMD.UseConcurrentHashMap")
private void loadParameters(Path paramFile, Map options)
throws IOException, MalformedModelException {
if (readParameters(paramFile, options)) {
return;
}
logger.debug("DJL formatted model not found, try to find MXNet model");
NDList paramNDlist = manager.load(paramFile);
MxSymbolBlock symbolBlock = (MxSymbolBlock) block;
List parameters = symbolBlock.getAllParameters();
Map map = new LinkedHashMap<>();
parameters.forEach(p -> map.put(p.getName(), p));
for (NDArray nd : paramNDlist) {
String key = nd.getName();
if (key == null) {
throw new IllegalArgumentException("Array names must be present in parameter file");
}
String paramName = key.split(":", 2)[1];
Parameter parameter = map.remove(paramName);
parameter.setArray(nd);
}
symbolBlock.setInputNames(new ArrayList<>(map.keySet()));
// TODO: Find a better to infer model DataType from SymbolBlock.
dataType = paramNDlist.head().getDataType();
logger.debug("MXNet Model {} ({}) loaded successfully.", paramFile, dataType);
}
}