All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.mxnet.engine.MxSymbolBlock Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */

package ai.djl.mxnet.engine;

import ai.djl.MalformedModelException;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * {@code MxSymbolBlock} is the MXNet implementation of {@link SymbolBlock}.
 *
 * 

You can create a {@code MxSymbolBlock} using {@link ai.djl.Model#load(java.nio.file.Path, * String)}. */ public class MxSymbolBlock extends AbstractSymbolBlock { private static final Logger logger = LoggerFactory.getLogger(MxSymbolBlock.class); private static final byte VERSION = 3; private NDManager manager; private CachedOp op; private Symbol symbol; private List mxNetParams; // includes input data private Map parameters; private Map paramShapes; private Shape[] outputShapes; private PairList inputDescriptions; private PairList outputDescriptions; private boolean first; /** * Constructs a {@code MxSymbolBlock} for a {@link Symbol}. * *

You can create a {@code MxSymbolBlock} using {@link ai.djl.Model#load(java.nio.file.Path, * String)}. * * @param manager the manager to use for the block * @param symbol the symbol containing the block's symbolic graph */ public MxSymbolBlock(NDManager manager, Symbol symbol) { this(manager); this.symbol = symbol; initBlock(); } /** * Constructs an empty {@code MxSymbolBlock}. * * @param manager the manager to use for the block */ public MxSymbolBlock(NDManager manager) { super(VERSION); this.manager = manager; } /** * Sets the names of the input data. * * @param inputNames the names of the input data */ public void setInputNames(List inputNames) { this.inputNames = inputNames; // now that we know which of the parameters are just input placeholders and which // are trainable, add them properly so they are correctly handled Set nameLookup = new HashSet<>(inputNames); parameters = new LinkedHashMap<>(mxNetParams.size()); for (Parameter mxNetParameter : mxNetParams) { if (!nameLookup.contains(mxNetParameter.getName())) { parameters.put(mxNetParameter.getName(), mxNetParameter); } } } /** * Returns the list of inputs and parameter NDArrays. * * @return the list of inputs and parameter NDArrays */ public List getAllParameters() { return mxNetParams; } /** * Returns the layers' name. * * @return a List of String containing the layers' name */ public List getLayerNames() { return symbol.getLayerNames(); } /** * Returns the Symbolic graph from the model. * * @return a {@link Symbol} object */ public Symbol getSymbol() { return symbol; } /** * Applies Optimization algorithm for the model. * * @param optimization the name of the optimization */ public void optimizeFor(String optimization) { Symbol newSymbol = symbol.optimizeFor(optimization, manager.getDevice()); symbol.close(); symbol = newSymbol; } /** {@inheritDoc} */ @Override public PairList describeInput() { if (inputDescriptions == null) { inputDescriptions = new PairList<>(); for (String name : inputNames) { // Add empty shapes as input shapes are not saved // in MXNet models logger.warn( "Input shapes are unknown, please run predict or forward once" + " and call describeInput again."); inputDescriptions.add(name, new Shape()); } } return inputDescriptions; } /** {@inheritDoc} */ @Override public ParameterList getDirectParameters() { return new ParameterList(parameters); } /** {@inheritDoc} */ @Override public PairList describeOutput() { if (outputDescriptions == null) { logger.warn( "Output shapes are unknown, please run predict or forward once" + " and call describeOutput again."); } return outputDescriptions; } /** {@inheritDoc} */ @Override protected NDList forwardInternal( ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { if (first) { synchronized (this) { if (first) { // create CachedOp is not thread-safe // add synchronized block to avoid creating multiple CachedOps op = JnaUtils.createCachedOp(this, (MxNDManager) manager, training); inputDescriptions = new PairList<>(); outputDescriptions = new PairList<>(); for (NDArray array : inputs) { inputDescriptions.add(array.getName(), array.getShape()); } NDList outputs = op.forward(parameterStore, inputs, training); for (NDArray array : outputs) { outputDescriptions.add(array.getName(), array.getShape()); } first = false; return outputs; } } } return op.forward(parameterStore, inputs, training); } /** {@inheritDoc} */ @Override public Shape[] getOutputShapes(Shape[] inputShapes) { if (outputShapes == null) { String[] outputNames = symbol.getOutputNames(); outputShapes = new Shape[outputNames.length]; for (int i = 0; i < outputShapes.length; ++i) { outputShapes[i] = getParameterShape(outputNames[i], inputShapes); } } return outputShapes; } /** {@inheritDoc} */ @Override public void removeLastBlock() { List layerNames = getLayerNames(); String layerName = layerNames.get(layerNames.size() - 2); Symbol sliced = symbol.get(layerName); symbol.close(); symbol = sliced; HashSet set = new HashSet<>(Arrays.asList(symbol.getAllNames())); for (int i = mxNetParams.size() - 1; i >= 0; --i) { Parameter parameter = mxNetParams.get(i); if (!set.contains(parameter.getName())) { mxNetParams.remove(i).close(); parameters.remove(parameter.getName(), parameter); } } } private Shape getParameterShape(String name, Shape[] inputShapes) { if (paramShapes == null) { PairList pairs = new PairList<>(); for (int i = 0; i < inputNames.size(); i++) { pairs.add(inputNames.get(i), inputShapes[i]); } paramShapes = symbol.inferShape(pairs); } if (paramShapes.containsKey(name)) { return paramShapes.get(name); } else { throw new IllegalArgumentException("Name " + name + " not found"); } } /** {@inheritDoc} */ @Override public void saveParameters(DataOutputStream os) throws IOException { os.writeByte(VERSION); String json = symbol.toJsonString(); // symbol size may go beyond os.writeUTF() size (65535) byte[] bytes = json.getBytes(StandardCharsets.UTF_8); os.writeInt(bytes.length); os.write(bytes); int size = inputNames.size(); os.writeInt(size); for (String name : inputNames) { os.writeUTF(name); } for (Parameter parameter : mxNetParams) { parameter.save(os); } } /** {@inheritDoc} */ @Override public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException { byte version = is.readByte(); if (version > VERSION) { throw new MalformedModelException("Unsupported encoding version: " + version); } if (version < VERSION && symbol == null) { throw new IllegalStateException( "Symbol is required for version 2, please use Model to load"); } if (version == VERSION) { int len = is.readInt(); byte[] bytes = new byte[len]; if (is.read(bytes) == -1) { throw new MalformedModelException("InputStream ends at symbol loading!"); } // init block only if it is not set symbol = Symbol.loadJson( (MxNDManager) manager, new String(bytes, StandardCharsets.UTF_8)); initBlock(); } int size = is.readInt(); for (int i = 0; i < size; ++i) { inputNames.add(is.readUTF()); } for (Parameter parameter : mxNetParams) { parameter.load(this.manager, is); } setInputNames(inputNames); } private void initBlock() { inputNames = new ArrayList<>(); String[] allNames = symbol.getAllNames(); mxNetParams = new ArrayList<>(allNames.length); Set auxNameSet = new HashSet<>(Arrays.asList(symbol.getAuxNames())); for (String name : allNames) { Parameter.Type type = inferType(name); boolean requireGrad = !auxNameSet.contains(name); mxNetParams.add( Parameter.builder() .setName(name) .setType(type) .optRequiresGrad(requireGrad) .build()); } first = true; } private static Parameter.Type inferType(String name) { if (name.endsWith("bias")) { return Parameter.Type.BIAS; } else if (name.endsWith("gamma")) { return Parameter.Type.GAMMA; } else if (name.endsWith("beta")) { return Parameter.Type.BETA; } else if (name.endsWith("moving_mean") || name.endsWith("running_mean")) { return Parameter.Type.RUNNING_MEAN; } else if (name.endsWith("moving_var") || name.endsWith("running_var")) { return Parameter.Type.RUNNING_VAR; } else if (name.endsWith("weight")) { return Parameter.Type.WEIGHT; } return Parameter.Type.OTHER; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy